# REGVAE

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('To enable a high-RAM runtime, select the Runtime > "Change runtime type"')
  print('menu, and then select High-RAM in the Runtime shape dropdown. Then, ')
  print('re-execute this cell.')
else:
  print('You are using a high-RAM runtime!')

Tue Oct 26 15:45:08 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.74       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# ---------------------------------------------------------------------------------------------------

In [2]:
import os
import sys
import numpy as np
import json
import random
import matplotlib.pyplot as plt
%matplotlib inline

from google.colab import drive
import time
import glob

import torch
from torch.nn import functional as F
import torch.optim as optim
import torchvision
from torchvision.utils import save_image
import torch.utils.data as data_utils
import torch.distributions as dist


from sklearn.decomposition import PCA
import pandas as pd
import seaborn as sns
from sklearn.manifold import TSNE
import random
from tqdm import tqdm

from tensorflow.python.client import device_lib
device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 12381423808501878097, name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 15434776576
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 6212672082793317692
 physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0"]

In [3]:
if not os.path.isdir('drive'):
  drive.mount('drive')
else:
  print('drive already mounted')

base_path = os.path.join('drive', 'My Drive', 'Colab Notebooks', 'GEN')
if not os.path.isdir(base_path):
  os.makedirs(base_path)

drive already mounted


In [4]:
PATH = os.path.join("drive", "My Drive", "Colab Notebooks", "GEN") 

sys.path.append(PATH+'/biomatsim')
from dataset.data_loader_synthetic import ToyCell, ToyCellpair
from model_vae_synthetic import VAE
from model_qzfreg_synthetic import QZFREG
from model_main_synthetic import FDVAE, REGVAE, Discriminator
from diva.pixel_cnn_utils import sample

print(torch.cuda.is_available())

True


In [5]:
# https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
##########################
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
##########################

### Load data

In [6]:
# get topography data
dataset_name = 'pdata1_128'

if not os.path.isdir('pdata1_128'): 
  start_time = time.time()
  if not os.path.isfile('pdata1_128.zip'): 
    ! wget -O pdata1_128.zip "https://surfdrive.surf.nl/files/index.php/s/QrMeUvt9ZPIOKnP/download"
  ! unzip -q pdata1_128.zip -d .
  print("Unzipped.")
  print("Elapsed time: {} seconds.".format(time.time()-start_time))
else:
  print('Found folder pdata1_128')

Found folder pdata1_128


In [7]:
# get cell data
dataset_name = 'dataset2_128_fixloc_1scale'

if not os.path.isdir('dataset2_128_fixloc_1scale'):
  start_time = time.time()
  if not os.path.isfile('dataset2_128_fixloc_1scale.zip'): 
    ! wget -O dataset2_128_fixloc_1scale.zip "https://surfdrive.surf.nl/files/index.php/s/VuU2UKuHlBpvxgh/download"
  ! unzip -q dataset2_128_fixloc_1scale.zip -d .
  print("Unzipped.")
  print("Elapsed time: {} seconds.".format(time.time()-start_time))
else:
  print('Found folder dataset2_128_fixloc_1scale')

Found folder dataset2_128_fixloc_1scale


### Initialize FDVAE

In [8]:
model_name_fdvae = 'FDVAE_seed_1_zdims_222_2_v1'  #_epoch_20'
model_path_fdvae = PATH+'/model_output_synthetic/fdvae/' + model_name_fdvae

print(model_name_fdvae)
print(model_path_fdvae)

FDVAE_seed_1_zdims_222_2_v1
drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/fdvae/FDVAE_seed_1_zdims_222_2_v1


In [9]:
with open(model_path_fdvae + '/' + model_name_fdvae +\
          '.json', 'r') as configfile:
    args_fdvae = json.load(configfile)
args_fdvae = dotdict(args_fdvae)

args_fdvae.cuda = not args_fdvae.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args_fdvae.cuda else "cpu")
kwargs = {'num_workers': 4, 'pin_memory': False} if args_fdvae.cuda else {}

# Set seed
torch.manual_seed(1)
torch.backends.cudnn.benchmark = False
np.random.seed(args_fdvae.seed)

In [10]:
model_fdvae = torch.load(model_path_fdvae + '/' + model_name_fdvae + '_epoch_2' + '.model')

