In [5]:
from scipy.io import savemat

import pandas as pd
import numpy as np
import pickle as pkl
from sklearn.preprocessing import StandardScaler

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

import tqdm
import seaborn as sns

from models import MLP, TandemNet, cVAE, cGAN, INN
from utils import evaluate_simple_inverse, evaluate_tandem_accuracy, evaluate_vae_inverse, evaluate_gan_inverse, evaluate_inn_inverse
from configs import get_configs
from plotting_utils import compare_cie_dist, compare_param_dist, plot_cie, plot_cie_raw_pred
from datasets import get_dataloaders, SiliconColor

from sklearn.metrics import r2_score

#DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = 'cpu'
train_loader, val_loader, test_loader = get_dataloaders('tandem_net')


###  Loading Data Predicted for Direct iinverse training

In [61]:
forward_model = MLP(4, 3).to(DEVICE)
forward_model.load_state_dict(torch.load('./models/forward_model.pth',map_location=torch.device('cpu'))['model_state_dict'])
inverse_model = MLP(3, 4).to(DEVICE)
inverse_model.load_state_dict(torch.load('./models/inverse_model.pth',map_location=torch.device('cpu'))['model_state_dict'])




cie_raw, param_raw, cie_pred, param_pred = evaluate_simple_inverse(forward_model, inverse_model, test_loader.dataset)
# Saving the predicted data
mdic = {"param_pred": param_pred}
savemat("data_predicted\param_inverse_pred.mat",mdic)

# Saving testing data (all modeel are the same)
mdic = {"param_test": param_raw,"CIE_x": cie_raw}
savemat("testing_data.mat",mdic)

Simple net Design RMSE loss 1.683
Simple net RMSE loss 117.920
Reconstruct RMSE loss 1.175
Reconstruct RMSE loss raw 0.593


### Loading Data predicted for Tandem (fixed decoder)

In [48]:
tandem_model = TandemNet(forward_model, inverse_model)
tandem_model.load_state_dict(torch.load('./models/tandem_net.pth',map_location=torch.device('cpu'))['model_state_dict'])
cie_raw, param_raw, cie_pred, param_pred = evaluate_tandem_accuracy(tandem_model, test_loader.dataset)

mdic = {"param_pred": param_pred}
savemat("data_predicted\param_tandem_pred.mat",mdic)



Tandem net Design RMSE loss 3.639
Tandem Design RMSE loss 249.367
Reconstruct RMSE loss 0.306
Reconstruct RMSE loss raw 0.121


### Loading VAE

In [49]:
configs = get_configs('vae')
vae_model = cVAE(configs['input_dim'], configs['latent_dim']).to(DEVICE)
vae_model.load_state_dict(torch.load('./models/vae.pth',map_location=torch.device('cpu'))['model_state_dict'])

cie_raw, param_raw, cie_pred, param_pred = evaluate_vae_inverse(forward_model, vae_model, configs, test_loader.dataset)

# Saving the predicted data
mdic = {"param_pred": param_pred}
savemat("data_predicted\param_vae_pred.mat",mdic)



Simple net Design RMSE loss 1.955
Simple net RMSE loss 139.427
Reconstruct RMSE loss 4.398
Reconstruct RMSE loss raw 1.258


### Loading GAN


In [50]:
configs = get_configs('gan')
cgan = cGAN(3, 4, noise_dim = configs['noise_dim'], hidden_dim = 128).to(DEVICE)
cgan.load_state_dict(torch.load('./models/gan.pth',map_location=torch.device('cpu'))['model_state_dict'])

cie_raw, param_raw, cie_pred, param_pred = evaluate_gan_inverse(forward_model, cgan, configs, test_loader.dataset)

# Saving the predicted data
mdic = {"param_pred": param_pred}
savemat("data_predicted\param_gan_pred.mat",mdic)


Simple net Design RMSE loss 2.851
Simple net RMSE loss 183.594
Reconstruct RMSE loss 4.276
Reconstruct RMSE loss raw 2.079


### Loading INN

In [51]:
configs = get_configs('inn')
model = INN(configs['ndim_total'], configs['input_dim'], configs['output_dim'], dim_z = configs['latent_dim']).to(DEVICE)
model.load_state_dict(torch.load('./models/inn.pth',map_location=torch.device('cpu'))['model_state_dict'], strict=False)

train_loader, val_loader, test_loader = get_dataloaders('tandem_net')
cie_raw, param_raw, cie_pred, param_pred = evaluate_inn_inverse(forward_model, model, configs, test_loader.dataset)

# Saving the predicted data
mdic = {"param_pred": param_pred}
savemat("data_predicted\param_inn_pred.mat",mdic)



Simple net Design RMSE loss 2.085
Simple net RMSE loss 148.667
Reconstruct RMSE loss 6.459
Reconstruct RMSE loss raw 2.243
