In [4]:
from scipy.io import savemat

import pandas as pd
import numpy as np
import pickle as pkl
from sklearn.preprocessing import StandardScaler

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

import tqdm
import seaborn as sns

from models import MLP, TandemNet, cVAE, cGAN, INN
from utils import evaluate_simple_inverse, evaluate_tandem_accuracy, evaluate_vae_inverse, evaluate_gan_inverse, evaluate_inn_inverse
from configs import get_configs
from plotting_utils import compare_cie_dist, compare_param_dist, plot_cie, plot_cie_raw_pred
from datasets import get_dataloaders, SiliconColor

from sklearn.metrics import r2_score

#DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = 'cpu'
train_loader, val_loader, test_loader = get_dataloaders('tandem_net')


### Define function for checking the correctness of predicted structures

In [46]:
def struc_check(structure):
    if np.sum(abs(structure)-structure)>0:  # if get negative parameters, then wrong structure
        return 0
    else:
        struc = np.reshape(structure, (-1, 4));
        N = np.shape(struc)[0]
        print(struc)
        for i in range(N):
            if (struc[i,1]+struc[i,3]>=struc[i,2]):  # if gap+diameter >= period, then wrong structure
                return 0;
            
        return 1;

###  Loading Data Predicted for Direct inverse training

In [31]:
forward_model = MLP(4, 3).to(DEVICE)
forward_model.load_state_dict(torch.load('./models/forward_model_trained.pth',map_location=torch.device('cpu'))['model_state_dict'])
inverse_model = MLP(3, 4).to(DEVICE)
inverse_model.load_state_dict(torch.load('./models/inverse_model_trained.pth',map_location=torch.device('cpu'))['model_state_dict'])


cie_raw, param_raw, cie_pred, param_pred = evaluate_simple_inverse(forward_model, inverse_model, test_loader.dataset)
M = np.shape(cie_raw)[0]
# Saving the predicted data and also the original training data together 
j = 0
for i in range(M):
    
    
    
mdic = {"param_pred": param_pred,"param_test": param_raw,"CIE_x": cie_raw}
savemat("data_predicted\param_inverse_pred.mat",mdic)

# Saving testing data (all modeel are the same)
mdic = {"param_test": param_raw,"CIE_x": cie_raw}
savemat('data_predicted\data_testing.mat',mdic)


Simple net Design RMSE loss 1.659
Simple net RMSE loss 111.462
Reconstruct RMSE loss 0.612
Reconstruct RMSE loss raw 0.076


### Loading Data predicted for Tandem (fixed decoder)

In [35]:
tandem_model = TandemNet(forward_model, inverse_model)
tandem_model.load_state_dict(torch.load('./models/tandem_net_trained.pth',map_location=torch.device('cpu'))['model_state_dict'])
cie_raw, param_raw, cie_pred, param_pred = evaluate_tandem_accuracy(tandem_model, test_loader.dataset)

mdic = {"param_pred": param_pred,"param_test": param_raw,"CIE_x": cie_raw}
savemat("data_predicted\param_tandem_pred.mat",mdic)


Tandem net Design RMSE loss 2.273
Tandem Design RMSE loss 153.133
Reconstruct RMSE loss 0.203
Reconstruct RMSE loss raw 0.025


### Loading VAE

In [42]:
configs = get_configs('vae')
vae_model = cVAE(configs['input_dim'], configs['latent_dim']).to(DEVICE)
vae_model.load_state_dict(torch.load('./models/vae_trained.pth',map_location=torch.device('cpu'))['model_state_dict'])

param_pred = np.zeros([M, 4*5])
for i in range(5):
    cie_raw, param_raw, cie_pred, param_pred[:,(4*i):(4*i+4)] = evaluate_vae_inverse(forward_model, vae_model, configs, test_loader.dataset)

for i in range(M):
    
# Saving the predicted data
mdic = {"param_pred": param_pred,}
savemat("data_predicted\param_vae_pred.mat",mdic)
print(param_pred[1,:])



