In [1]:
import scipy.io as sio
import numpy as np
from sklearn.model_selection import train_test_split
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.utils.data import DataLoader,Dataset
from matplotlib import pyplot as plt
import pyro
import pyro.distributions as dist
from pyro.distributions.transforms import spline_autoregressive, conditional_spline_autoregressive

data_location = r'/home/taymaz/Documents/Project_VAE/material_vae/data/MP_v2.mat'

if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.set_default_tensor_type("torch.cuda.FloatTensor")
    torch.cuda.empty_cache()
else:
    device = torch.device("cpu")
    torch.set_default_tensor_type("torch.FloatTensor")
    

In [2]:
# define prediction model
class ResidualBlock(nn.Module):
    """
    A general-purpose residual block for 1-dim inputs.
    
    """

    def __init__(self, dim, dropout=0.65, zero_initialization=True, batch_norm=True):
        super().__init__()
        self.batch_norm = batch_norm
        self.linear_layers = nn.ModuleList([nn.Linear(dim, dim) for _ in range(2)])
        self.dropout = nn.Dropout(p=dropout)
        self.relu =  nn.ReLU()
        if zero_initialization:
            torch.nn.init.uniform_(self.linear_layers[-1].weight, -1e-3, 1e-3)
            torch.nn.init.uniform_(self.linear_layers[-1].bias, -1e-3, 1e-3)
        if batch_norm:
            self.batch_norm_layers = nn.ModuleList([nn.BatchNorm1d(dim, eps=1e-3) for _ in range(2)])
    
    def forward(self, inputs):
        temps = inputs
        if self.batch_norm:
            temps = self.batch_norm_layers[0](temps)
        temps = self.relu(temps)
        temps = self.linear_layers[0](temps)
        if self.batch_norm:
            temps = self.batch_norm_layers[1](temps)
        temps = self.relu(temps)
        temps = self.dropout(temps)
        temps = self.linear_layers[1](temps)
        return inputs + temps

    
class ResidualNet(nn.Module):
    """
    A general-purpose residual network for 1-dim inputs. 
    Option to be used as a Gaussian encoder network (or mixture of Gaussian encoder network)
    
    """

    def __init__(self, in_dim, out_dim, hidden_dim, num_blocks=2, dropout=0.65, batch_norm=True, gauss_encoder=False, gauss_mix=False, num_gauss = 14):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.initial_layer = nn.Linear(in_dim, hidden_dim)
        self.blocks = nn.ModuleList([ResidualBlock(dim=hidden_dim, dropout=dropout, batch_norm=batch_norm) for _ in range(num_blocks)])
        self.final_layer = nn.Linear(hidden_dim, out_dim)
        self.softmax = nn.Softmax(dim=0)
        self.encode = gauss_encoder       
        self.gaussian_mix = gauss_mix
        self.num_gauss = num_gauss
        # if using ResNet as a VAE encoder
        if self.encode and not self.gaussian_mix:
            self.final_layer_loc = nn.Linear(hidden_dim, out_dim)
            self.final_layer_scale = nn.Linear(hidden_dim, out_dim)
        if self.encode and self.gaussian_mix:
            self.final_layers_loc = nn.ModuleList([nn.Linear(hidden_dim, out_dim) for _ in range(self.num_gauss)])
            self.final_layers_scale = nn.ModuleList([nn.Linear(hidden_dim, out_dim) for _ in range(self.num_gauss)])
            self.final_layers_weight = nn.ModuleList([nn.Linear(hidden_dim, 1) for _ in range(self.num_gauss)])
        
    def forward(self, inputs):
        temps = self.initial_layer(inputs)
        for block in self.blocks:
            temps = block(temps)
        if self.encode and not self.gaussian_mix:
            mu = self.final_layer_loc(temps)
            logvar = self.final_layer_scale(temps)
            return mu, logvar
        elif self.encode and self.gaussian_mix:
            mus = torch.stack([self.final_layers_loc[n](temps) for n in range(self.num_gauss)])
            logvars = torch.stack([self.final_layers_scale[n](temps) for n in range(self.num_gauss)])
            weights = torch.stack([self.final_layers_weight[n](temps) for n in range(self.num_gauss)]).squeeze()
            return mus, logvars, self.softmax(weights)
        else:
            outputs = self.final_layer(temps)
            return outputs

    
