<a href="https://colab.research.google.com/github/yingzibu/drug_design_JAK/blob/main/VAE/semi_supervised_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/wohlert/semi-supervised-pytorch/tree/master

Semi-Supervised Learning with Deep Generative Models

https://arxiv.org/abs/1406.5298


In [2]:
import torch
cuda = torch.cuda.is_available()
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import sys
# sys.path.append("../../semi-supervised")

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import init

# from layers import GaussianSample, GaussianMerge, GumbelSoftmax
# from inference import log_gaussian, log_standard_gaussian

$\log var = \log \sigma ^2$

https://github.com/wohlert/semi-supervised-pytorch/blob/master/semi-supervised/models/dgm.py


In [37]:
class Stochastic(nn.Module):
    """
    Base stochastic layer that uses the reparametrization trick to
    draw a sample from a distribution parametrized by mu and logvar
    """
    def reparametrize(self, mu, logvar):
        std = torch.exp(0.5*logvar) # e^(1/2 * log𝜎2) = 𝜎
        eps = torch.randn_like(std)
        return mu + eps * std

class GaussianSample(Stochastic):
    """
    Layer that represents a sample from a Gaussian distribution
    """
    def __init__(self, in_dim, out_dim):
        super(GaussianSample, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        # print(f'in_dim:{in_dim}, out_dim:{out_dim}')
        self.mu_layer = nn.Linear(in_dim, out_dim)
        self.logvar_layer = nn.Linear(in_dim, out_dim)
    def forward(self, x):
        mu = self.mu_layer(x)
        # https://pytorch.org/docs/stable/_images/Softplus.png
        logvar = F.softplus(self.logvar_layer(x))
        return self.reparametrize(mu, logvar), mu, logvar

In [38]:
class Encoder(nn.Module):
    def __init__(self, dims, sample_layer=GaussianSample):
        """
        Inference network:
        Attempts to infer the probability distribution p(z|x)
        from the data by fitting a variational distribution q_phi(z|x).
        (\mu, log \sigma ^2 ) =  q_phi(z|x)

        :param dims: dimensions of the networks
            given by the number of neurons on the form
            [input_dim, [hidden_dims], latent_dim]
        """
        super(Encoder, self).__init__()
        [x_dim, h_dim, z_dim] = dims
        neurons = [x_dim, *h_dim]
        linear_layers = [nn.Linear(
            neurons[i-1], neurons[i]) for i in range(1, len(neurons))]
        self.hidden = nn.ModuleList(linear_layers)
        self.sample = sample_layer(h_dim[-1], z_dim)

    def forward(self, x):
        for layer in self.hidden:
            x = F.relu(layer(x))
        return self.sample(x)
        # GaussianSample: return self.reparametrize(mu, logvar), mu, logvar

class Decoder(nn.Module):
    def __init__(self, dims):
        """
        Generative network:
        Generate samples from the original distribution p(x)
        by transforming a latent representation p_\theta(x|z)

        :param dims: dimensions of the networks
            given by the number of neurons on the form
            [latent_dim, [hidden_dims], input_dim]
        """
        super(Decoder, self).__init__()
        [z_dim, h_dim, x_dim] = dims
        neurons = [z_dim, *h_dim]
        linear_layers = [nn.Linear(
            neurons[i-1], neurons[i]) for i in range(1, len(neurons))]
        self.hidden = nn.ModuleList(linear_layers)
        self.reconstruction = nn.Linear(h_dim[-1], x_dim)
        self.output_activation = nn.Sigmoid()

    def forward(self, z):
        for layer in self.hidden:
            z = F.relu(layer(z))
            z = self.reconstruction(z)
            return self.output_activation(z)

In [45]:
class VAE(nn.Module):
    def __init__(self, dims):
        super(VAE, self).__init__()
        """
        :param dims: [x_dim, z_dim, h_dim]
        """
        [x_dim, z_dim, h_dim] = dims
        # print(f"x_dim: {x_dim}, z_dim: {z_dim}, h_dim: {h_dim}")
        self.z_dim = z_dim
        self.flow = None
        self.encoder = Encoder([x_dim, h_dim, z_dim])
        self.decoder = Decoder([z_dim, list(reversed(h_dim)), x_dim])
        # self.kl_divergence = 0

        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
    def add_flow(self, flow):
        self.flow = flow

    def forward(self, x, y=None):
        """
        :param x: input data
        :return reconstructed input
        """
        z, z_mu, z_logvar = self.encoder(x)
        return self.decoder(z)
    def sample(self, z):
        return self.decoder(z)

In [49]:
class Classifier(nn.Module):
    def __init__(self, dims):
        """
        Single hidden layer classifier with softmax output
        """
        super(Classifier, self).__init__()
        [x_dim, h_dim, y_dim] = dims
        self.dense = nn.Linear(x_dim, h_dim)
        self.logits = nn.Linear(h_dim, y_dim)
    def forward(self, x):
        x = F.relu(self.dense(x))
        x = F.softmax(self.logits(x), dim=-1)
        return x

class DeepGenerativeModel(VAE):
    """
    M2 code replication
    """
    def __init__(self, dims):
        [x_dim, self.y_dim, z_dim, h_dim] = dims
        # print(x_dim, self.y_dim, z_dim, h_dim)
        super(DeepGenerativeModel, self).__init__([x_dim, z_dim, h_dim])
        self.encoder = Encoder([x_dim+self.y_dim, h_dim, z_dim])
        self.decoder = Decoder([z_dim+self.y_dim, list(reversed(h_dim)), x_dim])

        # Only used one layer h_dim[0] since classifier is defined this way
        self.classifier = Classifier([x_dim, h_dim[0], self.y_dim])

        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal(m.weight.data)
                if m.bias is not None: m.bias.data.zero_()

    def forward(self, x, y):
        z, z_mu, z_logvar = self.encoder(torch.cat([x, y], dim=1))
        x = self.decoder(torch.cat([z,y], dim=1))
        return x

    def set_classifier(self, classifier):
        self.classifier = classifier

    def classify(self, x): return self.classifier(x)

    def sample(self, z, y):
        """
        Samples from the decoder to generate x
        :param z: latent normal variable
        :param y: label (one-hot encoded)
        :return: x
        """
        y = y.float()
        x = self.decoder(torch.cat([z,y], dim=1))
        return x

In [50]:
class StackedDGM(DeepGenerativeModel):
    """
    M1 + M2 from paper 'Semi-Supervised Learning with Deep Generative Models'
    (Kingma 2014) in PyTorch.
    """
    def __init__(self, dims, features):
        """
        :param dims: dimensions of x, y, z and hidden layers
        :param features: M1 model of VAE
        """
        [x_dim, y_dim, z_dim, h_dim] = dims
        super(StackedDGM, self).__init__([features.z_dim, y_dim, z_dim, h_dim])
        in_dim = self.decoder.reconstruction.in_features
        self.decoder.reconstruction = nn.Linear(in_dim, x_dim)

        self.features = features
        self.features.train(False) # Do not train features p_\theta(x)

        for param in self.features.parameters(): param.requires_grad=False

    def forward(self, x, y):
        z_sample, _, _ = self.features.encoder(x)
        return super(StackedDGM, self).forward(z_sample, y)

    def classify(self, x):
        _, z_mu, _ = self.features.encoder(x)
        logits = self.classifier(x)
        return logits

In [51]:
y_dim = 10
z_dim = 20
h_dim = [256, 128]
model = DeepGenerativeModel([784, y_dim, z_dim, h_dim])
model

x_dim: 784, z_dim: 20, h_dim: [256, 128]


  init.xavier_normal(m.weight.data)
  init.xavier_normal(m.weight.data)


DeepGenerativeModel(
  (encoder): Encoder(
    (hidden): ModuleList(
      (0): Linear(in_features=794, out_features=256, bias=True)
      (1): Linear(in_features=256, out_features=128, bias=True)
    )
    (sample): GaussianSample(
      (mu_layer): Linear(in_features=128, out_features=20, bias=True)
      (logvar_layer): Linear(in_features=128, out_features=20, bias=True)
    )
  )
  (decoder): Decoder(
    (hidden): ModuleList(
      (0): Linear(in_features=30, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=256, bias=True)
    )
    (reconstruction): Linear(in_features=256, out_features=784, bias=True)
    (output_activation): Sigmoid()
  )
  (classifier): Classifier(
    (dense): Linear(in_features=784, out_features=256, bias=True)
    (logits): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [68]:
import torch
import numpy as np
import sys
from urllib import request
from torch.utils.data import Dataset
def onehot(k):
    """
    Converts a number to its one-hot or 1-of-k representation
    vector.
    :param k: (int) length of vector
    :return: onehot function
    """
    def encode(label):
        y = torch.zeros(k)
        if label < k:
            y[label] = 1
        return y
    return encode

cuda = torch.cuda.is_available()
flatten_bernoulli = lambda x: transforms.ToTensor()(x).view(-1).bernoulli()

def get_mnist(location='./', batch_size=64, labels_per_class=100):
    from functools import reduce
    from operator import __or__
    from torch.utils.data.sampler import SubsetRandomSampler
    from torchvision.datasets import MNIST
    import torchvision.transforms as transforms
    n_labels=10

    mnist_train = MNIST(location, train=True, download=True,
                        transform=flatten_bernoulli,
                        target_transform=onehot(n_labels))
    mnist_valid = MNIST(location, train=False, download=True,
                        transform=flatten_bernoulli,
                        target_transform=onehot(n_labels))
    def get_sampler(labels, n=None):
        (indices, ) = np.where(reduce(__or__,
         [labels==i for i in np.arange(n_labels)]))
        np.random.shuffle(indices)
        indices = np.hstack([list(filter(lambda idx: labels[idx] == i,
                              indices))[:n] for i in range(n_labels)])
        indices = torch.from_numpy(indices)
        sampler = SubsetRandomSampler(indices)
        return sampler
    labelled = torch.utils.data.DataLoader(
        mnist_train, batch_size=batch_size, num_workers=2, pin_memory=cuda,
        sampler=get_sampler(mnist_train.train_labels.numpy(), labels_per_class))
    unlabelled = torch.utils.data.DataLoader(
        mnist_train, batch_size=batch_size, num_workers=2, pin_memory=cuda,
        sampler=get_sampler(mnist_train.train_labels.numpy()))
    validation = torch.utils.data.DataLoader(
        mnist_valid, batch_size=batch_size, num_workers=2, pin_memory=cuda,
        sampler=get_sampler(mnist_valid.test_labels.numpy()))
    return labelled, unlabelled, validation


In [None]:
labelled, unlabelled, validation = get_mnist(location="./", batch_size=64, labels_per_class=10)

In [66]:
len(labelled), len(unlabelled), len(validation)

(2, 938, 157)

In [70]:
from torchvision.datasets import MNIST
mnist_train = MNIST('./', train=True, download=True,
                        transform=flatten_bernoulli,
                        target_transform=onehot(10))

In [71]:
len(mnist_train)

60000

In [73]:
np.arange(10)

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])