## Import packages.

In [3]:
%matplotlib inline

# import packages
import numpy as np
import time
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
from matplotlib import gridspec

import torch
from torch import nn
from torch import distributions
from torch.nn.parameter import Parameter

from sklearn import cluster, datasets, mixture
from sklearn.preprocessing import StandardScaler

In [4]:
# define plot properties
from cycler import cycler
import matplotlib.cm as cm

from matplotlib import rcParams
from matplotlib import rc
from mpl_toolkits.axes_grid1 import make_axes_locatable

def rgb(r,g,b):
    return (float(r)/256.,float(g)/256.,float(b)/256.)

cb2 = [rgb(31,120,180), rgb(255,127,0), rgb(51,160,44), rgb(227,26,28), \
       rgb(166,206,227), rgb(253,191,111), rgb(178,223,138), rgb(251,154,153)]

rcParams['figure.figsize'] = (9,7.5)
#rcParams['figure.dpi'] = 300

rcParams['lines.linewidth'] = 1

rcParams['axes.prop_cycle'] = cycler('color', cb2)
rcParams['axes.facecolor'] = 'white'
rcParams['axes.grid'] = False

rcParams['patch.facecolor'] = cb2[0]
rcParams['patch.edgecolor'] = 'white'

#rcParams['font.family'] = 'Bitstream Vera Sans' 
rcParams['font.size'] = 23
rcParams['font.weight'] = 300


> Read data.

In [None]:
temp = np.loadtxt("geology_data.txt")

# velocities (1-10), depths of voronoi cells (1-10), 
# and then the outputs are 11 predicted Love wave velocities (one for each of 11 frequencies) 
x = temp[:,:20]
y = temp[:,20:]
print(temp.shape)

In [None]:
plt.plot(y.T, color='k', alpha=0.2);

> Make conditional NVP.

In [9]:
# import packages
import torch
from torch import nn
from torch import distributions
from torch.nn.parameter import Parameter
from torch.autograd import Variable


#=======================================================================================================
# define normalizing flow
class RealNVP(nn.Module):
    def __init__(self, nets, nett, mask, prior):
        super(RealNVP, self).__init__()

        self.prior = prior
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.t = torch.nn.ModuleList([nett() for _ in range(len(masks))])
        self.s = torch.nn.ModuleList([nets() for _ in range(len(masks))])
        self.s2 = torch.nn.ModuleList([nets2() for _ in range(len(masks))])
        self.t2 = torch.nn.ModuleList([nett2() for _ in range(len(masks))])

    def g(self, z, y):
        x = z
        for i in range(len(self.t)):
            x_ = x*self.mask[i]
            s1 = self.s[i](x_)
            s2 = self.s2[i](y)
            s = s1*s2*(1 - self.mask[i])
            t1 = self.t[i](x_)
            t2 = self.t2[i](y)
            t = t1*t2*(1 - self.mask[i])
            x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
        return x

    def f(self, x, y):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in reversed(range(len(self.t))):
            z_ = self.mask[i] * z
            s1 = self.s[i](z_)
            s2 = self.s2[i](y)
            s = s1*s2*(1 - self.mask[i])
            t1 = self.t[i](z_)
            t2 = self.t2[i](y)
            t = t1*t2*(1 - self.mask[i])
            z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
            log_det_J -= s.sum(dim=1)
        return z, log_det_J

    def log_prob(self,x,y):
        z, logp = self.f(x,y)
        return self.prior.log_prob(z) + logp

    def sample(self, z,y):
        x = self.g(z,y)
        return x



#=======================================================================================================
# restore models
flow = torch.load("../flow_final.pt", map_location=lambda storage, loc: storage); # load in cpu
flow.eval();

#--------------------------------------------------------------------------------------------------------
# import training set
temp = np.loadtxt("geology_data.txt")
x_tr = temp[:,:20]
y_tr = temp[:,20:]

# input dimension
dim_in = y_tr.shape[-1]

#=======================================================================================================
# choose an x tr
x_tr_choose = torch.from_numpy(x_tr).type(torch.FloatTensor)

z2 = np.random.multivariate_normal(np.zeros(dim_in), np.eye(dim_in), x_tr_choose.shape[0])
y2 = flow.sample(torch.from_numpy(z2).type(torch.FloatTensor), x_tr_choose).detach().cpu().numpy()


