# FDVAE

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('To enable a high-RAM runtime, select the Runtime > "Change runtime type"')
  print('menu, and then select High-RAM in the Runtime shape dropdown. Then, ')
  print('re-execute this cell.')
else:
  print('You are using a high-RAM runtime!')

Tue Oct 26 09:43:53 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.74       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# ---------------------------------------------------------------------------------------------------

In [2]:
import os
import sys
import numpy as np
import json
import random
import matplotlib.pyplot as plt
%matplotlib inline

from google.colab import drive
import time
import glob

import torch
from torch.nn import functional as F
import torch.optim as optim
import torchvision
from torchvision.utils import save_image
import torch.utils.data as data_utils
import torch.distributions as dist


from sklearn.decomposition import PCA
import pandas as pd
import seaborn as sns
from sklearn.manifold import TSNE
import random
from tqdm import tqdm

from tensorflow.python.client import device_lib
device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 3326323590203360323, name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 15434776576
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 8093311519215737807
 physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0"]

In [3]:
if not os.path.isdir('drive'):
  drive.mount('drive')
else:
  print('drive already mounted')

base_path = os.path.join('drive', 'My Drive', 'Colab Notebooks', 'GEN')
if not os.path.isdir(base_path):
  os.makedirs(base_path)

Mounted at drive


In [7]:
PATH = os.path.join("drive", "My Drive", "Colab Notebooks", "GEN") 

sys.path.append(PATH+'/biomatsim')
from dataset.data_loader_synthetic import ToyCell, ToyCellpair
from model_vae_synthetic import VAE
from model_qzfreg_synthetic import QZFREG
from model_main_synthetic import FDVAE, REGVAE, Discriminator
from diva.pixel_cnn_utils import sample

print(torch.cuda.is_available())

True


In [5]:
# https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
##########################
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
##########################

### Load data

In [6]:
# get cell data
dataset_name = 'dataset2_128_fixloc_1scale'

if not os.path.isdir('dataset2_128_fixloc_1scale'): 
  start_time = time.time()
  if not os.path.isfile('dataset2_128_fixloc_1scale.zip'): 
    ! wget -O dataset2_128_fixloc_1scale.zip "https://surfdrive.surf.nl/files/index.php/s/VuU2UKuHlBpvxgh/download"
  ! unzip -q dataset2_128_fixloc_1scale.zip -d .
  print("Unzipped.")
  print("Elapsed time: {} seconds.".format(time.time()-start_time))
else:
  print('Found folder dataset2_128_fixloc_1scale')

--2021-10-26 09:44:52--  https://surfdrive.surf.nl/files/index.php/s/VuU2UKuHlBpvxgh/download
Resolving surfdrive.surf.nl (surfdrive.surf.nl)... 145.100.27.67, 2001:610:108:203b:0:a11:da7a:5afe
Connecting to surfdrive.surf.nl (surfdrive.surf.nl)|145.100.27.67|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1280707500 (1.2G) [application/zip]
Saving to: ‘dataset2_128_fixloc_1scale.zip’


2021-10-26 09:45:48 (22.1 MB/s) - ‘dataset2_128_fixloc_1scale.zip’ saved [1280707500/1280707500]

Unzipped.
Elapsed time: 65.55587840080261 seconds.


### Train

In [9]:
# VAE
"""
x = (batch_size, 3, 128, 128) // (batch_size, 3, 64, 64)
"""
def run_epoch_vae(data_loader, model, optimizer, epoch, rec_freq=None, path_reconstructions='', train=True):
    #model.train()
    avg_loss = 0
    ce_x = 0
    kl_z = 0

    for batch_idx, (x, _) in tqdm(enumerate(data_loader)):
        # To device
        x = x.to(device) 

        if train and rec_freq and (epoch % rec_freq == 0) and (batch_idx == 1):
            model.eval()
            save_reconstructions_vae(model, x, path_reconstructions)

        if train: 
          model.train()
          optimizer.zero_grad()
          loss, CE_x, KL_z, x_recon = model.loss_function(x) 
          loss.backward()
          optimizer.step()
        else:
          with torch.no_grad():
            loss, CE_x, KL_z, x_recon = model.loss_function(x)

        avg_loss += loss.item()
        ce_x += CE_x.item()
        kl_z += KL_z.item()

    avg_loss /= len(data_loader.dataset)
    ce_x /= len(data_loader.dataset)
    kl_z /= len(data_loader.dataset)

    return avg_loss, ce_x, kl_z

