In [1]:
import torch
import numpy as np
from itertools import chain
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from datasets import *
from vae import *
from losses import *
from torchvision.transforms import RandomRotation
set_random_seed(7)

In [2]:

oids, targets = get_only_r_oids('akb.ztf.snad.space.json')

frames_dataset = AllFramesDataset(oids, transform=RandomRotation(degrees=20))
train_loader = DataLoader(frames_dataset, batch_size=256, shuffle=True, num_workers=32)

In [3]:
latent_dim = 36

learning_rate = 5e-5
encoder = VAEEncoder(latent_dim=latent_dim * 2)
decoder = Decoder(latent_dim=latent_dim)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = encoder.to(device)
decoder = decoder.to(device)

optimizer = torch.optim.Adam(
    chain(encoder.parameters(), decoder.parameters()), lr=learning_rate
)

In [None]:
losses = []
for i in tqdm(range(1, 101)):
    losses.append(
        train_vae(
            enc=encoder,
            dec=decoder,
            optimizer=optimizer,
            loader=train_loader,
            epoch=i,
            single_pass_handler=vae_pass_handler,
            loss_handler=vae_loss_handler,
            device=device
        )
    )



  0%|                                                   | 0/100 [00:00<?, ?it/s]

In [None]:
torch.save(encoder.state_dict(), 'trained_models/vae/encoder_aug.zip')
torch.save(decoder.state_dict(), 'trained_models/vae/decoder_aug.zip')

np.save('trained_models/vae/loss_aug.npy', losses)