In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
def KL_(mu,sigma_sq):
    """
    prior: p(z) ~ N(0,I)
    posterior approx.: q(z|x) ~ N(mu,sigma^2)
    :param mu: tensor of size [J,], holding the mean value
    :param sigma: tensor of size [J,], holding the s.d.
    :return: KL(q||p)
    """
    return 0.5*torch.sum(1 + torch.log(1e-10 + sigma_sq) - mu**2 - sigma_sq) #1e-10 is added inside the log to set a min value.


class Model(nn.Module):
    def __init__(self, latent_dim,device):
        """Initialize a VAE.

        Args:
            latent_dim: dimension of embedding
            device: run on cpu or gpu
        """
        super(Model, self).__init__()
        self.device = device
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 1, 2),  # B,  32, 28, 28
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),  # B,  32, 14, 14
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),  # B,  64,  7, 7
        )

        self.mu = nn.Linear(64 * 8 * 8, latent_dim)
        self.logvar = nn.Linear(64 * 8 * 8, latent_dim)

        self.upsample = nn.Linear(latent_dim, 64 * 8 * 8)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1),  # B,  64,  14,  14
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, 4, 2, 1, 1),  # B,  32, 28, 28
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 3, 4, 1, 2),  # B, 1, 28, 28
            nn.Sigmoid()
        )



    def sample(self,sample_size,mu=None,logvar=None):
        '''
        :param sample_size: Number of samples
        :param mu: z mean, None for prior (init with zeros)
        :param logvar: z logstd, None for prior (init with zeros)
        :return:
        '''
        if mu==None:
            mu = torch.zeros((sample_size,self.latent_dim)).to(self.device)
        if logvar == None:
            logvar = torch.zeros((sample_size,self.latent_dim)).to(self.device)
        #TODO
        if mu == None:  # Inference
            z = torch.randn((sample_size, self.latent_dim))
        else:
            z = self.z_sample(mu, logvar) #mu, e^logvar=var=sigma^2
        #z = self.z_sample(mu, logvar) #mu, e^logvar=var=sigma^2
        z = self.upsample(z)
        z = z.view([sample_size, 64, 8 , 8])
        return self.decoder(z)


    def z_sample(self, mu, logvar):
        return mu + torch.exp(0.5 * logvar) * torch.randn(mu.size()).to(self.device)


        self.prior_dist = torch.distributions.Normal(mu, torch.exp(0.5*logvar))
        return self.prior_dist.sample()


    def loss(self,x,recon,mu,logvar):
        """
        :param x:
        :param recon: reconstructed
        :param mu:
        :param logvar:
        :return: loss = recon_loss - KL = -ELBO
        """
        #
        # log_p_xi_zi = F.binary_cross_entropy(recon, x, reduction='none')
        # log_p_x_z = torch.sum(log_p_xi_zi, dim=1)
        # # KL
        # Ki_Li = -0.5 * (logvar - torch.exp(logvar) - torch.square(mu) + 1)  # -0.5*(2log(sig)-sig^2-mu^2+1)
        # KL = torch.sum(Ki_Li, dim=1)  # KL(p||q) = sum(KL(pi||qi))
        # return 0.5 * torch.sum(1 + torch.log(1e-10 + sigma_sq) - mu ** 2 - sigma_sq)  # 1e-10 is added inside the log to set a min value.
        # ELBO = log_p_x_z + KL

        #TODO
        #BCE_loss = nn.BCELoss(reduction="sum")
        #recon_loss = BCE_loss(recon, x)
        recon_loss = F.binary_cross_entropy(recon, x, reduction='sum')
        return recon_loss - KL_(mu, torch.exp(logvar)) #minus according to the definition in the paper

    def forward(self, x):
        #TODO
        z = self.encoder(x)
        z_reshape = z.view([z.shape[0], -1])
        mu = self.mu(z_reshape)
        logvar = self.logvar(z_reshape)
        sample = self.upsample(self.z_sample(mu, logvar))
        sample = sample.view([sample.shape[0],64,8,8])
        x_recon = self.decoder(sample)
        return x_recon, mu, logvar