In [10]:
# FDVAE
"""
x = (batch_size, 3, 128, 128) // (batch_size, 3, 64, 64)
fs = [(batch_size, 1), (batch_size, N), ...]; fs[k] array corresponds to fk
"""
def run_epoch(data_loader, model, optimizer, epoch, rec_freq=None, path_reconstructions='', train=True):
    #model.train()
    avg_loss = 0
    reg_qzf_loss = [0] * model.num_features
    reg_pzf_loss = [0] * model.num_features
    ce_x = 0
    kl_eps = 0
    kls_zfk = [0] * model.num_features
    kls_zfk_full = [0] * model.num_features

    for batch_idx, (x, fs) in tqdm(enumerate(data_loader)):
        # To device
        x = x.to(device)
        for k in range(len(fs)):
            fs[k] = fs[k].to(device)
            # fs[k].shape = (batch_size, 1) if the feature is continuous; 
                        #   (batch_size, N) if the feature is categorical with N categories (one-hot)

        if train and rec_freq and (epoch % rec_freq == 0) and (batch_idx == 1):
            model.eval()  # .eval() before saving an image
            save_reconstructions(model, x, fs, path_reconstructions)

        if train: 
          model.train()
          optimizer.zero_grad()
          loss, CE_x, KL_zeps, KLs_zfk, KLs_zfk_full, MSE_fs, MSE_fs_2 = model.loss_function(x, fs) 
          loss.backward()
          optimizer.step()
        else:
          with torch.no_grad():
            loss, CE_x, KL_zeps, KLs_zfk, KLs_zfk_full, MSE_fs, MSE_fs_2 = model.loss_function(x, fs)

        avg_loss += loss.item()
        ce_x += CE_x.item()
        kl_eps += KL_zeps.item()
        for k in range(len(fs)):
            reg_qzf_loss[k] += MSE_fs[k].item()
            reg_pzf_loss[k] += MSE_fs_2[k].item()
            kls_zfk[k] += KLs_zfk[k].item()
            kls_zfk_full[k] += KLs_zfk_full[k].item()

    avg_loss /= len(data_loader.dataset)
    ce_x /= len(data_loader.dataset)
    kl_eps /= len(data_loader.dataset)
    for k in range(len(fs)):
        reg_qzf_loss[k] /= len(data_loader.dataset)
        reg_pzf_loss[k] /= len(data_loader.dataset)
        kls_zfk[k] /= len(data_loader.dataset)
        kls_zfk_full[k] /= len(data_loader.dataset)

    return avg_loss, ce_x, reg_qzf_loss, reg_pzf_loss, kl_eps, kls_zfk, kls_zfk_full  # float, [float, ...]

In [12]:
# VAE
"""
x = (batch_size, 3, 128, 128) // (batch_size, 3, 64, 64)
"""
def save_reconstructions_vae(model, x, path='', add='0'):
    # Save reconstuction
    with torch.no_grad():

        x = x[:8]

        x_recon, _, _, _ = model.forward(x)
        sample_t = sample(x_recon)

        comparison = torch.cat([x, sample_t])
        save_image(comparison.cpu(),
                   path + 'reconstruction_vae_' + str(epoch-1) + '_' + add + '.png', nrow=8)

# FDVAE
"""
x = (batch_size, 3, 128, 128) // (batch_size, 3, 64, 64)
fs = [(batch_size, 1), (batch_size, N), ...]; fs[k] array corresponds to fk
"""
def save_reconstructions(model, x, fs, path='', add='0'):
    # Save reconstuction
    with torch.no_grad():

        x = x[:8]
        fs = [arr[:8] for arr in fs]

        x_recon, _, _, _, _, _, _, _, _, _, _ = model.forward(x, fs)
        sample_t = sample(x_recon)

        comparison = torch.cat([x, sample_t])
        save_image(comparison.cpu(),
                   path + 'reconstruction_' + str(epoch-1) + '_' + add + '.png', nrow=8)

Setting the parameters

In [21]:
################ VAE
# args = {
#     'no_cuda': False, 
#     'seed': 1, 
#     'batch_size': 50, 
#     'epochs': 500,
#     'lr': 0.001,

