In [None]:
# libraries
import numpy as np
import pandas as pd 
import torch
import random
import lightning as L
from utils import GWSDatasetFromPandasVAE, RiboDatasetGWS, train_vae
from pytorch_lightning.loggers import WandbLogger

In [None]:
# model parameters
# annot_thresh = 0.0
# longZerosThresh_val = 1e+10
# percNansThresh_val = 1.0
# dataset_split = '6h' # either 'ALL' / 'GOOD' / '6h'
# latent_dims = 512
# batch_size_val = 128

# annot_thresh = 0.5
# longZerosThresh_val = 20
# percNansThresh_val = 0.05
# dataset_split = '6h' # either 'ALL' / 'GOOD' / '6h'
# latent_dims = 512
# batch_size_val = 256

annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
dataset_split = '6h' # either 'ALL' / 'GOOD' / '6h'
latent_dims = 32
batch_size_val = 8
hidden_dims = 128
num_layers = 1
bidirectional = True
lr = 1e-4
total_epochs = 100

In [None]:
# reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
L.seed_everything(42)

# dataset paths 
if dataset_split == 'ALL':
    data_folder = '/net/lts2gdk0/mnt/scratch/lts2/nallapar/rb-prof/data/Darnell_Full/Darnell/data_conds_split/processed/'
elif dataset_split == 'GOOD':
    data_folder = '/net/lts2gdk0/mnt/scratch/lts2/nallapar/rb-prof/data/Darnell_Full/Darnell/data_conds_split_good/processed/'
elif dataset_split == '6h':
    data_folder = '/net/lts2gdk0/mnt/scratch/lts2/nallapar/rb-prof/data/Darnell_Full/Darnell/data_conds_split_6h/processed/'

# model name and output folder path
model_name = 'VAE-' + dataset_split + '-LD-' + str(latent_dims) + '-BS-' + str(batch_size_val) + '-AT-' + str(annot_thresh) + '-LZT-' + str(longZerosThresh_val) + '-PNT-' + str(percNansThresh_val) + '-HD-' + str(hidden_dims) + '-NL-' + str(num_layers)
output_loc = "saved_models/" + model_name

# start a new wandb run to track this script
wandb_logger = WandbLogger(log_model="all", project="vae", name=model_name)

# generate dataset
ds = 'ALL' # uses all the three conditions
train_dataset, test_dataset = RiboDatasetGWS(data_folder, dataset_split, ds, threshold = annot_thresh, longZerosThresh = longZerosThresh_val, percNansThresh = percNansThresh_val)

# convert pandas dataframes into torch datasets
train_dataset = GWSDatasetFromPandasVAE(train_dataset)
test_dataset = GWSDatasetFromPandasVAE(test_dataset)
print("samples in train dataset: ", len(train_dataset))
print("samples in test dataset: ", len(test_dataset))

# convert datasets into dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_val, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_val, shuffle=False, num_workers=4)

In [None]:
# train model
train_vae(latent_dims, output_loc, train_loader, test_loader, num_epochs=total_epochs, bs=batch_size_val, wandb_logger=wandb_logger, hidden_dims=hidden_dims, num_layers=num_layers, bidirectional=bidirectional, lr=lr)