In [11]:
def nparams(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

nparams(model_fdvae)

2678167

In [12]:
# freeze fully
for name, p in model_fdvae.named_parameters():
    p.requires_grad = False

nparams(model_fdvae)

0

### Train

In [13]:
def save_reconstructions_lx1(model, l, x, path=''):
    """
    Generate a cell image based on a topography image. 8 examples.
    zeps is sampled from prior (simulation of experiment; actual application scenario).
    N.B. At training zeps is sampled from the posterior q(ze|x) (Training scheme 2Ax)
    """
    with torch.no_grad():
        x = x[:8]
        l = l[:8]
        x_recon, _, _, _, _, _, _, _, _, _ = model.forward_lx(l)
        sample_t = sample(x_recon)

        comparison = torch.cat([l, sample_t])
        save_image(comparison.cpu(),
                   path + 'reconstruction_lx1_' + str(epoch-1) + '.png', nrow=8)

def save_reconstructions_lx1_same(model, l, x, path='', n=2):
    """
    Generate 8 cell images based on a single topography image. 
    zeps is sampled from prior (simulation of experiment; actual application scenario).
    N.B. At training zeps is sampled from the posterior q(ze|x) (Training scheme 2Ax)
    """
    with torch.no_grad():
        x_rep = x[:n].repeat_interleave(8, dim=0)
        l_rep = l[:n].repeat_interleave(8, dim=0)    
    
        x_recon, _, _, _, _, _, _, _, _, _ = model.forward_lx(l_rep)
        sample_t = sample(x_recon)

        for i in range(n):
            comparison = torch.cat([l_rep[i*8:(i+1)*8], sample_t[i*8:(i+1)*8]])
            save_image(comparison.cpu(),
                    path + 'reconstruction_lx1_same_' + str(epoch-1) + f'{i}.png', nrow=8)

def save_reconstructions_l(model, l, path=''):
    """
    Topography image reconstruction. 8 examples.
    """
    with torch.no_grad():
        l = l[:8]
        l_recon, _, _, _, _, _, _, _, _, _, _, _, _ = model.forward_l(l)
        sample_t = sample(l_recon)

        comparison = torch.cat([l, sample_t])
        save_image(comparison.cpu(),
                   path + 'reconstruction_l_' + str(epoch-1) + '.png', nrow=8)


def save_reconstructions_xl(model, l, x, path=''):
    """
    Generate a topography image based on a cell image. 8 examples.
    Training scheme 2Bp - l_eps sampled from posterior q(le|p) (not used).
    """
    with torch.no_grad():
        x = x[:8]
        l = l[:8]
        l_recon2, _, _, _, _, _, _, _, _, _ = model.forward_xl(l, x)
        sample_t = sample(l_recon2)

        comparison = torch.cat([x, sample_t])
        save_image(comparison.cpu(),
                   path + 'reconstruction_xl_' + str(epoch-1) + '.png', nrow=8)

def save_reconstructions_xl_prior(model, l, x, path=''):
    """
    Generate a topography image based on a cell image.
    Training scheme 2B (used by default) - l_eps sampled from prior.
    Actual application scenario for cell image-conditioned topography generation.
    """
    # Save reconstuction
    with torch.no_grad():
        x = x[:8]
        l = l[:8]
        l_recon2, _, _, _, _, _, _, _, _, _ = model.forward_xl_prior(l, x)
        sample_t = sample(l_recon2)

        comparison = torch.cat([x, sample_t])
        save_image(comparison.cpu(),
                   path + 'reconstruction_xlp_' + str(epoch-1) + '.png', nrow=8)
        
def save_reconstructions_xl_prior_same(model, l, x, path='', n=2):
    """
    Generate 8 topography images for a sinlge cell image. 
    Training scheme 2B (used by default) - l_eps sampled from prior.
    Actual application scenario for cell image-conditioned topography generation.
    """
    with torch.no_grad():
        # x = x[:8]
        # l = l[:8]
        x_rep = x[:n].repeat_interleave(8, dim=0)
        l_rep = l[:n].repeat_interleave(8, dim=0)
        l_recon2, _, _, _, _, _, _, _, _, _ = model.forward_xl_prior(l_rep, x_rep)
        sample_t = sample(l_recon2)

        for i in range(n):
          comparison = torch.cat([x_rep[i*8:(i+1)*8], sample_t[i*8:(i+1)*8]])
          save_image(comparison.cpu(),
                    path + 'reconstruction_xlp_' + str(epoch-1) + f'{i}.png', nrow=8)

In [14]:
"""
l ----TRAIN qzeps qzf_list ----> ze/zf/(le) --freeze--> x'
"""
def run_epoch_lx(data_loader, model, optimizer, epoch, rec_freq=None, path_reconstructions='', train=True):
    #model.train()
    avg_loss = 0
    ce_x = 0
    d_qzx_pzx = 0
    kl_zeps = 0
    kls_zfk = [0] * model.fdvae.num_features

    for batch_idx, (l, fl, x, fx) in tqdm(enumerate(data_loader)):
        l = l.to(device)
        x = x.to(device)

        if train and rec_freq and (epoch % rec_freq == 0) and (batch_idx == 1):
            model.eval()
            #save_reconstructions_lx(model, l, x, path_reconstructions)  # [x, sample_x]
            save_reconstructions_lx1(model, l, x, path_reconstructions)  # [l, sample_x]
            save_reconstructions_lx1_same(model, l, x, path_reconstructions)  # [l, sample_x]

        if train: 
          model.train()
          optimizer.zero_grad()
          loss, CE_x, D_qzx_pzx, KL_zeps, KLs_zfk = model.loss_function_x(l, x)  #l ----TRAIN qzeps qzf_list ----> ze/zf/(le) --freeze--> x'
          loss.backward()
          optimizer.step()
        else:
          with torch.no_grad():
            loss, CE_x, D_qzx_pzx, KL_zeps, KLs_zfk = model.loss_function_x(l, x)  #l ----TRAIN qzeps qzf_list ----> ze/zf/(le) --freeze--> x'

        avg_loss += loss.item()
        ce_x += CE_x.item()
        d_qzx_pzx += D_qzx_pzx.item()
        kl_zeps += KL_zeps#.item()
        for k in range(model.fdvae.num_features):
            kls_zfk[k] += KLs_zfk[k].item()

    avg_loss /= len(data_loader.dataset)
    ce_x /= len(data_loader.dataset)
    d_qzx_pzx /= len(data_loader.dataset)
    kl_zeps /= len(data_loader.dataset)
    for k in range(model.fdvae.num_features):
        kls_zfk[k] /= len(data_loader.dataset)

    return avg_loss, ce_x, d_qzx_pzx, kl_zeps, kls_zfk

In [15]:
"""
MAIN OBJECTIVE:
l ---freeze---------> ze/zf/---
l ---TRAIN qleps ---> le -------TRAIN pl---> l'

REGULARIZE (loss * self.eta):
x ---freeze (fdvae) --> ze/zf/---
l ---TRAIN qleps------> le -----TRAIN pl---> l'  
"""

def run_epoch_l(data_loader, model, optimizer, epoch, with_x=True, rec_freq=None, path_reconstructions='', train=True):
    #model.train()
    avg_loss = 0
    ce_l = 0
    ce_l2 = 0
    kl_zeps = 0
    kl_leps = 0
    kls_zfk = [0] * model.fdvae.num_features
    kl_combined = 0
    kl_combined2 = 0

    for batch_idx, out in tqdm(enumerate(data_loader)):
        if with_x:
          l = out[0].to(device)  # l, fl, x, fx
          x = out[2].to(device)
        else:
           l = out[0].to(device)  # l, fl, x, fx

        if train and rec_freq and (epoch % rec_freq == 0) and (batch_idx == 1):
            model.eval()
            save_reconstructions_l(model, l, path_reconstructions)
            #save_reconstructions_xl(model, l, x, path_reconstructions)
            save_reconstructions_xl_prior(model, l, x, path_reconstructions)
            save_reconstructions_xl_prior_same(model, l, x, path_reconstructions)

        if train: 
          model.train()
          optimizer.zero_grad()
          if with_x:
            loss, CE_l, KL_zeps, KL_leps, KLs_zfk, CE_l2, KL_combined, KL_combined2, l_recon, l_recon2 = model.loss_function_l(l, x)
          else:
            loss, CE_l, KL_zeps, KL_leps, KLs_zfk, CE_l2, KL_combined, KL_combined2, l_recon, l_recon2 = model.loss_function_l(l)  # this is only for the case when we don't want to use the regularizer
          loss.backward()
          optimizer.step()
        else:
          with torch.no_grad():
            if with_x:
              loss, CE_l, KL_zeps, KL_leps, KLs_zfk, CE_l2, KL_combined, KL_combined2, l_recon, l_recon2 = model.loss_function_l(l, x)
            else:
              loss, CE_l, KL_zeps, KL_leps, KLs_zfk, CE_l2, KL_combined, KL_combined2, l_recon, l_recon2 = model.loss_function_l(l)  # this is only for the case when we don't want to use the regularizer

        avg_loss += loss.item()
        ce_l += CE_l.item()
        ce_l2 += CE_l2.item()
        kl_zeps += KL_zeps#.item()
        kl_leps += KL_leps.item()
        for k in range(model.fdvae.num_features):
            kls_zfk[k] += KLs_zfk[k].item()
        kl_combined += KL_combined.item()
        kl_combined2 += KL_combined2.item()

    avg_loss /= len(data_loader.dataset)
    ce_l /= len(data_loader.dataset)
    ce_l2 /= len(data_loader.dataset)
    kl_zeps /= len(data_loader.dataset)
    kl_leps /= len(data_loader.dataset)
    for k in range(model.fdvae.num_features):
        kls_zfk[k] /= len(data_loader.dataset)
    kl_combined /= len(data_loader.dataset)
    kl_combined2 /= len(data_loader.dataset)

    return avg_loss, ce_l, ce_l2, kl_zeps, kl_leps, kls_zfk, kl_combined, kl_combined2

Setting the parameters REGVAE

In [16]:
# REGVAE
args = {
    'no_cuda': False, 
    'seed': 1, 
    'batch_size': 50,  # 20,
    'epochs': 500,
    'lr': 0.01,

    # data info
    'path_l': 'pdata1_128/',
    'path_x': 'dataset2_128_fixloc_1scale/',
    'path_table': 'pdata1_128/', # 'pdata1_128/' ?
    'data_info_filename': 'pdata1_dataset2_w1000_reg2.csv', #'pdata1_dataset2_control.csv', #'pdata1_dataset2_w1000_reg2.csv'
    'features_names_l': ['roundness_g1', 'radius_g2'],
    'features_names_x': ['roundness_f1', 'elongation_f2', 'nucleus_size_f3'],  # 'rotation_angle_f4'

    # Model REGVAE
    'fdvae': model_fdvae,
    'leps_dim': 2, 
    'use_zeps': False,  # ! (not used by default)
    
    # Beta VAE part 
    'lbetas_f': [300, 300, 300],
    'lbeta_eps': 400,  # for z_eps as part of the topography latent space; used only if self.use_zeps (not used by default)
    'beta_leps': 500,
    ## https://arxiv.org/abs/1804.03599 approach for l_eps to control the capacity of the latent space
    'gamma_leps': 0,  # beta_eps is used when gamma = 0
    'c_leps': 3,  # warmup for c
    ##
    'prior_leps_scale_train': False, #True,
    'eta': 0.5,  # weight of the 'cell-conditioned topography design' objective

    ######### NOT USED
    # LH * (D + delta) - True; LH - False
    'kl_llh_multiplier': False, 
    'delta': 10,  # for CE_l(x_recon, x) * (KL + delta), where ze ~ q(ze|l), zf ~ q(zf|l)
    #########

    # not used
    'beta_combined': 0,
    'beta_combined2': 0,

    # warm-up (?)
    'w': 20,  #100,  # 'number of epochs for warm-up.
    'max_beta': 1.,  # 'max beta for warm-up'
    'min_beta': 0.,

    'outpath': PATH+'/model_output_synthetic/regvae',  #'./'    
}

args = dotdict(args)
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': 0, #1,   # https://github.com/pytorch/pytorch/issues/5301
          'pin_memory': False} if args.cuda else {}