class FlowVAE(nn.Module):
    """
    Conditional (Autoregressive) Spline Flow based VAE with ResNet encoder/decoder.
    Two hierachies: structural and energetic.
    Option to use a mixture of Gaussians as prior.
    
    """
    
    def __init__(self, input_dim, hidden_dim, latent_dim, num_flows, gauss_mix=False, num_gauss=14):
        super(FlowVAE, self).__init__()
        
        self.input = input_dim
        self.hidden = hidden_dim
        self.latent = latent_dim
        self.gaussian_mixture_prior = gauss_mix
        self.resnet_encoder = ResidualNet(input_dim, latent_dim, hidden_dim, gauss_encoder=True, gauss_mix=gauss_mix, num_gauss=num_gauss)
        self.resnet_decoder = ResidualNet(latent_dim, input_dim, hidden_dim)
        self.numflow = num_flows
        self.flow_structural = [conditional_spline_autoregressive(self.latent, context_dim=1) for _ in range(self.numflow)]
        self.flow_energetic = [conditional_spline_autoregressive(self.latent, context_dim=1) for _ in range(self.numflow)]
        self.flow_modules = nn.ModuleList(self.flow_structural + self.flow_energetic)
    
    def encode(self, x):
        if self.gaussian_mixture_prior:
            mus, logvars, weights = self.resnet_encoder(x)
            return mus, logvars, weights
        else:
            mu, logvar = self.resnet_encoder(x)
            return mu, logvar

    def decode(self, z):
        out = self.resnet_decoder(z)
        return out
    
    def forward(self, x, context_structure, context_energy):
        if self.gaussian_mixture_prior:
            mu, logvar, weights = self.encode(x)
            mixture = dist.Categorical(weights.permute(1,0))
            component = dist.Independent(dist.Normal(mu.permute(1,0,2), logvar.permute(1,0,2)), 1)
            prior = dist.MixtureSameFamily(mixture, component) 
        else: 
            mu, logvar = self.encode(x)
            prior = dist.Normal(mu, logvar)
        
        structural_embed = dist.ConditionalTransformedDistribution(prior, self.flow_structural)
        energetic_embed = dist.ConditionalTransformedDistribution(structural_embed, self.flow_energetic)
        with pyro.plate("xrd", x.shape[0]):
            z_structural = structural_embed.condition(context_structure).sample()
            z_energetic = energetic_embed.condition(context_energy).sample()
            
        return self.decode(z_energetic), mu, logvar, z_energetic

In [3]:
# define loss function here, the MAPE are set as criterion for cohesive energy prediction
def simplevae_elbo_loss_function_with_energy(recon_x, x, mu, logvar):
    MSE = nn.MSELoss(reduction='mean')(recon_x, x)
    #MAPE_eng = torch.sum(torch.abs((pred_e-e)/e))
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return 0.1 * MSE + 0.01 * KLD, MSE, KLD

def pred_loss_CE(tar, tar_pred):
    loss = nn.CrossEntropyLoss()
    CE_tar = loss(tar_pred,tar.squeeze())
    return CE_tar

def pred_loss_MSE(tar, tar_pred):
    MSE = nn.MSELoss(reduction='mean')(tar_pred, tar) 
    return MSE

def accuracy(tar, tar_pred):
    acc = torch.sum(torch.argmax(tar_pred, dim=1).view(-1,1)==tar)
    return acc

In [4]:
# define auxiliary dataset class
class ndarrayDataset(Dataset):
    """simple dataset"""
    
    def __init__(self, X, y_structure, y_energy):
        super(ndarrayDataset, self).__init__()
        self.X = torch.from_numpy(X).float()
        self.y_structure = torch.from_numpy(y_structure).float()
        self.y_energy = torch.from_numpy(y_energy).long()
    
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y_structure[idx], self.y_energy[idx]

In [5]:
# load data (version 2)
data = sio.loadmat(data_location)

input_mat = data['MP']
id = input_mat[:,0]
atom_type = input_mat[:,1]
X = input_mat[:,2:3602] # training data
space_group = (input_mat[:,3602]-1).astype(int) # target value
band_gap = input_mat[:,3603] # target value
energy = input_mat[:,3604] # target value
mag_moment = input_mat[:,3605] # target value
energy_above_hull = input_mat[:,3606] # target value
targets = input_mat[:,3602:]

