In [17]:
"""
Since the output of at any CVT points for HDPE is not available,
we will not carry out evaluating metrics on the generated virtual samples,
as the way that have done for magical_sinus benchmarking dataset. Instead,
we indirectly confirm the quality of the generated samples by showing the
ability of our RegCGAN on capturing P(y|x).
"""


import importlib

import dataset, metrics, plotting, config, network
from models import reg_cgan_model
import numpy as np
import random

importlib.reload(network)
importlib.reload(dataset)
importlib.reload(metrics)
importlib.reload(plotting)
importlib.reload(config)
importlib.reload(reg_cgan_model)

<module 'models.reg_cgan_model' from '/Users/zhongsheng/Documents/GitWorkspace/RegCGAN/models/reg_cgan_model.py'>

In [24]:
import os

dataset_config = config.DatasetConfig(scenario="hdpeuce")

assert(dataset_config.scenario == "magical_sinus"
      or dataset_config.scenario == "hdpeuce")
fig_dir = f"../figures/{dataset_config.scenario}"

try:
    os.mkdir(fig_dir)
    print(f"Directory {fig_dir} created ")
except FileExistsError:
    print(f"Directory {fig_dir} already exists replacing files in this notebook")

Directory ../figures/hdpeuce already exists replacing files in this notebook


In [25]:
exp_config = config.Config(
    model=config.ModelConfig(activation="elu", lr_gen=0.0001, lr_disc=0.0001,
                             optim_gen="Adam", optim_disc="Adam", z_input_size=5),
    training=config.TrainingConfig(n_epochs=10000, batch_size=100, n_sampling=200),
    dataset=dataset_config,
    run=config.RunConfig(save_fig=1)
)

In [26]:
# Set random seed
np.random.seed(exp_config.model.random_seed)
random.seed(exp_config.model.random_seed)

from tensorflow import set_random_seed
set_random_seed(exp_config.model.random_seed)

In [27]:
X_train, y_train, X_valid, y_valid = dataset.get_dataset(scenario=exp_config.dataset.scenario,
                                                                         seed=exp_config.model.random_seed)

In [28]:
from sklearn.preprocessing import StandardScaler
X_scaler = StandardScaler()
X_train_scaled = X_scaler.fit_transform(X_train)
X_valid_scaled = X_scaler.fit_transform(X_valid)

y_scaler = StandardScaler()
y_train_scaled = y_scaler.fit_transform(y_train.reshape(-1, 1))
y_valid_scaled = y_scaler.fit_transform(y_valid.reshape(-1, 1))

## Gaussian Process

In [29]:
import GPy

variance = 0.1
length = 1

kernel = GPy.kern.RBF(input_dim=15, variance=variance, lengthscale=length)
gpr = GPy.models.GPRegression(X_train_scaled, y_train_scaled, kernel)

run_hyperopt_search = True
if run_hyperopt_search:
    gpr.optimize(messages=True)

