In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.utils import save_image
from datetime import datetime

In [2]:
import logging

TRAIN_LOSSES_LOGFILE = "train_losses.log"

class LossesLogger(object):
    def __init__(self, file_path_name):
        if os.path.isfile(file_path_name):
            os.remove(file_path_name)
        
        self.logger = logging.getLogger("losses_logger")
        self.logger.setLevel(1)
        file_handler = logging.FileHandler(file_path_name)
        file_handler.setLevel(1)
        self.logger.addHandler(file_handler)
        
        header = ",".join(["Epoch", "Loss", "Value"])
        self.logger.debug(header)
    
    def log(self, epoch, losses_storer):
        for k, v in losses_storer.item():
            log_string = ",".join(str(item) for item in [epoch, k, sum(v)/len(l)])
            self.logger.debug(log_string)

In [3]:
from timeit import default_timer
from collections import defaultdict
from tqdm import trange

class Trainer():
    def __init__(self, model, optimizer, loss_f, device, logger=logging.getLogger(__name__),
                 save_dir="results", is_progress_bar=True):
            self.device = device
            self.model = model.to(device)
            self.loss_f = loss_f
            self.optimizer = optimizer
            self.save_dir = save_dir
            self.is_progress_bar = is_progress_bar
            self.logger = logger
            self.losses_logger = LossesLogger(os.path.join(self.save_dir, TRAIN_LOSSES_LOGFILE))
            self.logger.info("Training Device: {}".format(self.device))
            
    def __call__(self, data_loader, epochs=10, checkpoint_every=10):
        start = default_timer()
        self.model.train()
        for epoch in range(epochs):
            storer = defaultdict(list)
            mean_epoch_loss = self._train_epoch(data_loader, storer, epoch)
            mean_epoch_loss = self._test_epoch(data_loader, storer, epoch)
            with torch.no_grad():
                sample = torch.randn(64, self.model.latent_dim).to(device)
                sample = model.decoder(sample).cpu()  # make sure on cpu
                save_image(sample.view(64, 1, 32, 32),
                           './results/samples/' + str(epoch) + '.png')
            
    def _train_epoch(self, data_loader, storer, epoch):
        epoch_loss = 0.
        kwargs = dict(desc="Epoch {}".format(epoch + 1), leave=False,
                      disable=not self.is_progress_bar)
        with trange(len(data_loader), **kwargs) as t:
            for _, (data, _) in enumerate(data_loader):
                iter_loss = self._train_iteration(data, storer)
                epoch_loss += iter_loss
                t.set_postfix(loss=iter_loss)
                t.update()
        mean_epoch_loss = epoch_loss / len(data_loader)
        return mean_epoch_loss
    
    def _train_iteration(self, data, storer):
        batch_size, channel, height, width = data.size()
        data = data.to(self.device)
        recon_batch, latent_dist, latent_sample = self.model(data)
        loss = self.loss_f(data, recon_batch, latent_dist, self.model.training, 
                           storer, latent_sample=latent_sample)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
            
        return loss.item()
    
    def _test_epoch(self, data_loader, storer, epoch):
        epoch_loss = 0.
        kwargs = dict(desc="Epoch {}".format(epoch + 1), leave=False,
                      disable=not self.is_progress_bar)
        with trange(len(data_loader), **kwargs) as t:
            for _, (data, _) in enumerate(data_loader):
                iter_loss = self._train_iteration(data, storer)
                epoch_loss += iter_loss
                t.set_postfix(loss=iter_loss)
                t.update()
        mean_epoch_loss = epoch_loss / len(data_loader)
        return mean_epoch_loss
    
    def _test_iteration(self, data, storer):
        batch_size, channel, height, width = data.size()
        data = data.to(self.device)
        recon_batch, latent_dist, latent_sample = self.model(data)
        loss = self.loss_f(data, recon_batch, latent_dist, self.model.training, 
                               storer, latent_sample=latent_sample)
            
        return loss.item()