In [6]:
bravais = np.copy(space_group)
for i, sg in enumerate(space_group):
    if sg == 0 or sg == 1:
        bravais[i] = 0
    if sg == 2 or sg == 3 or sg == 5 or sg == 6 or sg == 9 or sg == 10 or sg == 12 or sg == 13:
        bravais[i] = 1
    if sg == 4 or sg == 7 or sg == 8 or sg == 11 or sg == 14:
        bravais[i] = 2
    if 15 <= sg <= 18 or 24 <= sg <= 33 or 46 <= sg <= 61:
        bravais[i] = 3
    if sg == 19 or sg == 20 or 34 <= sg <= 40 or 62 <= sg <= 67:
        bravais[i] = 4
    if sg == 21 or sg == 68 or sg == 69 or 41 <= sg <= 42:
        bravais[i] = 5
    if sg == 22 or sg == 23 or 43 <= sg <= 45 or 70 <= sg <= 73:
        bravais[i] = 6
    if 74 <= sg <= 77 or sg == 80 or 82 <= sg <= 85 or 88 <= sg <= 95 or 98 <= sg <= 105 or 110 <= sg <= 117 or 122 <= sg <= 137:
        bravais[i] = 7
    if 78 <= sg <= 79 or sg == 81 or 86 <= sg <= 87 or 96 <= sg <= 97 or 106 <= sg <= 109 or 118 <= sg <= 121 or 138 <= sg <= 141:
        bravais[i] = 8
    if 142 <= sg <= 166:
        bravais[i] = 9
    if 167 <= sg <= 193:
        bravais[i] = 10
    if sg == 194 or sg == 197 or 199 <= sg <= 200 or sg == 204 or 206 <= sg <= 207 or 211 <= sg <= 212 or sg == 214 or sg == 217 or 220 <= sg <= 223:
        bravais[i] = 11
    if sg == 195 or 201 <= sg <= 202 or 208 <= sg <= 209 or sg == 215 or sg == 218 or 224 <= sg <= 227:
        bravais[i] = 12
    if sg == 196 or sg == 198 or sg == 203 or sg == 205 or sg == 210 or sg == 213 or sg == 216 or sg == 219 or 228 <= sg <= 229:
        bravais[i] = 13     

In [7]:
X_train, X_test, energy_train, energy_test, bravais_train, bravais_test = train_test_split(X, energy, bravais, test_size=0.40, shuffle=True, random_state=9)



X_train, X_test, tar_train, tar_test = train_test_split(X, bravais, test_size=0.40, shuffle=True, random_state=9)
X_val1, X_val2, tar_val1, tar_val2 = train_test_split(X_test, tar_test, test_size=0.50, shuffle=True, random_state=9)

In [8]:
# add some parameters
class Args:
    batch_size = 256 
    epochs_pre = 20     # pretraining epochs
    epochs_vae = 1000   # VAE epochs
    epochs_pred = 500   # prediction epochs
    seed = 9            # random seed (default: 9)
    log_interval = 10   # how many batches to wait before logging training status (default 10)
    latent_dim = 15     # VAE latent space dimension
    
args=Args()
torch.manual_seed(args.seed)

train_dataset = ndarrayDataset(X_train, tar_train)
train_loader = DataLoader(train_dataset, batch_size = args.batch_size)
v1_dataset = ndarrayDataset(X_val1, tar_val1)
v1_loader = DataLoader(v1_dataset, batch_size=1000)
v1_loader2 = DataLoader(v1_dataset, batch_size=args.batch_size)
v2_dataset = ndarrayDataset(X_val2, tar_val2)
v2_loader = DataLoader(v2_dataset, batch_size=1000)



"""

Body part of pretraining

"""

structural_estimator_model = ResidualNet(3600, 14, int((3600+14)/2))
structural_estimator_optimizer = optim.Adam(structural_estimator_model.parameters(), lr=1e-5)

energetic_estimator_model = ResidualNet(3600, 1, int(3600/2))
energetic_estimator_optimizer = optim.Adam(energetic_estimator_model.parameters(), lr=1e-5)

def structural_estimator_train(epoch):
    structural_estimator_model.train()
    matches = 0
    for batch_idx, (data, tar) in enumerate(train_loader):
        data = data.to(device)
        tar = tar.view(-1,1).to(torch.long).to(device)
        structural_estimator_optimizer.zero_grad()
        tar_pred = structural_estimator_model(data)
        loss = pred_loss_CE(tar, tar_pred)
        loss.backward()
        structural_estimator_optimizer.step()
        matches += accuracy(tar, tar_pred).item()
    accuracy = matches/len(X_val1)
    return loss, accuracy

def structural_estimator_test(epoch):
    structural_estimator_model.eval()
    matches = 0
    with torch.no_grad():
        for batch_idx, (data, tar) in enumerate(v1_loader):
            data = data.to(device)
            tar = tar.view(-1,1).to(torch.long).to(device)
            structural_estimator_optimizer.zero_grad()
            tar_pred = structural_estimator_model(data)
            loss = pred_loss_CE(tar, tar_pred)
            matches += accuracy(tar, tar_pred).item()
            if batch_idx % args.log_interval == 0:
                print(f'Epoch:{epoch}, pre-train val prediction loss: {loss.item()}')
    accuracy = matches/len(X_val2)
    return loss, accuracy