#a = Model(100,'cpu')
#print(a)

In [None]:
"""Training procedure for NICE.
"""
from google.colab import drive
drive.mount('/content/drive')

import argparse
import torch, torchvision
from torchvision import transforms
import numpy as np
#from VAE import Model
import matplotlib.pyplot as plt

def train(vae, trainloader, optimizer, epoch,device):
    vae.train()  # set to training mode
    #TODO
    import torchvision
    train_loss = 0
    for inputs_rgb, _ in trainloader:
        #if(inputs_rgb.size()[1]==3):
        #    inputs = torchvision.transforms.functional.rgb_to_grayscale(inputs_rgb)
        #else:
        inputs = inputs_rgb
        inputs = inputs.to(device)
        optimizer.zero_grad()
        output, mu, logvar = vae(inputs)
        loss = vae.loss(inputs, output, mu, logvar).mean() #loss = -ELBO
        #print(loss)
        train_loss += loss.item()
        #ELBO = - loss
        loss.backward()
        optimizer.step()
    train_loss = train_loss / len(trainloader.dataset)
    print("==> epoch %d: train_loss = %03.2f " % (epoch, train_loss))
    return train_loss

def test(vae, testloader, filename, epoch, device):
    vae.eval()  # set to inference mode
    test_loss = 0
    with torch.no_grad():
        # TODO
        if ((epoch +1)%50 ==0):
            samples = vae.sample(64).to(device)
            torchvision.utils.save_image(samples,'/content/drive/My Drive/VAE_Project_DGM/samples/' + filename + 'epoch%d.png' % epoch)
        for inputs_rgb, _ in testloader:
            #if(inputs_rgb.size()[1]==3):
            #    inputs = torchvision.transforms.functional.rgb_to_grayscale(inputs_rgb)
            #else:
            inputs = inputs_rgb
            inputs = inputs.to(device)
            output, mu, logvar = vae(inputs)
            test_loss += vae.loss(inputs, output, mu, logvar).item()
    test_loss /= len(testloader.dataset)
    print("==> epoch %d: test_loss = %03.2f " % (epoch, test_loss))
    return test_loss

def dequantization(x): #I have an issue with transforms.Lambda pickle, so I change to this function
    return x + torch.zeros_like(x).uniform_(0., 1./256.)