#     # data info
#     'datapath': 'dataset2_128_fixloc_1scale/', 
#     'data_info_filename': 'dataset2_128_fixloc_1scale.csv', 
#     'features_names': ['roundness', 'elongation', 'nucleus_size'],  #'rotation_angle'],  --- here for data loader

#     # Model VAE
#     'z_dim': 8,  # size of the latent space z

#     # Beta VAE part
#     'beta': 1,  # for KL(q(ze|x) || p(ze))

#     ## https://arxiv.org/abs/1804.03599 approach for z to control the capacity of the latent space
#     'gamma': 0,  # beta_eps is used when gamma = 0
#     'wc': 10,  # warmup for c
#     'c': 2,
#     ##

#     'prior_z_scale_train': False, 

#     # warm-up
#     'w': 10, #5,  #100,  # 'number of epochs for warm-up. Set to 0 to turn warmup off.' (DIVA paper: https://github.com/AMLab-Amsterdam/DIVA)
#     'max_beta': 1.,  # 'max beta for warm-up'
#     'min_beta': 0.,

#     'outpath': PATH+'/model_output_synthetic/vae',   #'./'    
# }

# args = dotdict(args)
# args.cuda = not args.no_cuda and torch.cuda.is_available()
# device = torch.device("cuda" if args.cuda else "cpu")
# kwargs = {'num_workers': 0, #1,   # https://github.com/pytorch/pytorch/issues/5301
#           'pin_memory': False} if args.cuda else {}

In [38]:
############################ FDVAE
args = {
    'no_cuda': False, 
    'seed': 1, 
    'batch_size': 50,
    'epochs': 500,
    'lr': 0.001,

    # data info
    'datapath': 'dataset2_128_fixloc_1scale/', 
    'data_info_filename': 'dataset2_128_fixloc_1scale.csv', 
    'features_names': ['roundness', 'elongation', 'nucleus_size'],  #'rotation_angle'],

    # Model FVDA
    'zf_dims': [2] * 3,  # sizes of the latent spaces zf
    'zeps_dim': 2,  # size of the latent space ze
    'f_dims': [1] * 3,  # dimensionalities of features (1)
    'qf_activations': [False] * 3,  # LeakyReLU activation in regression objectives before the linear layer


    # Aux multipliers
    'aux_loss_multipliers': [10 ** 5] * 3,  # [alpha_f1, alpha_f2, ...]
    'aux_loss_multipliers_2': [0] * 3,   

    # Beta VAE part
    'betas_f': [1, 1, 1],  # 1],  # [beta_f1, beta_f2, ...] for KL(q(zf|x) || p(zf|f))
    'beta_eps': 300,  # for KL(q(ze|x) || p(ze)) ############################

    ## https://arxiv.org/abs/1804.03599 approach for z_eps to control the capacity of the latent space
    'gamma': 0,  # beta_eps is used when gamma = 0
    'wc': 10,  # warmup for c
    'c': 2,
    ##

    'pzf_full_prior': True,
    'betas_f_full': [1, 1, 1],  # for KL(p(zf|f) || p(zf))

    # Trainable relatvie variance of the prior
    'prior_zeps_scale_train': False, 
    'prior_zf_full_scale_train_list': [True, True, True], 

    ### not used
    #'splitdim': False,  
    #'mult_loc': np.pi / 4,
    #'mult_scale': np.pi / 16,

    # warm-up
    'w': 10, #5,  #100,  # 'number of epochs for warm-up. Set to 0 to turn warmup off.' (DIVA paper: https://github.com/AMLab-Amsterdam/DIVA)
    'max_beta': 1.,  # 'max beta for warm-up'
    'min_beta': 0.,

    'outpath': PATH+'/model_output_synthetic/fdvae',  #'./'    
}

args = dotdict(args)
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
kwargs = {'num_workers': 0, #1,   # https://github.com/pytorch/pytorch/issues/5301
          'pin_memory': False} if args.cuda else {}

In [39]:
##### Model name
#model_name = 'VAE_seed_' + str(args.seed) + '_zdim_8_v1'
model_name = 'FDVAE_seed_' + str(args.seed) + '_zdims_222_2_v1' 

model_path = args.outpath + '/' + model_name
print(args.outpath)
print(model_name)
print(model_path)

# Set seed
torch.manual_seed(args.seed)
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)

drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/fdvae
FDVAE_seed_1_zdims_222_2_v1
drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/fdvae/FDVAE_seed_1_zdims_222_2_v1


# ----------------------

Data loader

In [16]:
data_info_table = pd.read_csv(args.datapath + args.data_info_filename, index_col=0)
data_info_table['filename'] = data_info_table['idx'].apply(lambda x: str(x).zfill(7) + '.png')
data_info_table.drop('idx', axis=1, inplace=True)

###
np.random.seed(1)
img_list_val = list(np.random.choice(data_info_table['filename'].values, int(round(data_info_table.shape[0] * 0.2)), replace=False))
img_list_train = [img for img in data_info_table['filename'].values if img not in img_list_val]

print(len(img_list_train), len(img_list_val))

40000 10000


In [23]:
# import glob
# c1 = "00[0-2]***"
# c2 = "0030000"
# file_list = sorted(list(set(glob.glob(args.datapath + c1 + ".png")).union(
#                         set(glob.glob(args.datapath + c2 + ".png")))), reverse=True)

train_data = ToyCell(path=args.datapath, 
                     data_info_filename=args.data_info_filename, 
                     features_names=args.features_names,
                     imgsize=(128, 128),
                     #condition="00[0]*",
                     img_list=img_list_train,
                     scaler=True
                     )
train_loader = data_utils.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, **kwargs)

  "Argument interpolation should be of type InterpolationMode instead of int. "
100%|██████████| 40000/40000 [00:50<00:00, 799.74it/s]


In [24]:
val_data = ToyCell(path=args.datapath, 
                   data_info_filename=args.data_info_filename, 
                   features_names=args.features_names,
                   imgsize=(128, 128),
                   #condition="00[0]*",
                   img_list=img_list_val,
                   scaler=True
                   )
val_loader = data_utils.DataLoader(val_data, batch_size=args.batch_size, shuffle=True, **kwargs)

  "Argument interpolation should be of type InterpolationMode instead of int. "
100%|██████████| 10000/10000 [00:12<00:00, 824.63it/s]


# ---------------------

Setup the model

In [None]:
################################# VAE
# model = VAE(args).to(device) 
# model.cuda()

In [40]:
################################# FDVAE
model = FDVAE(args).to(device)
model.cuda()

