In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms

In [2]:
base_dir = os.path.dirname(os.getcwd())
sys.path.append(base_dir)

from core.dataloader import CelebALoader
from core.models import VAE, Discriminator
from core.models import modules
from core.engine import ConfigFile, NCTrainer
data_dir = "../data/"

# ./Session

In [3]:
config = ConfigFile("")

In [4]:
config.set_session_name("sandbox_session")
config.setup_session()




# ./Dataloader

In [7]:
dataloader = CelebALoader(data_dir=data_dir, 
                          batch_size=32,
                          train_transform=transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]),
                          val_transform=transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]),
                          validation_split=0.3)
config.set_dataloader(dataloader)

__Mask Generator :__

In [8]:
masks_kwargs = {'size': (256, 256),
                'coverage': (0.1, 0.5)}

config.update_kwargs(masks=masks_kwargs)

# ./Model

__VAE :__

In [17]:
vae = VAE(input_size=(6, 256, 256), 
          z_dim=32, 
          enc_nf = [32, 64], 
          dec_nf = [256, 128, 64, 64],
          enc_kwargs = {'padding': 1},
          dec_kwargs = 3 * [{'kernel_size':4, 'padding': 1}] + [{'kernel_size':5, 'padding': 1}],
          out_channels=3)
config.set_model(vae)

__Discriminator :__

In [10]:
disc_kwargs = {'input_size': (3, 256, 256),
              'nb_filters': [32, 64]}
config.update_kwargs(discriminator=disc_kwargs)

# ./Training params

__Criterion :__

In [11]:
criterion = nn.BCELoss()
config.set_criterion(criterion)

__Optimizers:__

In [12]:
gen_optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3, weight_decay=1e-6)
disc_optimizer = {'lr': 1e-3,
                  'weight_decay': 1e-6}

config.set_optimizer(gen_optimizer)
config.update_kwargs(disc_optimizer=disc_optimizer)

__Metrics:__

In [13]:
pass

__Training scope :__

In [14]:
epoch = 128
config.set_epochs(epoch)

In [18]:
config.dump()

In [16]:
!tree -d /media/raid/shahine/neural_conditioner/

[01;34m/media/raid/shahine/neural_conditioner/[00m
└── [01;34msandbox_session[00m
    ├── [01;34mchkpt[00m
    ├── [01;34mruns[00m
    └── [01;34mscores[00m

4 directories