HBox(children=(VBox(children=(IntProgress(value=0, max=1000), HTML(value=''))), Box(children=(HTML(value=''),)…

## Construct CGAN model

In [None]:
regcgan = reg_cgan_model.RegCGAN(exp_config)
d_loss_err, d_loss_true, d_loss_fake, g_loss_err, g_pred, g_true = regcgan.train(X_train, y_train,
                                                                              epochs=exp_config.training.n_epochs,
                                                                              batch_size=exp_config.training.batch_size,
                                                                              verbose=True)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Generator_input_x (InputLayer)  (None, 15)           0                                            
__________________________________________________________________________________________________
Generator_input_z (InputLayer)  (None, 5)            0                                            
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 60)           960         Generator_input_x[0][0]          
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 60)           360         Generator_input_z[0][0]          
__________________________________________________________________________________________________
concatenat



Epoch: 0 / dLoss: 12.228018760681152 / gLoss: 0.28676560521125793
Epoch: 1 / dLoss: 13.788334846496582 / gLoss: 0.01779765449464321
Epoch: 2 / dLoss: 9.354936599731445 / gLoss: 1.345212459564209
Epoch: 3 / dLoss: 0.0019480367191135883 / gLoss: 14.967238426208496
Epoch: 4 / dLoss: 0.0021685969550162554 / gLoss: 15.802855491638184
Epoch: 5 / dLoss: 0.007570695132017136 / gLoss: 15.96017074584961
Epoch: 6 / dLoss: 0.019252179190516472 / gLoss: 16.118101119995117
Epoch: 7 / dLoss: 0.03875555843114853 / gLoss: 16.107845306396484
Epoch: 8 / dLoss: 0.13578028976917267 / gLoss: 16.118101119995117
Epoch: 9 / dLoss: 0.1393631547689438 / gLoss: 16.118101119995117
Epoch: 10 / dLoss: 0.1544414460659027 / gLoss: 16.118101119995117
Epoch: 11 / dLoss: 0.1897730529308319 / gLoss: 16.118101119995117
Epoch: 12 / dLoss: 0.19185158610343933 / gLoss: 16.118101119995117
Epoch: 13 / dLoss: 0.06096334010362625 / gLoss: 16.118101119995117
Epoch: 14 / dLoss: 0.07821130007505417 / gLoss: 16.118101119995117
Epoch:

Epoch: 120 / dLoss: 4.054375494888518e-06 / gLoss: 16.118101119995117
Epoch: 121 / dLoss: 8.47651608637534e-06 / gLoss: 16.118101119995117
Epoch: 122 / dLoss: 8.021126996027306e-05 / gLoss: 16.118101119995117
Epoch: 123 / dLoss: 8.012338366825134e-05 / gLoss: 16.118101119995117
Epoch: 124 / dLoss: 1.3690977539226878e-05 / gLoss: 16.118101119995117
Epoch: 125 / dLoss: 2.5149642169708386e-05 / gLoss: 16.118101119995117
Epoch: 126 / dLoss: 3.1593932362739e-05 / gLoss: 16.118101119995117
Epoch: 127 / dLoss: 4.960812293575145e-05 / gLoss: 16.118101119995117
Epoch: 128 / dLoss: 5.778201375505887e-05 / gLoss: 16.118101119995117
Epoch: 129 / dLoss: 0.00014389862190000713 / gLoss: 16.118101119995117
Epoch: 130 / dLoss: 1.2983578017156105e-05 / gLoss: 16.118101119995117
Epoch: 131 / dLoss: 5.105144009576179e-05 / gLoss: 16.118101119995117
Epoch: 132 / dLoss: 0.0001809168024919927 / gLoss: 16.118101119995117
Epoch: 133 / dLoss: 9.366601443616673e-05 / gLoss: 16.118101119995117
Epoch: 134 / dLoss:

Epoch: 238 / dLoss: 1.5037245248095132e-05 / gLoss: 16.118101119995117
Epoch: 239 / dLoss: 6.66470123178442e-06 / gLoss: 16.118101119995117
Epoch: 240 / dLoss: 2.7140202291775495e-05 / gLoss: 16.118101119995117
Epoch: 241 / dLoss: 8.948137292463798e-06 / gLoss: 16.118101119995117
Epoch: 242 / dLoss: 3.724915222846903e-05 / gLoss: 16.118101119995117
Epoch: 243 / dLoss: 1.9643995983642526e-05 / gLoss: 16.118101119995117
Epoch: 244 / dLoss: 1.1353103218425531e-05 / gLoss: 16.118101119995117
Epoch: 245 / dLoss: 8.364590030396357e-06 / gLoss: 16.118101119995117
Epoch: 246 / dLoss: 8.719669494894333e-06 / gLoss: 16.118101119995117
Epoch: 247 / dLoss: 2.717143070185557e-05 / gLoss: 16.118101119995117
Epoch: 248 / dLoss: 4.7795842874620575e-06 / gLoss: 16.118101119995117
Epoch: 249 / dLoss: 7.540217211499112e-06 / gLoss: 16.118101119995117
Epoch: 250 / dLoss: 1.4824593563389499e-05 / gLoss: 16.118101119995117
Epoch: 251 / dLoss: 1.2833375876653008e-05 / gLoss: 16.118101119995117
Epoch: 252 / d

Epoch: 355 / dLoss: 5.053816948930034e-06 / gLoss: 16.118101119995117
Epoch: 356 / dLoss: 5.80244704906363e-06 / gLoss: 16.118101119995117
Epoch: 357 / dLoss: 1.0648957868397702e-05 / gLoss: 16.118101119995117
Epoch: 358 / dLoss: 8.449427696177736e-06 / gLoss: 16.118101119995117
Epoch: 359 / dLoss: 8.350467396667227e-06 / gLoss: 16.118101119995117
Epoch: 360 / dLoss: 2.22673884309188e-06 / gLoss: 16.118101119995117
Epoch: 361 / dLoss: 9.448383934795856e-06 / gLoss: 16.118101119995117
Epoch: 362 / dLoss: 8.63412697071908e-06 / gLoss: 16.118101119995117
Epoch: 363 / dLoss: 7.221046871563885e-06 / gLoss: 16.118101119995117
Epoch: 364 / dLoss: 9.615260751161259e-06 / gLoss: 16.118101119995117
Epoch: 365 / dLoss: 6.407592081814073e-06 / gLoss: 16.118101119995117
Epoch: 366 / dLoss: 2.2207643723959336e-06 / gLoss: 16.118101119995117
Epoch: 367 / dLoss: 2.1814580577483866e-06 / gLoss: 16.118101119995117
Epoch: 368 / dLoss: 1.3159478839952499e-06 / gLoss: 16.118101119995117
Epoch: 369 / dLoss:

Epoch: 475 / dLoss: 2.513958747840661e-07 / gLoss: 16.118101119995117
Epoch: 476 / dLoss: 2.502037830254267e-07 / gLoss: 16.118101119995117
Epoch: 477 / dLoss: 3.1219269658322446e-07 / gLoss: 16.118101119995117
Epoch: 478 / dLoss: 2.788140704979014e-07 / gLoss: 16.118101119995117
Epoch: 479 / dLoss: 2.1920931203567307e-07 / gLoss: 16.118101119995117
Epoch: 480 / dLoss: 1.0584783467493253e-06 / gLoss: 16.118101119995117
Epoch: 481 / dLoss: 2.4901169126678724e-07 / gLoss: 16.118101119995117
Epoch: 482 / dLoss: 2.7642988698062254e-07 / gLoss: 16.118101119995117
Epoch: 483 / dLoss: 1.8548272464613547e-06 / gLoss: 16.118101119995117
Epoch: 484 / dLoss: 1.0239047014692915e-06 / gLoss: 16.118101119995117
Epoch: 485 / dLoss: 1.0656265203579096e-06 / gLoss: 16.118101119995117
Epoch: 486 / dLoss: 9.964838909581886e-07 / gLoss: 16.118101119995117
Epoch: 487 / dLoss: 1.0596638730930863e-06 / gLoss: 16.118101119995117
Epoch: 488 / dLoss: 9.917134775605518e-07 / gLoss: 16.118101119995117
Epoch: 489 

Epoch: 593 / dLoss: 9.094388246921881e-07 / gLoss: 16.118101119995117
Epoch: 594 / dLoss: 4.469016516850388e-07 / gLoss: 16.118101119995117
Epoch: 595 / dLoss: 4.4451746816775994e-07 / gLoss: 16.118101119995117
Epoch: 596 / dLoss: 2.2397769328108552e-07 / gLoss: 16.118101119995117
Epoch: 597 / dLoss: 4.457094746612711e-07 / gLoss: 16.118101119995117
Epoch: 598 / dLoss: 2.2397769328108552e-07 / gLoss: 16.118101119995117
Epoch: 599 / dLoss: 2.2874606031564326e-07 / gLoss: 16.118101119995117
Epoch: 600 / dLoss: 2.1920931203567307e-07 / gLoss: 16.118101119995117
Epoch: 601 / dLoss: 2.2397769328108552e-07 / gLoss: 16.118101119995117
Epoch: 602 / dLoss: 2.1920931203567307e-07 / gLoss: 16.118101119995117
Epoch: 603 / dLoss: 2.1920931203567307e-07 / gLoss: 16.118101119995117
Epoch: 604 / dLoss: 2.2397769328108552e-07 / gLoss: 16.118101119995117
Epoch: 605 / dLoss: 2.2397769328108552e-07 / gLoss: 16.118101119995117
Epoch: 606 / dLoss: 2.1920931203567307e-07 / gLoss: 16.118101119995117
Epoch: 60

Epoch: 710 / dLoss: 4.814715452994278e-07 / gLoss: 16.118101119995117
Epoch: 711 / dLoss: 3.598771627366659e-07 / gLoss: 16.118101119995117
Epoch: 712 / dLoss: 3.491482516437827e-07 / gLoss: 16.118101119995117
Epoch: 713 / dLoss: 2.2159349555295194e-07 / gLoss: 16.118101119995117
Epoch: 714 / dLoss: 2.2159349555295194e-07 / gLoss: 16.118101119995117
Epoch: 715 / dLoss: 7.389652409983682e-07 / gLoss: 16.118101119995117


In [None]:
plotting.plot_training_curve(d_loss_err, d_loss_true, d_loss_fake, g_loss_err, g_pred, g_true, fig_dir, exp_config.run.save_fig)

## Generate pairs of virtual samples

In [None]:
from os.path import basename
X_cvt = np.load(f"{fig_dir}/{basename(fig_dir)}_cvt_samples.npy")
X_cvt_scaled = X_scaler.transform(X_cvt)

In [None]:
ypred_recgan_cvt = regcgan.predict(X_cvt_scaled)
ypred_gp_cvt, cov_cvt = gpr.predict(X_cvt_scaled)

In [None]:
ypred_recgan_cvt = y_scaler.inverse_transform(ypred_recgan_cvt)
ypred_gp_cvt = y_scaler.inverse_transform(ypred_gp_cvt)

In [None]:
X_cvt_paird = np.c_[X_cvt, ypred_recgan_cvt]

In [None]:
from os.path import basename
np.save(f"{fig_dir}/{basename(fig_dir)}_CVT_samples_paired.npy", X_cvt_paird)
