<a href="https://colab.research.google.com/github/yingzibu/drug_design_JAK/blob/main/VAE/VAE_digits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [27]:
"""
Reference: https://github.com/wohlert/semi-supervised-pytorch/tree/master
"""

import torch
from torch.autograd import Variable
import math
import torch.nn as nn
import torch.nn.functional as F
from itertools import repeat
from torchvision import transforms, datasets
from tqdm import tqdm

In [28]:
def enumerate_discrete(x, y_dim):
    def batch(batch_size, label):
        labels = (torch.ones(batch_size, 1) * label).type(torch.LongTensor)
        y = torch.zeros((batch_size, y_dim))
        y.scatter_(1, labels, 1)
        return y.type(torch.LongTensor)
    batch_size = x.size(0)
    generated = torch.cat([batch(batch_size, i) for i in range(y_dim)])
    if x.is_cuda: generated = generated.cuda()
    return Variable(generated.float())

def onehot(k):
    """
    Converts a number to its torch.Size([k])
    one-hot representation vector
    :param k: (int) length of vector
    : return onehot function
    """
    def encode(label):
        y = torch.zeros(k)
        if label < k: y[label] = 1
        return y
    return encode # torch.Size([k])

def log_sum_exp(tensor, dim=-1, sum_op=torch.sum):
    """
    :param tensor: Tensor to compute LogSumExp (LSE) over
                    as approximation for the sum in a log domain
    :param dim: dimension to perform opertation over
    :param sum_op: reductive operation to be applied: torch.sum or torch.mean
    :return LSE
    """
    max, argmax = torch.max(tensor, dim=dim, keepdim=True)
    LSE = torch.log(sum_op(torch.exp(tensor-max), dim=dim,
                           keepdim=True) + 1e-8) + max
    return LSE

def log_gaussian(x, mu, logvar):
    ":return log N(x|mu, var)"
    log_pdf = -0.5 * math.log(2*math.pi) - logvar / 2 - \
                (x - mu) **2 / (2 * torch.exp(logvar))
    return torch.sum(log_pdf, dim=-1)

def log_standard_gaussian(x):
    mu = torch.zeros(x.shape)
    logvar = torch.log(torch.ones_like(x)) # donot forget log!
    return log_gaussian(x, mu, logvar)

def log_standard_categorical(p):
    """
    Returns H(p, u), u is a standard (uniform) categorical distribution
    """
    prior = F.softmax(torch.ones_like(p))
    prior.requires_grad = False
    cross_entropy = - torch.sum(p*torch.log(prior+ 1e-8), dim=1)
    return cross_entropy

In [29]:
class IWS(object):
    """
    Importance weighted sampler (Burda 2015) to be used in conjunction with SVI
    """
    def __init__(self, mc=1, iw=1):
        """
        :param mc: number of Monte Carlo samples
        :param iw: number of Importance Weighted samples
        """
        self.mc = mc
        self.iw = iw

    def resample(self, x): return x.repeat(self.mc * self.iw, 1)

    def __call__(self, elbo):
        elbo = elbo.view(self.mc, self.iw, -1)
        elbo = torch.mean(log_sum_exp(elbo, dim=1, sum_op=torch.mean), dim=0)
        return elbo.view(-1)

class SVI(nn.Module):
    """Stochastic variational inference (SVI)"""
    base_sampler = IWS(mc=1, iw=1)
    def __init__(self, model, likelihood=F.binary_cross_entropy,
                 beta=repeat(1), sampler=base_sampler):
        """
        Initializes a new SVI optimizer for semi-supervised learning
        :param model: semi-supervised model to evaluate
        :param likelihood: p(x|y, z) for example BCE or MSE
        :param sampler: sampler for x and y, e.g. for Monte Carlo
        :param beta: warm-up/scaling of KL-term
        """
        super(SVI, self).__init__()
        self.model = model
        self.likelihood = likelihood
        self.beta = beta
        self.sampler = sampler

    def forward(self, x, y=None):
        is_labelled = False if y is None else True
        xs, ys = (x, y)
        if not is_labelled:
            ys = enumerate_discrete(xs, self.model.y_dim) # why
            xs = xs.repeat(self.model.y_dim, 1)
        xs = self.sampler.resample(xs)
        ys = self.sampler.resample(ys)
        reconstruction = self.model(xs, ys) # x, y -> z, y -> x

        # p(x|y, z)
        likelihood = -self.likelihood(reconstruction, xs)

        # p(y)
        prior = -log_standard_categorical(ys)

        # -L(x, y) = E_q_theta(z|x, y) [log p_theta(x|y, z) + log p(y)
        #                               + log p(z) - log q_phi(z|x, y)]
        #          = likelihood + prior - KL_divergence
        elbo = likelihood + prior - next(self.beta) * self.model.kl_divergence
        elbo = self.sampler(elbo)

        if is_labelled: return torch.mean(elbo)

        logits = self.model.classify(x)
        elbo = elbo.view_as(logits.t()).t()

        H = - torch.sum(torch.mul(logits, torch.log(logits+1e-8)), dim=-1)
        L_minus = torch.sum(torch.mul(logits, elbo), dim=-1)

        return torch.mean(H + L_minus)