In [4]:
LOSSES = ["betaH", "betaB"]
RECON_DIST = ["bernoulli", "laplace", "gaussian"]

In [5]:
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms

batch_size = 64

mnist_dataset = datasets.MNIST('../../data', 
                   train=True, 
                   download=True, 
                   transform=transforms.Compose([
                       transforms.Resize(32),
                       transforms.ToTensor()
                   ]))

mnist_dataset_test = datasets.MNIST('../../data', train=False, download=True, transform=transforms.Compose([
                       transforms.Resize(32),
                       transforms.ToTensor()
                   ]))

train_loader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(mnist_dataset_test, batch_size=batch_size, shuffle=True)

In [6]:
RES_DIR = "./results"
formatter = logging.Formatter('%(asctime)s %(levelname)s - %(funcName)s: %(message)s',
                                  "%H:%M:%S")
logger = logging.getLogger(__name__)
logger.setLevel(1)
stream = logging.StreamHandler()
stream.setLevel(1)
stream.setFormatter(formatter)
logger.addHandler(stream)

exp_dir = os.path.join(RES_DIR, "first")
# logger.info("Root directory for saving and loading experiments: {}".format(exp_dir))

In [7]:
from vae import VAE
from losses import get_loss_fn
from torch import optim

latent_dim = 12
img_size = [1,32,32]

lr = 5e-4

betaB_args = {"rec_dist": "bernoulli",
              "reg_anneal": 10000, 
              "betaH_B": 4,
              "betaB_initC": 0,
              "betaB_finC": 25,
              "betaB_G": 100
             }

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_f = get_loss_fn("betaB", n_data=len(train_loader.dataset), device=device, **betaB_args)

generator_model = VAE(img_size, latent_dim).to(device)
optimizer = optim.Adam(generator_model.parameters(), lr=lr)

In [8]:
trainer = Trainer(generator_model, optimizer, loss_f, device, logger=logger,
                 save_dir=exp_dir, is_progress_bar=False)

15:16:40 INFO - __init__: Training Device: cuda


In [9]:
# epochs = 100
# checkpoint_every = 10
# trainer(train_loader, epochs=epochs, checkpoint_every=checkpoint_every)

In [9]:
class Config:
    def __init__(self, generator, device, image_size=32, mode='train', model_path='./model/Siamese',
                 generate_path='./Generated', num_epochs=100, distance_weight=1.0, 
                 dataset='MNIST', tensorboard=True, batch_size=64, batch_size_test=1000):
        self.mode = mode
        self.image_size = image_size
        self.model_path = model_path
        self.generate_path = generate_path
        self.dataset = dataset
        self.num_epochs = num_epochs
        self.distance_weight = distance_weight
        self.tensorboard = tensorboard
        self.generator = generator
        self.batch_size = batch_size
        self.batch_size_test = batch_size_test
        self.device = device

config = Config(generator=generator_model, device=device, num_epochs=400, tensorboard=True)

In [10]:
from discriminator import SiameseDiscriminator
from siameseDataset import SiameseMNIST

In [11]:
mnist_siamese_dataset = SiameseMNIST(mnist_dataset)
mnist_siamese_dataset_test = SiameseMNIST(mnist_dataset_test)

siamese_data_loader = DataLoader(dataset=mnist_siamese_dataset, batch_size=config.batch_size, shuffle=True)
siamese_data_loader_test = DataLoader(dataset=mnist_siamese_dataset_test, batch_size=config.batch_size_test, shuffle=True)

In [12]:
from generator import SiameseGanSolver

start_time = datetime.now()

solver = SiameseGanSolver(config, siamese_data_loader)
solver.train()

end_time = datetime.now()
print(send_mail(start_time, end_time))

  0%|          | 0/400 [00:00<?, ?it/s]

We are training

0 2020-01-22 15:16:44.019290





NameError: name 'torch' is not defined