In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import make_grid

In [2]:
base_dir = os.path.dirname(os.getcwd())
sys.path.append(base_dir)

from core.dataloader import CelebALoader
from core.models import VAE, Discriminator
from core.models import modules
from core.engine import ConfigFile, NCTrainer
data_dir = "../data/"



# ./Session

In [3]:
config = ConfigFile("")

In [4]:
config.set_session_name("sandbox_session")
config.setup_session()




# ./Dataloader

In [5]:
dataloader = CelebALoader(data_dir=data_dir, 
                          batch_size=8,
                          train_transform=transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]),
                          val_transform=transforms.ToTensor(),
                          validation_split=0.3)
config.set_dataloader(dataloader)

In [6]:
dataloader

Dataset CelebA
    Number of datapoints: 162770
    Root location: ../data/
    Target type: ['attr']
    Split: train
<core.dataloader.celeba.CelebALoader object at 0x10fad4048>

__Mask Generator :__

In [7]:
masks_kwargs = {'size': (256, 256),
                'coverage': (0.1, 0.5)}

config.update_kwargs(masks=masks_kwargs)

# ./Model

__VAE :__

In [8]:
vae = VAE(input_size=(6, 256, 256), 
          z_dim=32, 
          enc_nf = [32, 64], 
          dec_nf = [256, 128, 128, 128, 64, 64],
          enc_kwargs = {'padding': 1},
          out_channels=3,
          out_kwargs={'output_padding': 1})
config.set_model(vae)

__Discriminator :__

In [9]:
disc_kwargs = {'input_size': (3, 256, 256),
              'nb_filters': [32, 64]}
config.update_kwargs(discriminator=disc_kwargs)

# ./Training params

__Criterion :__

In [10]:
criterion = nn.BCELoss()
config.set_criterion(criterion)

__Optimizers:__

In [11]:
gen_optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3, weight_decay=1e-6)
disc_optimizer = {'lr': 1e-3,
                  'weight_decay': 1e-6}

config.set_optimizer(gen_optimizer)
config.update_kwargs(disc_optimizer=disc_optimizer)

__Metrics:__

In [12]:
pass

In [None]:
config.dump()