In [None]:
!git clone https://github.com/wohlert/semi-supervised-pytorch.git

In [30]:
cd /content/semi-supervised-pytorch/semi-supervised/

/content/semi-supervised-pytorch/semi-supervised


In [31]:
from models import StackedDeepGenerativeModel, DeepGenerativeModel
from models import VariationalAutoencoder

In [32]:
y_dim = 10
z_dim = 32
h_dim = [256, 128]
x_dim = 784
model_VAE = VariationalAutoencoder([x_dim, z_dim, h_dim])
model_VAE

VariationalAutoencoder(
  (encoder): Encoder(
    (hidden): ModuleList(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): Linear(in_features=256, out_features=128, bias=True)
    )
    (sample): GaussianSample(
      (mu): Linear(in_features=128, out_features=32, bias=True)
      (log_var): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (decoder): Decoder(
    (hidden): ModuleList(
      (0): Linear(in_features=32, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=256, bias=True)
    )
    (reconstruction): Linear(in_features=256, out_features=784, bias=True)
    (output_activation): Sigmoid()
  )
)

In [33]:
cuda = torch.cuda.is_available()
import matplotlib.pyplot as plt
%matplotlib inline
import sys

In [34]:
import torch
import numpy as np
import sys
from urllib import request
from torch.utils.data import Dataset
sys.path.append("../semi-supervised")
n_labels = 10
cuda = torch.cuda.is_available()


class SpriteDataset(Dataset):
    """
    A PyTorch wrapper for the dSprites dataset by
    Matthey et al. 2017. The dataset provides a 2D scene
    with a sprite under different transformations:
    * color
    * shape
    * scale
    * orientation
    * x-position
    * y-position
    """
    def __init__(self, transform=None):
        self.transform = transform
        url = "https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"

        try:
            self.dset = np.load("./dsprites.npz", encoding="bytes")["imgs"]
        except FileNotFoundError:
            request.urlretrieve(url, "./dsprites.npz")
            self.dset = np.load("./dsprites.npz", encoding="bytes")["imgs"]

    def __len__(self):
        return len(self.dset)

    def __getitem__(self, idx):
        sample = self.dset[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample


def get_mnist(location="./", batch_size=64, labels_per_class=100):
    from functools import reduce
    from operator import __or__
    from torch.utils.data.sampler import SubsetRandomSampler
    from torchvision.datasets import MNIST
    import torchvision.transforms as transforms
    from utils import onehot

    flatten_bernoulli = lambda x: transforms.ToTensor()(x).view(-1).bernoulli()

    mnist_train = MNIST(location, train=True, download=True,
                        transform=flatten_bernoulli, target_transform=onehot(n_labels))
    mnist_valid = MNIST(location, train=False, download=True,
                        transform=flatten_bernoulli, target_transform=onehot(n_labels))

    def get_sampler(labels, n=None):
        # Only choose digits in n_labels
        (indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(n_labels)]))

        # Ensure uniform distribution of labels
        np.random.shuffle(indices)
        indices = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(n_labels)])

        indices = torch.from_numpy(indices)
        sampler = SubsetRandomSampler(indices)
        return sampler

    # Dataloaders for MNIST
    labelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, num_workers=2, pin_memory=cuda,
                                           sampler=get_sampler(mnist_train.train_labels.numpy(), labels_per_class))
    unlabelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, num_workers=2, pin_memory=cuda,
                                             sampler=get_sampler(mnist_train.train_labels.numpy()))
    validation = torch.utils.data.DataLoader(mnist_valid, batch_size=batch_size, num_workers=2, pin_memory=cuda,
                                             sampler=get_sampler(mnist_valid.test_labels.numpy()))

    return labelled, unlabelled, validation


In [None]:
l, u, v = get_mnist(location='./', batch_size=64, labels_per_class=10)

In [35]:
_, train, val = get_mnist(location = './', batch_size=64)

In [37]:
def binary_cross_entropy(r, x):
    return -torch.sum(x * torch.log(r+1e-8) + \
           (1-x) * torch.log(1-r + 1e-8), dim=-1)

optimizer = torch.optim.Adam(model_VAE.parameters(),
                             lr=3e-4, betas=(0.9, 0.999))
# alpha = 0.1 * len(u)/len(l)
from itertools import cycle
# from inference import SVI, IWS
sampler = IWS(mc=1, iw=1)
if cuda: model_VAE = model_VAE.cuda()


In [44]:
from google.colab import files