def main(dataset = "cifar10", batch_size = 128, epochs = 50, sample_size = 64, latent_dim = 100, lr = 1e-3,continue_from = 0):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    transform  = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(dequantization), #dequantization
        transforms.Normalize((0.,), (257./256.,)), #rescales to [0,1]

    ])

    if dataset == 'mnist':
        trainset = torchvision.datasets.MNIST(root='./data/MNIST',
            train=True, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
            batch_size=batch_size, shuffle=True, num_workers=2)
        testset = torchvision.datasets.MNIST(root='./data/MNIST',
            train=False, download=True, transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
            batch_size=batch_size, shuffle=False, num_workers=2)
    elif dataset == 'fashion-mnist':
        trainset = torchvision.datasets.FashionMNIST(root='~/torch/data/FashionMNIST',
            train=True, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
            batch_size=batch_size, shuffle=True, num_workers=2)
        testset = torchvision.datasets.FashionMNIST(root='./data/FashionMNIST',
            train=False, download=True, transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
            batch_size=batch_size, shuffle=False, num_workers=2)
    elif dataset == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root='~/torch/data/CIFAR10',
            train=True, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset,
            batch_size=batch_size, shuffle=True, num_workers=2)
        testset = torchvision.datasets.CIFAR10(root='./data/CIFAR10',
            train=False, download=True, transform=transform)
        testloader = torch.utils.data.DataLoader(testset,
            batch_size=batch_size, shuffle=False, num_workers=2)
    else:
        raise ValueError('Dataset not implemented')



    vae = Model(latent_dim=latent_dim,device=device).to(device)
    optimizer = torch.optim.Adam(vae.parameters(), lr=lr)
    #TODO
    filename = '%s_' % dataset \
            + 'batch%d_' % batch_size \
            + 'mid%d_' % latent_dim \
            + 'epoch%d_' % epochs
    train_loss_list = []
    test_loss_list = []
    
    if (continue_from > 0):
      checkpoint = torch.load("/content/drive/My Drive/VAE_Project_DGM/models/cifar10_batch128_mid100_epoch3250__VAE.pt",map_location=torch.device(device))
      #state = {'epoch': n_epoch + 1, 'state_dict': vae.state_dict(),'optimizer': optimizer.state_dict()}
      vae.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      #stopped_at_epoch = continue_from #checkpoint['epoch']    
    
    for n_epoch in range(continue_from,epochs):
        train_loss_list.append(train(vae, trainloader, optimizer, n_epoch, device))
        test_loss_list.append(test(vae, testloader, filename, n_epoch, device))

        if ((n_epoch +1)%50 ==0):
          filename2 = '%s_' % dataset \
            + 'batch%d_' % batch_size \
            + 'mid%d_' % latent_dim \
            + 'epoch%d_' % (n_epoch+1)
          state = {'epoch': n_epoch + 1, 'state_dict': vae.state_dict(),'optimizer': optimizer.state_dict()}
          torch.save(state, '/content/drive/My Drive/VAE_Project_DGM/models/' + filename2 + '_VAE.pt')
    print(train_loss_list)
    print(test_loss_list)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset',
                        help='dataset to be modeled.',
                        type=str,
                        default='cifar10')
    parser.add_argument('--batch_size',
                        help='number of images in a mini-batch.',
                        type=int,
                        default=128)
    parser.add_argument('--epochs',
                        help='maximum number of iterations.',
                        type=int,
                        default=1500)
    parser.add_argument('--sample_size',
                        help='number of images to generate.',
                        type=int,
                        default=64)

    parser.add_argument('--latent_dim',
                        help='.',
                        type=int,
                        default=100)
    parser.add_argument('--lr',
                        help='initial learning rate.',
                        type=float,
                        default=1e-3)

    #args = parser.parse_args()
    #main(dataset = "cifar10", batch_size = 128, epochs = 50, sample_size = 64, latent_dim = 100, lr = 1e-3)


Mounted at /content/drive


In [None]:
main(dataset = "cifar10", batch_size = 128, epochs = 4000, sample_size = 64, latent_dim = 100, lr = 1e-3, continue_from = 3250)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /root/torch/data/CIFAR10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /root/torch/data/CIFAR10/cifar-10-python.tar.gz to /root/torch/data/CIFAR10
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/CIFAR10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/CIFAR10/cifar-10-python.tar.gz to ./data/CIFAR10
==> epoch 3250: train_loss = 1814.97 
==> epoch 3250: test_loss = 1821.10 
==> epoch 3251: train_loss = 1815.04 
==> epoch 3251: test_loss = 1821.21 
==> epoch 3252: train_loss = 1815.03 
==> epoch 3252: test_loss = 1821.62 
==> epoch 3253: train_loss = 1815.08 
==> epoch 3253: test_loss = 1820.86 
==> epoch 3254: train_loss = 1814.95 
==> epoch 3254: test_loss = 1821.55 
==> epoch 3255: train_loss = 1815.01 
==> epoch 3255: test_loss = 1821.15 
==> epoch 3256: train_loss = 1815.06 
==> epoch 3256: test_loss = 1821.03 
==> epoch 3257: train_loss = 1815.00 
==> epoch 3257: test_loss = 1820.95 
==> epoch 3258: train_loss = 1815.00 
==> epoch 3258: test_loss = 1821.25 
==> epoch 3259: train_loss = 1815.05 
==> epoch 3259: test_loss = 1821.02 
==> epoch 3260: train_loss = 1814.98 
==> epoch 3260: test_loss = 1821.88 
==> epoch 3261: train_loss = 1815.01 
==> epoch 3261: test_loss = 1821.07 
==> epoch 3262: train_loss = 1815