In [17]:
##### Model name
model_name = 'REGVAE_seed_' + str(args.seed) + '_main_betas300_v1'  # or control dataset

model_path = args.outpath + '/' + model_name
print(args.outpath)
print(model_name)
print(model_path)

# Set seed
torch.manual_seed(args.seed)
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)

drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/regvae
REGVAE_seed_1_main_betas300_v1
drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/regvae/REGVAE_seed_1_main_betas300_v1


# ----------------------

Data loader

In [18]:
data_info_table = pd.read_csv(args.path_table + args.data_info_filename, index_col=0)

###
np.random.seed(1)
img_list_x_val = list(np.random.choice(data_info_table['filename_x'].values, int(round(data_info_table.shape[0] * 0.2)), replace=False))
img_list_x_train = [img for img in data_info_table['filename_x'].values if img not in img_list_x_val]

print(len(img_list_x_train), len(img_list_x_val))

32627 10000


In [19]:
train_data = ToyCellpair(path_l=args.path_l, 
                         path_x=args.path_x, # + "*",
                         path_table=args.path_table,
                         data_info_filename=args.data_info_filename, 
                         features_names_l=args.features_names_l,
                         features_names_x=args.features_names_x,
                         imgsize=(128, 128),
                         # condition_l="00[3-5]*" # "00[0-2]*",  ##"00[0]*"#"000*",
                         img_list_x=img_list_x_train,
                         scaler=True
                         )
train_loader = data_utils.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, **kwargs)

  "Argument interpolation should be of type InterpolationMode instead of int. "
100%|██████████| 32627/32627 [00:40<00:00, 814.41it/s]
100%|██████████| 32627/32627 [00:55<00:00, 584.73it/s]


In [20]:
val_data = ToyCellpair(path_l=args.path_l, 
                       path_x=args.path_x, # + "*",
                       path_table=args.path_table,
                       data_info_filename=args.data_info_filename, 
                       features_names_l=args.features_names_l,
                       features_names_x=args.features_names_x,
                       imgsize=(128, 128),
                       condition_x="00[3-5]*",         # "00[0-2]*",  ##"00[0]*"#"000*",  ### take part of the data to reduce gpu usage
                       img_list_x=img_list_x_val,
                       scaler=True
                       )
val_loader = data_utils.DataLoader(val_data, batch_size=args.batch_size, shuffle=True, **kwargs)

  "Argument interpolation should be of type InterpolationMode instead of int. "
100%|██████████| 6933/6933 [00:10<00:00, 683.72it/s]
100%|██████████| 6933/6933 [00:10<00:00, 670.69it/s]


In [None]:
# set([el.split("/")[1] for el in train_data.final_filenames_x]).intersection(set([el.split("/")[1] for el in val_data.final_filenames_x])) # set()

# ---------------------

### Setup the model

Step 2

In [21]:
######## REGVAE (STEP2: training the zf encoders of the topography model)
# setup the VAE
model = REGVAE(args).to(device)
##
model.fdvae = model_fdvae
##
model = model.cuda()

print(nparams(model))

# check 
print(model.lqzf_list[0].fc11[0].weight.requires_grad)
print(model.fdvae.qzf_list[0].fc11[0].weight.requires_grad)

# list(model.state_dict().keys())
####################################################

# STEP 1:  freeze qle, pl, train lqzf_list
for name, child in model.named_children():
  print(name)
  if name in ['pl', 'qleps']:
    for p in child.parameters():
      p.requires_grad = False

model.prior_leps_log_scale.requires_grad = False  # relevant when prior_leps_scale_train == True

print(nparams(model))

2678122
True
False
fdvae
pl
qleps
lqzf_list
1672620


In [None]:
############ MORE TRAINING
# model = torch.load(model_path + '/' + model_name + '_epoch_10' + '.model')#, map_location=torch.device('cpu'))
# model = model.to(device)

In [None]:
############ INITIALIZE FROM ANOTHER MODEL
# base_model_name = 'REGVAE_seed_1_main_betas300_v0'
# base_model_path = 'drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/regvae/REGVAE_seed_1_main_betas300_v0'
# model = torch.load(base_model_path + '/' + base_model_name + '_epoch_20' + '.model')#, map_location=torch.device('cpu'))
# model = model.to(device)

