In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os
import sys
import numpy as np
import torch
import torch.optim as optim
from matplotlib import pyplot as plt
import pandas as pd
from tqdm import tqdm_notebook, tnrange

In [None]:
module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src import constants as c
from src import utils
from src import visualization as v
from src import model as m

In [None]:
parser = utils.setup_argparse()
args = parser.parse_args(args=['--root=/users/dli44/tool-presence/',
                               '--data-dir=data/larynx_data/',
                               '--image-size=64',
                               '--loss-function=mmd',
                               '--z-dim=10',
                               '--betas=1'
                              ])

datasets, dataloaders = utils.setup_data(args)
args.loss_function = utils.select_loss_function(args.loss_function)
args.z_dim = [int(x) for x in args.z_dim.split(',')]
args.betas = [float(x) for x in args.betas.split(',')]

In [None]:
model = m.VAE(image_channels=args.image_channels,
              image_size=args.image_size,
              h_dim1=1024,
              h_dim2=128,
              zdim=args.z_dim[0]).to(c.device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

tbar = tnrange(args.epochs)
for epoch in tbar:
    """
    Training
    """
    model.train()
    train_loss, kl, rl = 0, 0, 0
    t2 = tqdm_notebook(dataloaders['train'])
    for batch_idx, (data, _) in enumerate(t2):
        data = data.to(c.device)
        optimizer.zero_grad()
        recon_batch, z, mu, logvar = model(data)

        loss_params = {'recon': recon_batch,
                       'x': data,
                       'z': z,
                       'mu': mu,
                       'logvar': logvar,
                       'batch_size': args.batch_size,
                       'input_size': args.image_size,
                       'zdim': args.z_dim[0],
                       'beta': args.betas[0]}

        loss, r, k = args.loss_function(**loss_params)
        loss.backward()

        train_loss += loss.item()
        kl += k.item()
        rl += r.item()

        optimizer.step()

        t2.set_postfix(
            {"Reconstruction Loss": r.item(),
             "KL Divergence": k.item()})

    losses['kl'].append(kl)
    losses['rl'].append(rl)

    tbar.set_postfix({"KL Divergence":
                      kl/len(dataloaders['train'].dataset),
                      "Reconstruction Loss":
                      rl/len(dataloaders['train'].dataset)})