def energetic_estimator_train(epoch):
    energetic_estimator_model.train()
    matches = 0
    for batch_idx, (data, tar) in enumerate(train_loader):
        data = data.to(device)
        tar = tar.view(-1,1).to(torch.long).to(device)
        energetic_estimator_optimizer.zero_grad()
        tar_pred = energetic_estimator_model(data)
        loss = pred_loss_MSE(tar, tar_pred)
        loss.backward()
        energetic_estimator_optimizer.step()
        matches += accuracy(tar, tar_pred).item()
    accuracy = matches/len(X_val1)
    return loss, accuracy

def energetic_estimator_test(epoch):
    energetic_estimator_model.eval()
    matches = 0
    with torch.no_grad():
        for batch_idx, (data, tar) in enumerate(v1_loader):
            data = data.to(device)
            tar = tar.view(-1,1).to(torch.long).to(device)
            energetic_estimator_optimizer.zero_grad()
            tar_pred = energetic_estimator_model(data)
            loss = pred_loss_MSE(tar, tar_pred)
            matches += accuracy(tar, tar_pred).item()
            if batch_idx % args.log_interval == 0:
                print(f'Epoch:{epoch}, pre-train val prediction loss: {loss.item()}')
    accuracy = matches/len(X_val2)
    return loss, accuracy





"""

Body part of VAE training

"""

vae_model = FlowVAE(input_dim=3600,hidden_dim=200,latent_dim=args.latent_dim,num_flows=3,gauss_mix=True,num_gauss=14).to(device)
vae_optimizer = optim.Adam(vae_model.parameters(), lr=1e-5)

def vae_train(epoch):
    vae_model.train()
    latent = []
    for batch_idx, (data, tar_structure, tar_energy) in enumerate(train_loader):
        data = data.to(device)
        tar_structure = tar.view(-1,1).to(device)
        tar_energy = tar.view(-1,1).to(device)
        vae_optimizer.zero_grad()
        x_pred, mu, logvar, energetic_embed = vae_model(data, tar_structure, tar_energy)
        latent.append(energetic_embed.detach().cpu().numpy())
        loss, _, _ = simplevae_elbo_loss_function_with_energy(x_pred, data, mu, logvar)
        loss.backward()
        vae_optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Total Loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    return loss, latent


def vae_test(epoch):
    vae_model.eval()
    with torch.no_grad():
        for i, (data, tar_structure, tar_energy) in enumerate(v1_loader):
            data = data.to(device)
            tar_structure = tar.view(-1,1).to(device)
            tar_energy = tar.view(-1,1).to(device)
            vae_optimizer.zero_grad()
            x_pred, mu, logvar, _ = vae_model(data, tar_structure, tar_energy)
            loss, _, _  = simplevae_elbo_loss_function_with_energy(x_pred, data, mu, logvar)

    return loss


"""

Body part of prediction model training

"""

predictor_model = ResidualNet(args.latent_dim, 14, int((args.latent_dim+14)/2))
predictor_optimizer = optim.Adam(model2.parameters(), lr=1e-5)

def pred_train(epoch):
    structural_estimator_model.eval()
    energetic_estimator_model.eval()
    vae_model.eval()
    predictor_model.train()
    latent = []
    matches = 0
    for batch_idx, (data, tar_structure, tar_energy) in enumerate(v1_loader2):
        data = data.to(device)
        context_structure = structural_estimator_model(data).argmax(dim = 1).view(-1,1)
        context_energy = energetic_estimator_model(data).argmax(dim = 1).view(-1,1)
        _,_,_, energetic_embedding = vae_model(data, context_structure, context_energy)
        tar_structure = tar_structure.view(-1,1).to(torch.long).to(device)
        tar_energy = tar_energy.view(-1,1).to(device)
        predictor_optimizer.zero_grad()
        pred_energy = predictor_model(energetic_embedding)
        latent.append(energetic_embedding.detach().cpu().numpy())
        loss = pred_loss(tar_energy, pred_energy)
        loss.backward()
        predictor_optimizer.step()
        matches += accuracy(tar_energy, pred_energy).item()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Total Loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
    accuracy = matches/len(X_val1)
    return loss, latent, accuracy
    