Step 3

In [34]:
######## Load the model trained in the previous step.
model = torch.load(model_path + '/' + model_name + '_epoch_1' + '.model')#, map_location=torch.device('cpu'))
model = model.to(device)

In [None]:
######## Load the model trained in the previous step.
# base_model_name = 'REGVAE_seed_1_main_betas300_v0'
# base_model_path = 'drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/regvae/REGVAE_seed_1_main_betas300_v0'
# model = torch.load(base_model_path + '/' + base_model_name + '_epoch_20' + '.model')#, map_location=torch.device('cpu'))
# model = model.to(device)

In [35]:
######## REGVAE (STEP3: training the remaining components of the topography model: qle, pl; freeze lqzf_list)
for name, child in model.named_children():
  print(name)
  if name in ['pl', 'qleps']:
    for p in child.parameters():
      p.requires_grad = True
  if name in ['lqzf_list']:
    for p in child.parameters():
      p.requires_grad = False

model.prior_leps_log_scale.requires_grad = True  # # relevant when prior_leps_scale_train == True

fdvae
pl
qleps
lqzf_list


In [None]:
##################################### Some notes on differents regimes of training 
# ###Here we only train qleps and pl.fc1
# ########## ! now freeze pl (except fc1) and train only qleps
# for name, child in model.named_children():
#   if name in ['pl']:#['pzf_list', 'qzf_list', 'qf_list', 'prior_zf_full_log_scales']:
#     for name1, child1 in child.named_children():
#       if name1 != 'fc1':  # fc1 True
#         print(name1)
#         for p in child1.parameters():
#           p.requires_grad = False

#######
# # we train ONLY pl, including fc1. qleps is frozen
# print('activate pl except pl.fc1')
# for name, child in model.named_children():
#   if name in ['pl']:#['pzf_list', 'qzf_list', 'qf_list', 'prior_zf_full_log_scales']:
#     for name1, child1 in child.named_children():
#     ###if name1 != 'fc1':  # disable for train_pl
#       print(name1)
#       for p in child1.parameters():
#         p.requires_grad = True ### we also train fc1 !!!

# # print('disable qzeps, px.fc1')
# for name, child in model.named_children():
#   if name in ['qleps']:#['pzf_list', 'qzf_list', 'qf_list', 'prior_zf_full_log_scales']:
#     for name1, child1 in child.named_children():
#       print(name1)
#       for p in child1.parameters():
#         p.requires_grad = False

# ###for p in model.pl.fc1.parameters():   # let's also train fc1
# ###   p.requires_grad = False

# ---------------------


In [36]:
# check 
print(model.qleps.conv1.weight.requires_grad)
print(model.lqzf_list[0].fc11[0].weight.requires_grad)
print(model.fdvae.qzf_list[0].fc11[0].weight.requires_grad)
print(model.pl.conv1.weight.requires_grad)

True
False
False
True


# ---------------------


Setup the optimizer

In [37]:
print(args.lr)

0.01


In [38]:
# args.lr = 0.001

In [39]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)

best_loss = 1e+8
best_y_acc = 0.

early_stopping_counter = 1
max_early_stopping = 100

# ---------------------

In [40]:
# check the arguments
{i:args[i] for i in args if i!='fdvae'}

{'batch_size': 50,
 'beta_combined': 0,
 'beta_combined2': 0,
 'beta_leps': 500,
 'c_leps': 3,
 'cuda': True,
 'data_info_filename': 'pdata1_dataset2_w1000_reg2.csv',
 'delta': 10,
 'epochs': 500,
 'eta': 0.5,
 'features_names_l': ['roundness_g1', 'radius_g2'],
 'features_names_x': ['roundness_f1', 'elongation_f2', 'nucleus_size_f3'],
 'gamma_leps': 0,
 'kl_llh_multiplier': False,
 'lbeta_eps': 400,
 'lbetas_f': [300, 300, 300],
 'leps_dim': 2,
 'lr': 0.01,
 'max_beta': 1.0,
 'min_beta': 0.0,
 'no_cuda': False,
 'outpath': 'drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/regvae',
 'path_l': 'pdata1_128/',
 'path_table': 'pdata1_128/',
 'path_x': 'dataset2_128_fixloc_1scale/',
 'prior_leps_scale_train': False,
 'seed': 1,
 'use_zeps': False,
 'w': 20}

In [41]:
model_name, model_path

('REGVAE_seed_1_main_betas300_v1',
 'drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/regvae/REGVAE_seed_1_main_betas300_v1')

In [28]:
os.mkdir(model_path)

import json

with open(model_path + '/' + model_name + '.json', 'w') as configfile:
    json.dump({i:args[i] for i in args if i!='fdvae'}, configfile, indent=2)