In [None]:
for epoch in range(50):
    model_VAE.train()
    total_loss = 0
    for (u, _) in tqdm(train, desc=f"epoch {epoch}"):
        u = Variable(u)

        if cuda: u = u.cuda(device=0)

        reconstruction = model_VAE(u)

        likelihood = -binary_cross_entropy(reconstruction, u)
        elbo = likelihood - model_VAE.kl_divergence

        L = -torch.mean(elbo)

        L.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += L.data.item()

    m = len(train)
    print(f"Epoch {epoch}\tL: {total_loss/m:.2f}")

    if epoch % 5 == 0:
        model_VAE.eval()
        val_loss = 0
        for (u, _) in val:
            u = Variable(u)
            if cuda: u = u.cuda(device=0)
            recon = model_VAE(u)
            l = -binary_cross_entropy(recon, u)
            elbo = l - model_VAE.kl_divergence
            L = -torch.mean(elbo)
            val_loss += L.data.item()
        print(f"Valid \tL: {val_loss/len(val):.2f}")

        torch.save(model_VAE.state_dict(), 'VAE_digits_pretrain_' + str(epoch) + ".pt")
        files.download('VAE_digits_pretrain_' + str(epoch) + ".pt")
# model.load_state_dict(torch.load('VAE_digits_pretrain_' + str(epoch) + ".pt",
#                                     map_location=device))

epoch 0: 100%|██████████| 938/938 [00:39<00:00, 23.50it/s]


Epoch: 0	L: 127.69





Valid 	L: 127.21



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

epoch 1: 100%|██████████| 938/938 [00:39<00:00, 23.68it/s]



Epoch: 1	L: 127.57


epoch 2: 100%|██████████| 938/938 [00:40<00:00, 23.25it/s]



Epoch: 2	L: 127.52


epoch 3: 100%|██████████| 938/938 [00:41<00:00, 22.48it/s]



Epoch: 3	L: 127.48


epoch 4: 100%|██████████| 938/938 [00:39<00:00, 23.47it/s]



Epoch: 4	L: 127.23


epoch 5: 100%|██████████| 938/938 [00:39<00:00, 23.62it/s]


Epoch: 5	L: 127.20





Valid 	L: 126.59



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

epoch 6: 100%|██████████| 938/938 [00:39<00:00, 23.51it/s]



Epoch: 6	L: 127.28


epoch 7:  81%|████████▏ | 763/938 [00:32<00:07, 24.73it/s]

In [45]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# kwargs = {'num_workers':1, 'pin_memory': True}
# batch_size = 126
# train_loader = torch.utils.data.DataLoader(
#     datasets.MNIST('../data', train=True, download=True,
#                    transform=transforms.ToTensor()),
#     batch_size=batch_size, shuffle=True, **kwargs)
# test_loader = torch.utils.data.DataLoader(
#     datasets.MNIST('../data', train=False, download=True,
#                    transform=transforms.ToTensor()),
#     batch_size=batch_size, shuffle=True, **kwargs)

# def loss_function(recon_x, x, mu, logvar):
#     BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
#     KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
#     return BCE + KLD

In [40]:
val

<torch.utils.data.dataloader.DataLoader at 0x789fac0c02e0>

In [38]:
# device = 'cuda'
# for epoch in range(10):
#     model_VAE.train()
#     total_loss = 0
#     for i, (data, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
#         # optimizer.zero_grad()
#         if cuda: data = data.to(device)
#         recon_batch = model_VAE(data.view(-1, 784))

#         BCE_loss = F.binary_cross_entropy(recon_batch, data.view(-1, 784),
#                                       reduction='sum')
#         loss = BCE_loss + torch.mean(model_VAE.kl_divergence)
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()

#         total_loss += loss.data.item()
#     if epoch % 1 == 0:
#         model_VAE.eval()
#         print('Epoch :{}'.format(epoch))
#         print('[Train] \t loss:{:.2f}'.format(total_loss/len(train_loader)))
#         val_loss = 0
#         for (data, _) in test_loader:
#             if cuda: data = data.to(device)

#             recon_batch = model_VAE(data.view(-1, 784))
#             BCE_loss = F.binary_cross_entropy(recon_batch, data.view(-1, 784),
#                                       reduction='sum')
#             loss = BCE_loss + torch.mean(model_VAE.kl_divergence)
#             val_loss += loss.data.item()
#         print('[valid] \t loss:{:.2f}'.format(loss / len(test_loader)))




In [None]:
BCE_loss

tensor(69300.1406, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward0>)

In [None]:
log_gaussian(onehot(3)(2), torch.Tensor([1, 2, 3]), torch.Tensor([3, 2, 3]))

tensor(-7.1520)

In [None]:
def log_standard_gaussian_or(x):
    """
    Evaluates the log pdf of a standard normal distribution at x.

    :param x: point to evaluate
    :return: log N(x|0,I)
    """
    return torch.sum(-0.5 * math.log(2 * math.pi) - x ** 2 / 2, dim=-1)

In [None]:
log_standard_gaussian(onehot(3)(2)) == log_standard_gaussian_or(onehot(3)(2))

tensor(True)

In [None]:
F.softmax(torch.ones_like(onehot(3)(2)).T, dim=1)

  F.softmax(torch.ones_like(onehot(3)(2)).T, dim=1)


IndexError: ignored