In [None]:
plt.plot(y2.T);

In [None]:
print(y2.shape)

In [None]:
# setup figure
plt.figure(figsize=[15,15]);

# the latent space
temp = np.load("real_nvp_results.npz")
z = temp["z1"]
print(z.shape)
plt.subplot(221)
plt.xlim([-4,4])
plt.ylim([-4,4])
plt.scatter(z[:, 0], z[:, 1], s=0.1)
plt.title(r'$z = f(X)$')

z = temp["z2"]
plt.subplot(222)
plt.xlim([-4,4])
plt.ylim([-4,4])
plt.scatter(z[:, 0], z[:, 1], s=0.1)
plt.title(r'$z \sim p(z)$')

#-----------------------------------------------------------------------------------------
# the Kiel diagram
x = temp["x1"]
plt.subplot(223)
plt.xlim([3000,8000])
plt.ylim([0,5])
plt.gca().invert_xaxis()
plt.gca().invert_yaxis()
plt.scatter(x[:,0], x[:,1], c='r', s=0.01)
plt.title(r'$X \sim p(X)$ [2D]')

x = temp["x2"]
plt.subplot(224)
plt.xlim([3000,8000])
plt.ylim([0,5])
plt.gca().invert_xaxis()
plt.gca().invert_yaxis()
plt.scatter(x[:,0], x[:,1], c='r', s=0.01)
plt.title(r'$X = g(z)$ [2D]')

#===================================================================================
# save figure
plt.tight_layout(w_pad=0.2,h_pad=0.3)

# save figure
#plt.savefig("Real_Dwarf.png")
plt.savefig("Kiel_Diagram.png")



In [None]:
# import package 
import seaborn as sns
    
# reset line width
rcParams['lines.linewidth'] = 3

            
#=====================================================================================
# restore stellar labels
# real APOGEE-Payne labels
temp = np.load("real_nvp_results.npz")
x = temp["x1"]

# reconstruction from real NVP
#x = temp["x2"]


#=====================================================================================
# initiate the plot
fig = plt.figure(figsize=[20,20]);

# combine labels
ax = fig.add_subplot(111)

ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)

# axis labels
ax.set_xlabel("[Fe/H]", fontsize=40, labelpad=30);
ax.set_ylabel("[X/Fe]", fontsize=40, labelpad=40);

#----------------------------------------------------------------------------------
# label name
label_name = ["C", "N", "O", "Mg",\
              "Al", "Si", "S", "K",\
              "Ca", "Ti", "Cr", "Mn",\
              "Ni", "Cu"]

#----------------------------------------------------------------------------------
# loop over all labels
for u1 in range(len(label_name)):

    # choose subplot
    ax = fig.add_subplot(4, 4, u1+1)
    
    # tick padding
    ax.tick_params(axis='x', pad=10);
    
    # plotting range
    plt.xlim([-1,0.4])
    plt.ylim([-0.4,0.5])
    plt.text(-0.9, 0.35, label_name[u1], color="white", fontsize=30)
        
    # reduce number of ticks
    plt.locator_params(nbins=4)

    # delete ticks
    if u1 % 4 != 0:
        plt.setp(ax.get_yticklabels(), visible=False)
    if u1 < len(label_name) - 4:
        plt.setp(ax.get_xticklabels(), visible=False)
        
#----------------------------------------------------------------------------------        
    # plot results
    choose = (x[:,1] > 4)*(x[:,2] > -1.)*(np.abs(x[:,u1+3]) < 0.5)
    plt.hexbin(x[:,2][choose], x[:,u1+3][choose],\
               gridsize=50, cmap="viridis")
    
#----------------------------------------------------------------------------------  
    # add text
    #if u1 == len(label_name) - 1:
    #    plt.text(1., -0.1, "Giants\n" + r"$\mathregular{(\log\, g < 4)}$", fontsize=40)

                        
#===================================================================================
# save figure
plt.tight_layout(w_pad=0.2,h_pad=0.3)

# save figure
plt.savefig("Real_Dwarf.png")
#plt.savefig("Reconstruction_Dwarf.png")