In [42]:
def nparams(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

nparams(model)

1005502

# -------------------------

TRAINING -- STEP 2

In [None]:
# check
# model.lbetas_f, args.lbetas_f 

In [None]:
# check
# model.use_zeps

In [30]:
import copy
args_COPY = copy.deepcopy(dict(args))

In [None]:
######## Training loop REGVAE - STEP 2
print('\nStart training [[[STEP 2]]]:', {i:args[i] for i in args if i!='fdvae'})
for epoch in range(1, args.epochs + 1):

    ######################################
    ### WARMUP (increase beta gradually).
    ### Formula: min([final_value, curent_val + (final value - current_val) * (epoch * 1. - epochs_completed) / (args_COPY['w'] - epochs_completed)])
    #model.beta_leps = min([args_COPY['beta_leps'], 0 + (args_COPY['beta_leps'] - 0) * (epoch * 1. - 0) / (args_COPY['w'] - 0)])
    # for k in range(model.fdvae.num_features):
    #     model.lbetas_f[k] = min([args_COPY['lbetas_f'][k], 0 + (args_COPY['lbetas_f'][k] - 0) * (epoch * 1.) / (args_COPY['w'] - 0)])
    # print('BETAS: ', model.beta_leps, model.lbetas_f, args_COPY['beta_leps'], args_COPY['lbetas_f'])
    ######################################

    # Train
    train_loss, ce_x, d_qzx_pzx, kl_zeps, kls_zfk = run_epoch_lx(train_loader, model, optimizer, epoch, 
                                                                 rec_freq=1, path_reconstructions=model_path + '/', train=True)

    # logging train scores
    str_print = "{} EPOCH: avg loss {}".format(epoch, train_loss)
    str_print += f" avg ce_x {ce_x}"
    str_print += f" avg D {d_qzx_pzx}"
    str_print += f" avg KL_eps {kl_zeps} \n"
    for k in range(model.fdvae.num_features):      
        str_print += f", {train_data.features_names_x[k]} KL {kls_zfk[k]}"
    print(str_print)

    # Val
    val_loss, val_ce_x, val_d_qzx_pzx, val_kl_zeps, val_kls_zfk = run_epoch_lx(val_loader, model, optimizer, epoch, 
                                                                               rec_freq=1, path_reconstructions=model_path + '/', train=False)

    # logging val scores
    str_print = "{} EPOCH: VAL avg loss {}".format(epoch, val_loss)
    str_print += f" avg ce_x {val_ce_x}"
    str_print += f" avg D {val_d_qzx_pzx}"
    str_print += f" avg KL_eps {val_kl_zeps} \n"
    for k in range(model.fdvae.num_features):      
        str_print += f", {train_data.features_names_x[k]} KL {kls_zfk[k]}"
    print(str_print)

    if val_loss < best_loss:
        early_stopping_counter = 1

        best_loss = val_loss

        print("saving model: ", model_name + f'_epoch_{epoch}' + '.model')
        torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
    else:
        #######
        if epoch >= 1 and epoch % 10 == 0:
            print("RESERVE saving model: ", model_name + f'_epoch_{epoch}' + '.model')
            torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
        #######   
        early_stopping_counter += 1
        if early_stopping_counter == max_early_stopping:
            break

In [None]:
# 653it [03:50,  2.83it/s]
# 1 EPOCH: avg loss 208956.9319888436 avg ce_x 201253.2048610047 avg D 25.879229548136237 avg KL_eps 0.0 
# , roundness_f1 KL 4.301376469600693, elongation_f2 KL 4.7941591136136275, nucleus_size_f3 KL 16.583553050324113
# 139it [00:17,  8.17it/s]
# 1 EPOCH: VAL avg loss 207863.5752308136 avg ce_x 202214.06073283323 avg D 19.032232664254657 avg KL_eps 0.0 
# , roundness_f1 KL 4.301376469600693, elongation_f2 KL 4.7941591136136275, nucleus_size_f3 KL 16.583553050324113
# saving model:  REGVAE_seed_1_main_betas300_v1_epoch_1.model

# -------------------------

TRAINING -- STEP3

In [43]:
import copy
args_COPY = copy.deepcopy(dict(args))

In [None]:
######## Training loop REGVAE - STEP 3
print('\nStart training [[[STEP 3]]]:', {i:args[i] for i in args if i!='fdvae'})
for epoch in range(1, args.epochs + 1):

    ######################################
    ### WARMUP (increase beta gradually).
    ### Formula: min([final_value, curent_val + (final value - current_val) * (epoch * 1. - epochs_completed) / (args_COPY['w'] - epochs_completed)])
    #model.beta_leps = min([args_COPY['beta_leps'], 0 + (args_COPY['beta_leps'] - 0) * (epoch * 1. - 0) / (args_COPY['w'] - 0)])
    # for k in range(model.fdvae.num_features):
    #     model.lbetas_f[k] = min([args_COPY['lbetas_f'][k], 0 + (args_COPY['lbetas_f'][k] - 0) * (epoch * 1.) / (args_COPY['w'] - 0)])
    # print('BETAS: ', model.beta_leps, model.lbetas_f, args_COPY['beta_leps'], args_COPY['lbetas_f'])
    ######################################

    # Train
    train_loss, ce_l, ce_l2, kl_zeps, \
      kl_leps, kls_zfk, kl_combined, kl_combined2 = run_epoch_l(train_loader, model, optimizer, epoch, with_x=True, 
                                                                rec_freq=1, path_reconstructions=model_path + '/', train=True)

    # logging train scores
    str_print = "{} EPOCH: avg loss {}".format(epoch, train_loss)
    str_print += f" avg ce_l {ce_l}"
    str_print += f" avg ce_l2 {ce_l2}"
    str_print += f" avg KL_eps {kl_zeps}"
    str_print += f" avg KL_leps {kl_leps} \n"
    for k in range(model.fdvae.num_features):      
        str_print += f", {train_data.features_names_x[k]} KL {kls_zfk[k]}"
    str_print += f"\n avg KL_combined {kl_combined}"
    str_print += f" avg KL_combined2 {kl_combined2}"
    print(str_print)

    # Val
    val_loss, val_ce_l, val_ce_l2, val_kl_zeps, \
      val_kl_leps, val_kls_zfk, val_kl_combined, val_kl_combined2 = run_epoch_l(val_loader, model, optimizer, epoch, with_x=True, 
                                                                                rec_freq=1, path_reconstructions=model_path + '/', train=False)

    # logging val scores
    str_print = "{} EPOCH: VAL avg loss {}".format(epoch, val_loss)
    str_print += f" avg ce_l {val_ce_l}"
    str_print += f" avg ce_l2 {val_ce_l2}"
    str_print += f" avg KL_eps {val_kl_zeps}"
    str_print += f" avg KL_leps {val_kl_leps} \n"
    for k in range(model.fdvae.num_features):      
        str_print += f", {train_data.features_names_x[k]} KL {val_kls_zfk[k]}"
    str_print += f"\n avg KL_combined {val_kl_combined}"
    str_print += f" avg KL_combined2 {val_kl_combined2}"
    print(str_print)

    if val_loss < best_loss:
        early_stopping_counter = 1

        best_loss =  val_loss

        print("saving model: ", model_name + f'_epoch_{epoch}' + '.model')
        torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
    else:
        #######
        if epoch >= 1 and epoch % 10 == 0:
            print("RESERVE saving model: ", model_name + f'_epoch_{epoch}' + '.model')
            torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
        #######   
        early_stopping_counter += 1
        if early_stopping_counter == max_early_stopping:
            break

In [None]:
# 653it [06:36,  1.65it/s]
# 1 EPOCH: avg loss 295121.2520458516 avg ce_l 191066.69237134888 avg ce_l2 195054.9827137034 avg KL_eps 0.0 avg KL_leps 2.1511972773435724 
# , roundness_f1 KL 2.4177661060603683, elongation_f2 KL 4.094873486394722, nucleus_size_f3 KL 11.658924163230026
#  avg KL_combined 18.097879616560263 avg KL_combined2 18.049029053512584
# 139it [00:30,  4.62it/s]
# 1 EPOCH: VAL avg loss 257265.68147720717 avg ce_l 167737.74617714944 avg ce_l2 167739.378101558 avg KL_eps 0.0 avg KL_leps 0.020880545373947994 
# , roundness_f1 KL 2.417999768353444, elongation_f2 KL 4.094125511050568, nucleus_size_f3 KL 12.313896843964697
#  avg KL_combined 16.385829108888128 avg KL_combined2 16.40207270178159
# saving model:  REGVAE_seed_1_main_betas300_v1_epoch_1.model

# -------------------------

### ADDING A DISCRIMINATOR (an auxiliary GAN objective)

In [69]:
best_loss_discriminator = 1e+8
early_stopping_counter_discriminator = 1
max_early_stopping_discriminator = 100

criterion = torch.nn.BCELoss()
real_label = 1.
fake_label = 0.

In [55]:
# REGVAE
argsD = {
    'no_cuda': False, 
    'seed': 1, 
    'batch_size': 50,  # 20,
    'epochs': 500,
    'lrd': 0.001, #0.0002,
    #'beta1': 0.5,
    'outpath': PATH+'/model_output_synthetic/D',  #'./'    
}
argsD = dotdict(argsD)

# ##args = parser.parse_args()
# argsD.cuda = not argsD.no_cuda and torch.cuda.is_available()
# device = torch.device("cuda" if argsD.cuda else "cpu")
# kwargsD = {'num_workers': 0, #1,   # https://github.com/pytorch/pytorch/issues/5301
#           'pin_memory': False} if argsD.cuda else {}

Pretrain the discriminator for a few epochs

In [70]:
## Initialize Discriminator for pre-training
model_name_discriminator = 'D_seed_' + str(args.seed) + ''
model_path_discriminator = PATH+'/model_output_synthetic/D/' + model_name_discriminator
print(model_name_discriminator)
print(model_path_discriminator)

discriminator = Discriminator().to(device)

D_seed_1
drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/D/D_seed_1


In [63]:
os.mkdir(model_path_discriminator)

import json

with open(model_path_discriminator + '/' + model_name_discriminator + '.json', 'w') as configfile:
    json.dump({i:argsD[i] for i in argsD}, configfile, indent=2)

In [64]:
# optimizer for the disctiminator
optimizerD = optim.Adam(discriminator.parameters(), lr=argsD.lrd)#, betas=(argsD.beta1, 0.999))

In [65]:
# optimizer for the model
args.lr = 0.001
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)