FDVAE(
  (px): px(
    (fc1): Sequential(
      (0): Linear(in_features=8, out_features=1024, bias=False)
      (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (rn1): IdResidualConvTBlockBNIdentity(
      (nonlin): LeakyReLU(negative_slope=0.01)
      (conv1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (rn2): Upsample(size=8, mode=nearest)
    (rn3): IdResidualConvTBlockBNIdentity(
      (nonlin): LeakyReLU(negative_slope=0.01)
      (conv1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_ru

In [None]:
############ MORE TRAINING
# model = torch.load(model_path + '/' + model_name + '_epoch_10' + '.model')#, map_location=torch.device('cpu'))
# model = model.to(device)

In [None]:
############ INITIALIZE FROM ANOTHER MODEL
# base_model_name = 'FDVAE_seed_1_zdims_222_2_v0'
# base_model_path = 'drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/fdvae/FDVAE_seed_1_zdims_222_2_v0'
# model = torch.load(base_model_path + '/' + base_model_name + '_epoch_20' + '.model')#, map_location=torch.device('cpu'))
# model = model.to(device)

In [None]:
############ INITIALIZE SPECIFIC COMPONENTS FROM ANOTHER MODEL
# model_path_old = 'drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/fdvae/old_name'
# model_name_old = 'old_name'
# model_old = torch.load(model_path_old + '/' + model_name_old + '_epoch_8' + '.model')
# old_model_dict = model_old.state_dict()
# new_model_dict = model.state_dict()
# pretrained_dict = {k: v for k, v in old_model_dict.items() if k.split(".")[0] == 'px' and k.split(".")[1] != 'fc1' and k in new_model_dict}

# #pretrained_dict.update({k: v for k, v in old_model_dict.items() if k.split(".")[0] == 'qzeps'})
# pretrained_dict.update({k: v for k, v in old_model_dict.items() if k.split(".")[0] == 'qzeps' and k.split(".")[1] not in ['fc11', 'fc12']})

# #pretrained_dict.update({k: v for k, v in old_model_dict.items() if k.split(".")[0] == 'prior_zeps_log_scale'})
# new_model_dict.update(pretrained_dict)
# model.load_state_dict(new_model_dict)

----------------------------

INITIALIZE FROM QZFREG (for FDVAE) - in case encoders/cond. priors/relative variance parameters are pretrained

In [None]:
# model_path_qzf1 = 'drive/My Drive/Colab Notebooks/GEN/FDVA_output/model_qzf/QZFreg1_seed_1_w0_bs50_zdim2_qzf3_fi0_10_1_100k_0_scaletrain'
# model_name_qzf1 = 'QZFreg1_seed_1_w0_bs50_zdim2_qzf3_fi0_10_1_100k_0_scaletrain'
# model_qzf1 = torch.load(model_path_qzf1 + '/' + model_name_qzf1 + '_epoch_77' + '.model')#, map_location=torch.device('cpu'))

# model_path_qzf2 = 'drive/My Drive/Colab Notebooks/GEN/FDVA_output/model_qzf/QZFreg1_seed_3_w0_bs50_zdim2_qzf3_fi1_10_1_100k_0_scaletrain'
# model_name_qzf2 = 'QZFreg1_seed_3_w0_bs50_zdim2_qzf3_fi1_10_1_100k_0_scaletrain'
# model_qzf2 = torch.load(model_path_qzf2 + '/' + model_name_qzf2 + '_epoch_56' + '.model')#, map_location=torch.device('cpu'))

# model_path_qzf3 = 'drive/My Drive/Colab Notebooks/GEN/FDVA_output/model_qzf/QZFreg1_seed_2_w0_bs50_zdim2_qzf3_fi2_10_1_100k_0_scaletrain'
# model_name_qzf3 = 'QZFreg1_seed_2_w0_bs50_zdim2_qzf3_fi2_10_1_100k_0_scaletrain'
# model_qzf3 = torch.load(model_path_qzf3 + '/' + model_name_qzf3 + '_epoch_67' + '.model')#, map_location=torch.device('cpu'))

# fdva_dict = model.state_dict()
# qzf1_dict = model_qzf1.state_dict()
# qzf2_dict = model_qzf2.state_dict()
# qzf3_dict = model_qzf3.state_dict()

# # model_qzf1.state_dict().keys()
# ## 'prior_log_scale'

# # model.state_dict().keys() 
# ## 'prior_zeps_log_scale'
# ## 'prior_zf_full_log_scales.0', 'prior_zf_full_log_scales.1', 'prior_zf_full_log_scales.2'

# pretrained_qzf1 = {'qzf_list.0.' + ".".join(k.split(".")[1:]): v for k, v in qzf1_dict.items() if k.split(".")[0] == 'qzf'}
# pretrained_qzf1.update({'pzf_list.0.' + ".".join(k.split(".")[1:]): v for k, v in qzf1_dict.items() if k.split(".")[0] == 'pzf'})
# pretrained_qzf1.update({'qf_list.0.' + ".".join(k.split(".")[1:]): v for k, v in qzf1_dict.items() if k.split(".")[0] == 'qf'})
# pretrained_qzf1.update({'prior_zf_full_log_scales.0': qzf1_dict['prior_log_scale']})

# pretrained_qzf2 = {'qzf_list.1.' + ".".join(k.split(".")[1:]): v for k, v in qzf2_dict.items() if k.split(".")[0] == 'qzf'}
# pretrained_qzf2.update({'pzf_list.1.' + ".".join(k.split(".")[1:]): v for k, v in qzf2_dict.items() if k.split(".")[0] == 'pzf'})
# pretrained_qzf2.update({'qf_list.1.' + ".".join(k.split(".")[1:]): v for k, v in qzf2_dict.items() if k.split(".")[0] == 'qf'})
# pretrained_qzf2.update({'prior_zf_full_log_scales.1': qzf2_dict['prior_log_scale']})

# pretrained_qzf3 = {'qzf_list.2.' + ".".join(k.split(".")[1:]): v for k, v in qzf3_dict.items() if k.split(".")[0] == 'qzf'}
# pretrained_qzf3.update({'pzf_list.2.' + ".".join(k.split(".")[1:]): v for k, v in qzf3_dict.items() if k.split(".")[0] == 'pzf'})
# pretrained_qzf3.update({'qf_list.2.' + ".".join(k.split(".")[1:]): v for k, v in qzf3_dict.items() if k.split(".")[0] == 'qf'})
# pretrained_qzf3.update({'prior_zf_full_log_scales.2': qzf3_dict['prior_log_scale']})

# fdva_dict.update(pretrained_qzf1)
# fdva_dict.update(pretrained_qzf2)
# fdva_dict.update(pretrained_qzf3)

# model.load_state_dict(fdva_dict)

In [None]:
### In case we want to FREEZE the pretrained encoders, conditional priors and relative variance parameters
# for name, child in model.named_children():
#   if name in ['pzf_list', 'qzf_list', 'qf_list', 'prior_zf_full_log_scales']:
#     for p in child.parameters():
#       p.requires_grad = False

In [None]:
# check 
# model.pzf_list[0].fc21s[0].weight.requires_grad, model.prior_zf_full_log_scales[0].requires_grad

# ---------------------

Freezing / unfreezing certain components

In [None]:
# for name, child in model.named_children():
#   if name in ['px']:#['pzf_list', 'qzf_list', 'qf_list', 'prior_zf_full_log_scales']:
#     for name1, child1 in child.named_children():
#       if name1 != 'fc1':
#         #print(name1)
#         for p in child1.parameters():
#           p.requires_grad = True
#       else:
#         for p in child1.parameters():
#           p.requires_grad = True

# for name, child in model.named_children():
#   if name in ['qzeps']:#['pzf_list', 'qzf_list', 'qf_list', 'prior_zf_full_log_scales']:
#     for name1, child1 in child.named_children():
#       if name1 not in ['fc11', 'fc12']:
#         #print(name1)
#         for p in child1.parameters():
#           p.requires_grad = True ###True (RY) ####False (RX)
#       else:
#         for p in child1.parameters():
#           p.requires_grad = True

In [None]:
# check
# model.qzeps.fc11[0].weight.requires_grad, model.px.fc1[0].weight.requires_grad, model.px.rn1.conv1.weight.requires_grad, model.qzeps.rn1.conv1.weight.requires_grad
##model.qzeps.fc11[0].weight.requires_grad, model.px.fc1[0].weight.requires_grad, model.px.rn1.conv1.weight.requires_grad

# ---------------------

Setup the optimizer

In [41]:
print(args.lr)

0.001


In [42]:
# args.lr = 0.001

In [43]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)

best_loss = 1e+8
best_y_acc = 0.

early_stopping_counter = 1
max_early_stopping = 100

# ---------------------

In [44]:
# check the arguments
args

{'aux_loss_multipliers': [100000, 100000, 100000],
 'aux_loss_multipliers_2': [0, 0, 0],
 'batch_size': 50,
 'beta_eps': 300,
 'betas_f': [1, 1, 1],
 'betas_f_full': [1, 1, 1],
 'c': 2,
 'cuda': True,
 'data_info_filename': 'dataset2_128_fixloc_1scale.csv',
 'datapath': 'dataset2_128_fixloc_1scale/',
 'epochs': 500,
 'f_dims': [1, 1, 1],
 'features_names': ['roundness', 'elongation', 'nucleus_size'],
 'gamma': 0,
 'lr': 0.001,
 'max_beta': 1.0,
 'min_beta': 0.0,
 'no_cuda': False,
 'outpath': 'drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/fdvae',
 'prior_zeps_scale_train': False,
 'prior_zf_full_scale_train_list': [True, True, True],
 'pzf_full_prior': True,
 'qf_activations': [False, False, False],
 'seed': 1,
 'w': 10,
 'wc': 10,
 'zeps_dim': 2,
 'zf_dims': [2, 2, 2]}

In [45]:
model_name, model_path

('FDVAE_seed_1_zdims_222_2_v1',
 'drive/My Drive/Colab Notebooks/GEN/model_output_synthetic/fdvae/FDVAE_seed_1_zdims_222_2_v1')

In [31]:
# run only once when a new model is initialized
os.mkdir(model_path)

import json

with open(model_path + '/' + model_name + '.json', 'w') as configfile:
    json.dump(args, configfile, indent=2)

In [46]:
# check the number of parameters
def nparams(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

nparams(model)

2678167

# -------------------------

TRAINING VAE

In [35]:
# check
# print(model.beta_eps, args.beta_eps)

In [36]:
# import copy
# args_COPY = copy.deepcopy(dict(args))

In [None]:
######## Training loop VAE
# print('\nStart training:', args)
# for epoch in range(1, args.epochs + 1): 

#     ######################################
#     ### WARMUP (increase beta gradually).
#     ### Formula: min([final_value, curent_val + (final value - current_val) * (epoch * 1. - epochs_completed) / (args_COPY['w'] - epochs_completed)])
#     # model.beta = min([args_COPY['beta'], 10 + (args_COPY['beta'] - 10) * (epoch * 1. - 11) / (args_COPY['w'] - 11)])  # w!
#     # model.c = min([args_COPY['c'], 4 + (args_COPY['c'] - 4) * (epoch * 1. - 44) / (args_COPY['wc'] - 44)])  # wc!
#     # print('C, GAMMA: ',  model.c, model.gamma, args_COPY['c'], args_COPY['gamma'])
#     ######################################

#     # Train
#     train_loss, ce_x, kl_z = run_epoch_vae(train_loader, model, optimizer, epoch, rec_freq=1, path_reconstructions=model_path + '/', train=True)

#     # logging train scores
#     str_print = "{} EPOCH: avg loss {}".format(epoch, train_loss)
#     str_print += f" avg ce_x {ce_x}"
#     str_print += f" avg KL_z {kl_z}"
#     print(str_print) 

#     # Val
#     val_loss, val_ce_x, val_kl_z = run_epoch_vae(val_loader, model, optimizer, epoch, rec_freq=1, path_reconstructions=model_path + '/', train=False)

#     # logging val scores
#     str_print = "{} EPOCH: VAL avg loss {}".format(epoch, val_loss)
#     str_print += f" avg ce_x {val_ce_x}"
#     str_print += f" avg KL_z {val_kl_z}"
#     print(str_print) 

#     if val_loss < best_loss:
#         early_stopping_counter = 1

#         best_loss = val_loss

#         print("saving model: ", model_name + f'_epoch_{epoch}' + '.model')
#         torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
#     else:
#         #######
#         if epoch >= 1 and epoch % 10 == 0:
#             print("RESERVE saving model: ", model_name + f'_epoch_{epoch}' + '.model')
#             torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
#         #######   
#         early_stopping_counter += 1
#         if early_stopping_counter == max_early_stopping:
#             break

In [None]:
# 800it [03:48,  3.50it/s]
# 1 EPOCH: avg loss 266571.73815 avg ce_x 266494.25805 avg KL_z 77.4799271408081
# 200it [00:13, 14.95it/s]
# 1 EPOCH: VAL avg loss 212320.2766 avg ce_x 212220.141 avg KL_z 100.13557810058593
# saving model:  VAE_seed_1_zdim_8_v1_epoch_1.model

# -------------------------

TRAINING FDVAE

In [33]:
# check
# print(model.beta_eps, args.beta_eps)

In [47]:
import copy
args_COPY = copy.deepcopy(dict(args))

In [None]:
######## Training loop FDVAE
print('\nStart training:', args)
for epoch in range(1, args.epochs + 1): 

    ######################################
    ### WARMUP (increase beta gradually).
    ### Formula: min([final_value, curent_val + (final value - current_val) * (epoch * 1. - epochs_completed) / (args_COPY['w'] - epochs_completed)])
    # model.beta_eps = min([args_COPY['beta_eps'], 10 + (args_COPY['beta_eps'] - 10) * (epoch * 1. - 11) / (args_COPY['w'] - 11)])  # w!
    # model.c = min([args_COPY['c'], 4 + (args_COPY['c'] - 4) * (epoch * 1. - 44) / (args_COPY['wc'] - 44)])  # wc!
    # print('C, GAMMA: ',  model.c, model.gamma, args_COPY['c'], args_COPY['gamma'])
    # for k in range(model.num_features):
    #     model.betas_f[k] = min([args_COPY['betas_f'][k], 0 + (args_COPY['betas_f'][k] - 0) * (epoch * 1. - 0) / (args_COPY['w'] - 0)])
    # print('BETAS: ',  model.beta_eps, model.betas_f, model.betas_f_full, args_COPY['beta_eps'], args_COPY['betas_f'], args_COPY['betas_f_full'])
    # print('ALPHAS: ',  model.aux_loss_multipliers, model.aux_loss_multipliers_2, args_COPY['aux_loss_multipliers'], args_COPY['aux_loss_multipliers_2'])
    ######################################

    # Train
    train_loss, ce_x, reg_qzf_loss, reg_pzf_loss, \
          kl_eps, kls_zfk, kls_zfk_full = run_epoch(train_loader, model, optimizer, epoch, rec_freq=1, 
                                                    path_reconstructions=model_path + '/', train=True)
    # logging train scores
    str_print = "{} EPOCH: avg loss {}".format(epoch, train_loss)
    str_print += f" avg ce_x {ce_x}"
    str_print += f" avg KL_eps {kl_eps}"
    for k in range(model.num_features):
        str_print += f", \n{train_data.features_names[k]} feature loss qzf {reg_qzf_loss[k]}" 
        str_print += f", {train_data.features_names[k]} feature loss pzf {reg_pzf_loss[k]}"       
        str_print += f", {train_data.features_names[k]} KL {kls_zfk[k]}"
        str_print += f", {train_data.features_names[k]} KL full {kls_zfk_full[k]}"
    print(str_print) 

    # Val
    val_loss, val_ce_x, val_reg_qzf_loss, val_reg_pzf_loss, \
          val_kl_eps, val_kls_zfk, val_kls_zfk_full = run_epoch(val_loader, model, optimizer, epoch, rec_freq=1, 
                                                                path_reconstructions=model_path + '/', train=False)

    # logging val scores
    str_print = "{} EPOCH: VAL avg loss {}".format(epoch, val_loss)
    str_print += f" avg ce_x {val_ce_x}"
    str_print += f" avg KL_eps {val_kl_eps}"
    for k in range(model.num_features):
        str_print += f", \n{train_data.features_names[k]} feature loss qzf {val_reg_qzf_loss[k]}"
        str_print += f", {train_data.features_names[k]} feature loss pzf {val_reg_pzf_loss[k]}"       
        str_print += f", {train_data.features_names[k]} KL {val_kls_zfk[k]}"
        str_print += f", {train_data.features_names[k]} KL full {val_kls_zfk_full[k]}"
    print(str_print) 

    if val_loss < best_loss:
        early_stopping_counter = 1

        best_loss = val_loss

        print("saving model: ", model_name + f'_epoch_{epoch}' + '.model')
        torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
    else:
        #######
        if epoch >= 1 and epoch % 10 == 0:
            print("RESERVE saving model: ", model_name + f'_epoch_{epoch}' + '.model')
            torch.save(model, model_path + '/' + model_name + f'_epoch_{epoch}' + '.model')
        #######   
        early_stopping_counter += 1
        if early_stopping_counter == max_early_stopping:
            break

In [None]:
# 800it [06:27,  2.06it/s]
# 1 EPOCH: avg loss 245422.727475 avg ce_x 243660.99025 avg KL_eps 3.6915118152141573, 
# roundness feature loss qzf 0.002092827259283513, roundness feature loss pzf 0.054084887577593325, roundness KL 10.052883821487427, roundness KL full 0.6170354379117489, 
# elongation feature loss qzf 0.0028189792375080285, elongation feature loss pzf 0.05542313022315502, elongation KL 14.623862271881103, elongation KL full 2.246926694583893, 
# nucleus_size feature loss qzf 0.0012305661214282737, nucleus_size feature loss pzf 0.02648472283035517, nucleus_size KL 11.681087564086914, nucleus_size KL full 0.8255669524908066
# 200it [00:24,  8.27it/s]
# 1 EPOCH: VAL avg loss 286031.1892 avg ce_x 285150.3968 avg KL_eps 1.3013192699432372, 
# roundness feature loss qzf 0.0017515918269753457, roundness feature loss pzf 0.03549718990325928, roundness KL 10.45258600769043, roundness KL full 0.22406841302514077, 
# elongation feature loss qzf 0.0019195088326931, elongation feature loss pzf 0.041211951303482056, elongation KL 12.487350659179688, elongation KL full 1.5249429347991943, 
# nucleus_size feature loss qzf 0.000878522290661931, nucleus_size feature loss pzf 0.014190089377760886, nucleus_size KL 10.199532424926758, nucleus_size KL full 0.5483263230323792
# saving model:  FDVAE_seed_1_zdims_222_2_v1_epoch_1.model