Simple net Design RMSE loss 1.977
Simple net RMSE loss 127.359
Reconstruct RMSE loss 0.846
Reconstruct RMSE loss raw 0.107
Simple net Design RMSE loss 1.971
Simple net RMSE loss 126.434
Reconstruct RMSE loss 0.851
Reconstruct RMSE loss raw 0.108
Simple net Design RMSE loss 1.970
Simple net RMSE loss 127.684
Reconstruct RMSE loss 0.848
Reconstruct RMSE loss raw 0.108
Simple net Design RMSE loss 1.972
Simple net RMSE loss 127.008
Reconstruct RMSE loss 0.854
Reconstruct RMSE loss raw 0.108
Simple net Design RMSE loss 1.988
Simple net RMSE loss 126.845
Reconstruct RMSE loss 0.845
Reconstruct RMSE loss raw 0.107
[117.99251203 241.46220995 576.95083995 149.44528419 130.49229993
 284.9304042  596.87736339 143.02588886 111.64615283 240.64331808
 583.27546121 150.71603313 131.50109387 251.84657023 580.34361379
 143.81422012 114.71137796 221.41466188 536.07680277 152.13184117]


### Loading GAN


In [50]:
configs = get_configs('gan')
cgan = cGAN(3, 4, noise_dim = configs['noise_dim'], hidden_dim = 128).to(DEVICE)
cgan.load_state_dict(torch.load('./models/gan_trained.pth',map_location=torch.device('cpu'))['model_state_dict'])

cie_raw, param_raw, cie_pred, param_pred = evaluate_gan_inverse(forward_model, cgan, configs, test_loader.dataset)

# Saving the predicted data
mdic = {"param_pred": param_pred}
savemat("data_predicted\param_gan_pred.mat",mdic)


Simple net Design RMSE loss 2.851
Simple net RMSE loss 183.594
Reconstruct RMSE loss 4.276
Reconstruct RMSE loss raw 2.079


### Loading INN

In [12]:
configs = get_configs('inn')
model = INN(configs['ndim_total'], configs['input_dim'], configs['output_dim'], dim_z = configs['latent_dim']).to(DEVICE)
model.load_state_dict(torch.load('./models/inn_trained.pth',map_location=torch.device('cpu'))['model_state_dict'], strict=False)

train_loader, val_loader, test_loader = get_dataloaders('tandem_net')
cie_raw, param_raw, cie_pred, param_pred = evaluate_inn_inverse(forward_model, model, configs, test_loader.dataset)

# Saving the predicted data
mdic = {"param_pred": param_pred}
savemat("data_predicted\param_inn_pred.mat",mdic)
print(param_pred)

cie_raw, param_raw, cie_pred, param_pred = evaluate_inn_inverse(forward_model, model, configs, test_loader.dataset)
print(param_pred)



Simple net Design RMSE loss 1.822
Simple net RMSE loss 117.517
Reconstruct RMSE loss 0.576
Reconstruct RMSE loss raw 0.072
[[191.96848636 203.88141006 477.56673329  92.08189938]
 [115.12096164 248.59299629 583.07927293 152.19496664]
 [ 40.40140844 206.99763644 409.19576999  96.98265555]
 ...
 [ 76.29257146 304.47152441 480.77474021  78.97236482]
 [ 58.08068083 256.07671342 540.75013687 151.31030206]
 [171.16747552 249.17232783 688.30404958 102.63288651]]
Simple net Design RMSE loss 1.877
Simple net RMSE loss 122.762
Reconstruct RMSE loss 0.568
Reconstruct RMSE loss raw 0.071
[[190.27730387 185.1183781  369.71112716  86.56274261]
 [123.72582126 274.2716442  596.42329906 148.16740826]
 [ 54.84954884 227.41168913 499.36919222 100.7150707 ]
 ...
 [133.25130511 246.22498883 475.11334253  85.30973684]
 [ 51.67697672 220.74685232 561.74602552 151.84664993]
 [153.20832597 307.19890767 687.02560922 107.68840573]]