In [71]:
def run_epoch_discriminator(data_loader, model, discriminator, optimizerD, train=True):
    #model.train()
    #avg_loss = 0
    errd = 0
    errd_real = 0
    errd_fake = 0
    dx = 0
    dg_z1 = 0

    for batch_idx, out in tqdm(enumerate(data_loader)):
        #if with_x:
        l = out[0].to(device)  # l, fl, x, fx
        x = out[2].to(device)
        # else:
        #    l = out[0].to(device)  # l, fl, x, fx

        if train: 
          model.train()
          optimizer.zero_grad()

          l_recon2, qzf_list2, _, zf_q_list2, qzeps2, _, zeps_q2,\
              _, pleps2, leps_p = model.forward_xl_prior(l, x)  # x -> zf (prior -> le) -> l'
          sample_t = sample(l_recon2)

          ############################
          # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
          ###########################
          ## Train with all-real batch
          label = torch.full((l.size(0),), real_label, dtype=torch.float, device=device)
          # Forward pass real batch through D
          output = discriminator.forward(l).view(-1)
          # Calculate loss on all-real batch
          errD_real = criterion(output, label)
          # Calculate gradients for D in backward pass
          errD_real.backward()
          D_x = output.mean().item()

          ## Train with all-fake batch
          fake = sample_t
          label.fill_(fake_label)
          # Classify all fake batch with D
          output = discriminator.forward(fake.detach()).view(-1)
          # Calculate D's loss on the all-fake batch
          errD_fake = criterion(output, label)
          # Calculate the gradients for this batch, accumulated (summed) with previous gradients
          errD_fake.backward()
          D_G_z1 = output.mean().item()
          # Compute error of D as sum over the fake and the real batches
          errD = errD_real + errD_fake
          # Update D
          optimizerD.step()

        else:
          with torch.no_grad():
            l_recon2, qzf_list2, _, zf_q_list2, qzeps2, _, zeps_q2,\
                _, pleps2, leps_p = model.forward_xl_prior(l, x)  # x -> zf (prior -> le) -> l'
            sample_t = sample(l_recon2)

            ## all-real batch
            label = torch.full((l.size(0),), real_label, dtype=torch.float, device=device)
            # Forward pass real batch through D
            output = discriminator.forward(l).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            D_x = output.mean().item()

            # process all-fake batch
            fake = sample_t
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = discriminator.forward(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            D_G_z1 = output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD = errD_real + errD_fake

        errd += errD.item()
        errd_real += errD_real.item()
        errd_fake += errD_fake.item()
        dx += D_x
        dg_z1 += D_G_z1

    errd /= len(data_loader.dataset)
    errd_real /= len(data_loader.dataset)
    errd_fake /= len(data_loader.dataset)
    dx /= len(data_loader.dataset)
    dg_z1 /= len(data_loader.dataset)

    return errd, errd_real, errd_fake, dx, dg_z1

In [None]:
########### Pretrain the Discriminator
print('\nStart training DISCRIMINATOR:', {i:args[i] for i in args if i!='fdvae'})
##G_losses = []
D_losses = []

for epoch in range(1, args.epochs + 1):

    # Train
    errd, errd_real, errd_fake, dx, dg_z1 = run_epoch_discriminator(train_loader, model, discriminator, optimizerD, train=True)

    # logging train scores
    str_print = "{} EPOCH: avg errd {}".format(epoch, errd)
    str_print += f" avg errd_real {errd_real}"
    str_print += f" avg errd_fake {errd_fake}"
    str_print += f" avg dx {dx}"
    str_print += f" avg dg_z1 {dg_z1}"
    print(str_print)

    # Val
    val_errd, val_errd_real, val_errd_fake, val_dx, val_dg_z1 = run_epoch_discriminator(val_loader, model, discriminator, optimizerD, train=False)

    # logging train scores
    str_print = "{} EPOCH: VAL avg errd {}".format(epoch, val_errd)
    str_print += f" avg errd_real {val_errd_real}"
    str_print += f" avg errd_fake {val_errd_fake}"
    str_print += f" avg dx {val_dx}"
    str_print += f" avg dg_z1 {val_dg_z1}"
    print(str_print)

    D_losses.append(val_errd)

    if val_errd < best_loss_discriminator:
        early_stopping_counter_discriminator = 1

        best_loss_discriminator = val_errd

        print("saving model: ", model_name_discriminator + f'_epoch_{epoch}' + '.model')
        torch.save(discriminator, model_path_discriminator + '/' + model_name_discriminator + f'_epoch_{epoch}' + '.model')
    else:
        #######
        if epoch >= 1 and epoch % 10 == 0:
            print("RESERVE saving model: ", model_name_discriminator + f'_epoch_{epoch}' + '.model')
            torch.save(discriminator, model_path_discriminator + '/' + model_name_discriminator + f'_epoch_{epoch}' + '.model')
        #######   
        early_stopping_counter_discriminator += 1
        if early_stopping_counter_discriminator == max_early_stopping_discriminator:
            break

In [None]:
# 653it [02:21,  4.61it/s]
# 1 EPOCH: avg errd 0.025731841070329678 avg errd_real 0.015352839881301742 avg errd_fake 0.01037900117715341 avg dx 0.009599916219280167 avg dg_z1 0.008043021754695004
# 139it [00:17,  8.10it/s]
# 1 EPOCH: VAL avg errd 0.026170091104865144 avg errd_real 0.015713524436648016 avg errd_fake 0.010456566741304183 avg dx 0.00946238425447978 avg dg_z1 0.008092846156060868
# saving model:  D_seed_1_epoch_1.model

Load the pretrained discriminator and train the combined objectve

In [73]:
best_loss_discriminator = 1e+8
early_stopping_counter_discriminator = 1
max_early_stopping_discriminator = 100

criterion = torch.nn.BCELoss()
real_label = 1.
fake_label = 0.

In [76]:
model_name_discriminator = 'D_seed_' + str(args.seed) + ''
model_path_discriminator = PATH+'/model_output_synthetic/D/' + model_name_discriminator
print(model_name_discriminator)
print(model_path_discriminator)

discriminator = torch.load(model_path_discriminator + '/' + model_name_discriminator + '_epoch_1' + '.model')#, map_location=torch.device('cpu'))

D_seed_1
drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/D/D_seed_1


In [None]:
# base_model_name_discriminator = 'D_seed_' + str(args.seed) + ''
# base_model_path_discriminator = PATH+'/FDVA_output/D/' + base_model_name_discriminator
# print(base_model_name_discriminator)
# print(base_model_path_discriminator)

# discriminator = torch.load(base_model_path_discriminator + '/' + base_model_name_discriminator + '_epoch_125' + '.model')#, map_location=torch.device('cpu'))

In [77]:
# optimizer for the disctiminator
optimizerD = optim.Adam(discriminator.parameters(), lr=argsD.lrd)#, betas=(argsD.beta1, 0.999))

In [78]:
# optimizer for the model
args.lr = 0.001
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)

