In [1]:
import torch
import numpy as np
from itertools import chain
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

import sys
sys.path.insert(0, '../..')
from RB_ZTF.scripts.datasets import *
from RB_ZTF.scripts.vae import *
from RB_ZTF.scripts.losses import *

from torchvision.transforms import RandomRotation
set_random_seed(7)

In [2]:
models_args = {'baseline':{'latent_dim': 36, 'transform':None},
               'aug':{'latent_dim': 36, 'transform': RandomRotation(degrees=20)},
               'ld78':{'latent_dim': 78, 'transform':None}}

In [3]:
name = input('Choose model to train (baseline/ld78/aug):  ')
args = models_args[name]

Choose model to train (baseline/ld78/aug):   aug


**NOTE:** If you want to train several models, then you need to restart the jupiter kernel for each new model. This is necessary for complete consistency of the final results.

In [4]:
oids, targets = get_only_r_oids('../akb.ztf.snad.space.json')

frames_dataset = AllFramesDataset(oids, transform=args['transform'])
train_loader = DataLoader(frames_dataset, batch_size=256, shuffle=True, num_workers=32)

In [5]:
latent_dim = args['latent_dim']

learning_rate = 5e-5
encoder = VAEEncoder(latent_dim=latent_dim * 2)
decoder = Decoder(latent_dim=latent_dim)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = encoder.to(device)
decoder = decoder.to(device)

optimizer = torch.optim.Adam(
    chain(encoder.parameters(), decoder.parameters()), lr=learning_rate
)

In [None]:
losses = []
for i in range(1, 101):
    losses.append(
        train_vae(
            enc=encoder,
            dec=decoder,
            optimizer=optimizer,
            loader=train_loader,
            epoch=i,
            single_pass_handler=vae_pass_handler,
            loss_handler=vae_loss_handler,
            device=device
        )
    )

In [None]:
if name == 'baseline':
    torch.save(encoder.state_dict(), '../trained_models/vae/encoder.zip')
    torch.save(decoder.state_dict(), '../trained_models/vae/decoder.zip')
    np.save('../trained_models/vae/loss.npy', losses)
else:
    torch.save(encoder.state_dict(), f'../trained_models/vae/encoder_{name}.zip')
    torch.save(decoder.state_dict(), f'../trained_models/vae/decoder_{name}.zip')
    np.save(f'../trained_models/vae/loss_{name}.npy', losses)