def pred_test(epoch):
    structural_estimator_model.eval()
    energetic_estimator_model.eval()
    vae_model.eval()
    predictor_model.eval()
    latent = []
    matches = 0
    with torch.no_grad():
        for batch_idx, (data, tar_structure, tar_energy) in enumerate(v2_loader):
            data = data.to(device)
            context_structure = structural_estimator_model(data).argmax(dim = 1).view(-1,1)
            context_energy = energetic_estimator_model(data).argmax(dim = 1).view(-1,1)
            _,_,_, energetic_embedding = vae_model(data, context_structure, context_energy)
            tar_structure = tar_structure.view(-1,1).to(torch.long).to(device)
            tar_energy = tar_energy.view(-1,1).to(device)
            predictor_optimizer.zero_grad()
            pred_energy = predictor_model(energetic_embedding)
            latent.append(energetic_embedding.detach().cpu().numpy())
            loss = pred_loss(tar_energy, pred_energy)
            matches += accuracy(tar, tar_pred).item()
    accuracy = matches/len(X_val2)
    return loss, accuracy
    
    
    
    

pre_train_loss_list = []
pre_test_loss_list = []
pretrain_accuracy =[]
pretest_accuracy = []
for epoch in range(1, args.epochs_pre + 1):
    pre_train_loss, acc_train = pre_pred_train(epoch)
    pre_test_loss, acc_val = pre_pred_test(epoch)
    pre_train_loss_list.append(pre_train_loss)
    pre_test_loss_list.append(pre_test_loss)
    pretrain_accuracy.append(acc_train)
    pretest_accuracy.append(acc_val)
print('pretraining done')

train_err_list = []
test_err_list = []
latent_list = []
for epoch in range(1, args.epochs_vae + 1):
    train_err, latent = train(epoch)
    test_err = test(epoch)
    train_err_list.append(train_err)
    test_err_list.append(test_err)
    latent_list.append(latent)
print('latent space training done')

train_pred_loss_list = []
test_pred_loss_list = []
latent_list_pred = []
train_accuracy =[]
test_accuracy = []
for epoch in range(1, args.epochs_pred + 1):
    train_pred_loss, latent2, acc_train = train_pred(epoch)
    test_pred_loss, acc_val = test_pred(epoch)
    train_pred_loss_list.append(train_pred_loss)
    test_pred_loss_list.append(test_pred_loss)
    latent_list_pred.append(latent2)
    train_accuracy.append(acc_train)
    test_accuracy.append(acc_val)


# The program will save the data for plotting and recovery of modeling
# This part saves the MAPE loss in each epoch
with open('MAPE_simpleVAE_%d.npz' % args.epochs_pred,'wb') as f:
    np.savez(f, train_err = train_err_list, test_err = test_err_list)
torch.save(model.state_dict(),'model_simpleVAE_%d.pth' % args.epochs_pred)


TypeError: __init__() missing 1 required positional argument: 'y_energy'

In [None]:
# plot average MAPE over all iterations
plt.title("Error For Cohesive Energy Prediction")
plt.plot(train_err_list, label='train')
plt.plot(test_err_list, label='test')
plt.xlabel("Epoch")
plt.ylabel("KLD + MSE")
plt.grid()
plt.legend()
plt.show()
#plt.savefig('simpleVAE_mape.png') # can be saved as .svg file

In [None]:
for i, err in enumerate(train_err_list):
    train_err_list[i] = err.cpu().detach().numpy()
np.save('vae_cond_mm_train_loss', train_err_list)

for i, err in enumerate(test_err_list):
    test_err_list[i] = err.cpu().detach().numpy()
np.save('vae_cond_mm_test_loss', test_err_list)

for i, err in enumerate(pre_train_loss_list):
    pre_train_loss_list[i] = err.cpu().detach().numpy()
np.save('vae_cond_mm_pre_train_loss', pre_train_loss_list)

for i, err in enumerate(pre_test_loss_list):
    pre_test_loss_list[i] = err.cpu().detach().numpy()
np.save('vae_cond_mm_pre_test_loss', pre_test_loss_list)

for i, err in enumerate(train_pred_loss_list):
    train_pred_loss_list[i] = err.cpu().detach().numpy()
np.save('vae_cond_mm_train_pred_loss', train_pred_loss_list)

for i, err in enumerate(test_pred_loss_list):
    test_pred_loss_list[i] = err.cpu().detach().numpy()
np.save('vae_cond_mm_test_pred_loss', test_pred_loss_list)

np.save('vae_cond_mm_label', latent_list)

np.save('vae_cond_mm_label_pred', latent_list_pred)

np.save('vae_cond_mm_train_acc', train_accuracy)

np.save('vae_cond_mm_test_acc', test_accuracy)

np.save('vae_cond_mm_pretrain_acc', pretrain_accuracy)

np.save('vae_cond_mm_pretest_acc', pretest_accuracy)