In [79]:
XI_GEN = 1000000

In [82]:
def run_epoch_l_gan(data_loader, model, discriminator, optimizer, optimizerD, epoch, rec_freq=None, path_reconstructions='', train=True):
    #model.train()
    #discriminator.train()
    avg_loss = 0
    ce_l = 0
    ce_l2 = 0
    kl_zeps = 0
    kl_leps = 0
    kls_zfk = [0] * model.fdvae.num_features
    kl_combined = 0
    kl_combined2 = 0

    errd = 0
    errd_real = 0
    errd_fake = 0
    dx = 0
    dg_z1 = 0
    dg_z2 = 0
    errg = 0

    for batch_idx, out in tqdm(enumerate(data_loader)):
        #if with_x:
        l = out[0].to(device)  # l, fl, x, fx
        x = out[2].to(device)
        # else:
        #    l = out[0].to(device)  # l, fl, x, fx

        if train and rec_freq and (epoch % rec_freq == 0) and (batch_idx == 1):
            save_reconstructions_l(model, l, path_reconstructions)
            #save_reconstructions_xl(model, l, x, path_reconstructions)
            save_reconstructions_xl_prior(model, l, x, path_reconstructions)
            save_reconstructions_xl_prior_same(model, l, x, path_reconstructions)

        if train: 
          model.train()
          discriminator.train()
          optimizer.zero_grad()
          optimizerD.zero_grad()

          #if with_x:
          loss, CE_l, KL_zeps, KL_leps, KLs_zfk, CE_l2, KL_combined, KL_combined2, l_recon, l_recon2 = model.loss_function_l(l, x)  # MI, MI2, TC, LOGQZ, LOGQZ_prodmarginals
          # else:
          #   loss, CE_l, KL_zeps, KL_leps, KLs_zfk, CE_l2, KL_combined, KL_combined2, l_recon, _ = model.loss_function_l(l)  # this is only for case when we don't want to use the regularizer

          sample_t = sample(l_recon2)

          ##### (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
          ## Train with all-real batch
          label = torch.full((l.size(0),), real_label, dtype=torch.float, device=device)
          # Forward pass real batch through D
          output = discriminator.forward(l).view(-1)
          # Calculate loss on all-real batch
          errD_real = criterion(output, label)
          # Calculate gradients for D in backward pass
          errD_real.backward()
          D_x = output.mean().item()

          ## Train with all-fake batch
          fake = sample_t
          label.fill_(fake_label)
          # Classify all fake batch with D
          output = discriminator.forward(fake.detach()).view(-1)
          # Calculate D's loss on the all-fake batch
          errD_fake = criterion(output, label)
          # Calculate the gradients for this batch, accumulated (summed) with previous gradients
          errD_fake.backward()
          D_G_z1 = output.mean().item()
          # Compute error of D as sum over the fake and the real batches
          errD = errD_real + errD_fake
          # Update D every 10 batches
          #if epoch % 2 == 1 and batch_idx == 0:
          if batch_idx % 10 == 0:
            optimizerD.step()

          ##### (2) Update G network: maximize log(D(G(z)))
          label.fill_(real_label)  # fake labels are real for generator cost
          # Since we just updated D, perform another forward pass of all-fake batch through D
          output = discriminator.forward(fake).view(-1)
          # Calculate G's loss based on this output
          errG = criterion(output, label)
          # Calculate gradients for G
          loss += XI_GEN * errG
          D_G_z2 = output.mean().item()
          # Update G
          loss.backward()
          optimizer.step()

        else:
          with torch.no_grad():
            #if with_x:
            loss, CE_l, KL_zeps, KL_leps, KLs_zfk, CE_l2, KL_combined, KL_combined2, l_recon, l_recon2 = model.loss_function_l(l, x)  # MI, MI2, TC, LOGQZ, LOGQZ_prodmarginals
            # else:
            #   loss, CE_l, KL_zeps, KL_leps, KLs_zfk, CE_l2, KL_combined, KL_combined2, l_recon, _ = model.loss_function_l(l)  # this is only for case when we don't want to use the regularizer
            sample_t = sample(l_recon2)

            ##### D network
            ## Process all-real batch
            label = torch.full((l.size(0),), real_label, dtype=torch.float, device=device)
            # Forward pass real batch through D
            output = discriminator.forward(l).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            D_x = output.mean().item()

            ## Process all-fake batch
            fake = sample_t
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = discriminator.forward(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            D_G_z1 = output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD = errD_real + errD_fake

            ##### G network
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = discriminator.forward(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            loss += XI_GEN * errG
            D_G_z2 = output.mean().item()

        errd += errD.item()
        errd_real += errD_real.item()
        errd_fake += errD_fake.item()
        dx += D_x
        dg_z1 += D_G_z1
        dg_z2 += D_G_z2
        errg += errG.item()

        avg_loss += loss.item()
        ce_l += CE_l.item()
        ce_l2 += CE_l2.item()
        kl_zeps += KL_zeps#.item()
        kl_leps += KL_leps.item()
        for k in range(model.fdvae.num_features):
            kls_zfk[k] += KLs_zfk[k].item()
        kl_combined += KL_combined.item()
        kl_combined2 += KL_combined2.item()

    avg_loss /= len(data_loader.dataset)
    ce_l /= len(data_loader.dataset)
    ce_l2 /= len(data_loader.dataset)
    kl_zeps /= len(data_loader.dataset)
    kl_leps /= len(data_loader.dataset)
    for k in range(model.fdvae.num_features):
        kls_zfk[k] /= len(data_loader.dataset)
    kl_combined /= len(data_loader.dataset)
    kl_combined2 /= len(data_loader.dataset)
    errd /= len(data_loader.dataset)
    errd_real /= len(data_loader.dataset)
    errd_fake /= len(data_loader.dataset)
    dx /= len(data_loader.dataset)
    dg_z1 /= len(data_loader.dataset)
    dg_z2 /= len(data_loader.dataset)
    errg /= len(data_loader.dataset)

    return avg_loss, ce_l, ce_l2, kl_zeps, kl_leps, kls_zfk, kl_combined, kl_combined2, errd, errd_real, errd_fake, dx, dg_z1, dg_z2, errg

In [None]:
######## Training loop REGVAE (STEP 3) + GAN objective
print('\nStart training [[[STEP 3 GAN]]]:', {i:args[i] for i in args if i!='fdvae'})
G_losses = []
D_losses = []

for epoch in range(1, args.epochs + 1):

    # Train
    train_loss, ce_l, ce_l2, kl_zeps, kl_leps, kls_zfk, kl_combined, kl_combined2,\
                             errd, errd_real, errd_fake, dx, dg_z1, dg_z2, errg = run_epoch_l_gan(train_loader, model, discriminator, optimizer, optimizerD, 
                                                                                                  epoch, rec_freq=1, path_reconstructions=model_path + '/', train=True)

    # logging train scores
    str_print = "{} EPOCH: avg loss {}".format(epoch, train_loss)
    str_print += f" avg ce_l {ce_l}"
    str_print += f" avg ce_l2 {ce_l2}"
    str_print += f" avg KL_eps {kl_zeps}"
    str_print += f" avg KL_leps {kl_leps} \n"
    for k in range(model.fdvae.num_features):      
        str_print += f", {train_data.features_names_x[k]} KL {kls_zfk[k]}"
    str_print += f"\n avg KL_combined {kl_combined}"
    str_print += f" avg KL_combined2 {kl_combined2}"
    str_print += f"\n avg errd {errd}"
    str_print += f" avg errd_real {errd_real}"
    str_print += f" avg errd_fake {errd_fake}"
    str_print += f" avg dx {dx}"
    str_print += f" avg dg_z1 {dg_z1}"
    str_print += f" avg dg_z2 {dg_z2}"
    str_print += f" avg errg {errg}"
    print(str_print)

    # Val
    val_loss, val_ce_l, val_ce_l2, val_kl_zeps, val_kl_leps, val_kls_zfk, val_kl_combined, val_kl_combined2,\
                          val_errd, val_errd_real, val_errd_fake, val_dx, val_dg_z1, val_dg_z2, val_errg = run_epoch_l_gan(val_loader, model, discriminator, optimizer, optimizerD, 
                                                                                                                           epoch, rec_freq=1, path_reconstructions=model_path + '/', train=False)

    # logging val scores
    str_print = "{} EPOCH: VAL avg loss {}".format(epoch, val_loss)
    str_print += f" avg ce_l {val_ce_l}"
    str_print += f" avg ce_l2 {val_ce_l2}"
    str_print += f" avg KL_eps {val_kl_zeps}"
    str_print += f" avg KL_leps {val_kl_leps} \n"
    for k in range(model.fdvae.num_features):      
        str_print += f", {train_data.features_names_x[k]} KL {val_kls_zfk[k]}"
    str_print += f"\n avg KL_combined {val_kl_combined}"
    str_print += f" avg KL_combined2 {val_kl_combined2}"
    str_print += f"\n avg errd {val_errd}"
    str_print += f" avg errd_real {val_errd_real}"
    str_print += f" avg errd_fake {val_errd_fake}"
    str_print += f" avg dx {val_dx}"
    str_print += f" avg dg_z1 {val_dg_z1}"
    str_print += f" avg dg_z2 {val_dg_z2}"
    str_print += f" avg errg {val_errg}"
    print(str_print)

    G_losses.append(val_errg)
    D_losses.append(val_errd)

    #######
    # if epoch > 1:
    #     print("RRRRRRRRESERVE saving model: ", model_name + f'_epoch_{epoch}' + '.model')
    #     torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
    #     torch.save(discriminator, model_path_discriminator + '/' + model_name_discriminator + f'_epoch_{epoch}' + '.model')
    #######

    if val_loss < best_loss:
        early_stopping_counter = 1

        best_loss =  val_loss

        print("saving model: ", model_name + f'_epoch_{epoch}' + '.model')
        torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
    else:
        #######
        if epoch >= 1 and epoch % 10 == 0:
            print("RESERVE saving model: ", model_name + f'_epoch_{epoch}' + '.model')
            torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
        #######   
        early_stopping_counter += 1
        if early_stopping_counter == max_early_stopping:
            break

In [None]:
# 653it [08:05,  1.35it/s]
# 1 EPOCH: avg loss 289151.01078861067 avg ce_l 157206.03886351796 avg ce_l2 180442.97235418519 avg KL_eps 0.0 avg KL_leps 0.5841221608799044 
# , roundness_f1 KL 2.412954923863773, elongation_f2 KL 4.092082730713976, nucleus_size_f3 KL 11.67125448121176
#  avg KL_combined 16.529692080770133 avg KL_combined2 16.545611257801845
#  avg errd 0.015494498402052902 avg errd_real 0.008945446051539597 avg errd_fake 0.006549052354794986 avg dx 0.013661521159218183 avg dg_z1 0.00500406309477929 avg dg_z2 0.004678837584775385 avg errg 0.035978537051842195
# 139it [00:37,  3.68it/s]
# 1 EPOCH: VAL avg loss 291592.6690709752 avg ce_l 157421.42188401616 avg ce_l2 180489.10538084246 avg KL_eps 0.0 avg KL_leps 0.4629769557759174 
# , roundness_f1 KL 2.4202314355919556, elongation_f2 KL 4.107097751201403, nucleus_size_f3 KL 12.220713593955681
#  avg KL_combined 16.77856865303405 avg KL_combined2 16.74273874466638
#  avg errd 0.009150887805639631 avg errd_real 0.0036932395795866444 avg errd_fake 0.005457648245399561 avg dx 0.016831455198726555 avg dg_z1 0.004269581364410775 avg dg_z2 0.004269581364410775 avg errg 0.038070793964927484

In [None]:
####torch.save(discriminator, model_path_discriminator + '/' + model_name_discriminator + f'_epoch_{133}' + '.model')

In [None]:
###torch.save(model, model_path + '/' + model_name + f'_epoch_{133}' + '.model')