<a href="https://colab.research.google.com/github/rahulsm27/ML/blob/main/VAE_Gumbel_Softmax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:

import torch.utils.data

from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image

import torch.nn.functional as F
import numpy as np
import pandas as pd
import math


In [19]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [20]:


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)

In [21]:
# Code to implement VAE-gumple_softmax in pytorch

# The code has been modified from pytorch example vae code and inspired by the origianl \
# tensorflow implementation of gumble-softmax by Eric Jang.




def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape).to(device)

    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature, hard=False):
    """
    ST-gumple-softmax
    input: [*, n_class]
    return: flatten --> [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)

    if not hard:
        return y.view(-1, latent_dim * categorical_dim)

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard.view(-1, latent_dim * categorical_dim)


class VAE_gumbel(nn.Module):
    def __init__(self, temp):
        super(VAE_gumbel, self).__init__()

        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, latent_dim * categorical_dim)

        self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
        self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(512, 784)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.relu(self.fc3(h2))

    def decode(self, z):
        h4 = self.relu(self.fc4(z))
        h5 = self.relu(self.fc5(h4))
        return self.sigmoid(self.fc6(h5))

    def forward(self, x, temp, hard):
        q = self.encode(x.view(-1, 784))
        q_y = q.view(q.size(0), latent_dim, categorical_dim)
        z = gumbel_softmax(q_y, temp, hard)
        return self.decode(z), F.softmax(q_y, dim=-1).reshape(*q.size())

In [22]:
latent_dim = 30
categorical_dim = 10  # one-of-K vector

temp_min = 1
ANNEAL_RATE = 0.00003

model = VAE_gumbel(temp).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, qy):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False) / x.shape[0]

    log_ratio = torch.log(qy * categorical_dim + 1e-20)
    KLD = torch.sum(qy * log_ratio, dim=-1).mean()

    return BCE + KLD

In [23]:
def train(epoch, model, train_loader, optimizer, temp, cuda=True, hard=False):
  model.train()
  train_loss = 0
  for batch_idx, (data, _) in enumerate(train_loader):
    data.to(device)
    optimizer.zero_grad()

    recon_batch, q_y = model(data, temp, hard)

    loss = loss_function(recon_batch, data, q_y)
    loss.backward()

    train_loss += loss.item() * len(data)
    optimizer.step()

    if batch_idx % 100 == 1:
        temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)

    if batch_idx % 100 == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader),
                  loss.item()))

  print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))


In [35]:
def test(epoch, model, test_loader, temp, cuda=True, hard=False):
  model.eval()
  test_loss = 0

  for i, (data, _) in enumerate(test_loader):
    data.to(device)

    recon_batch, qy = model(data, temp, hard)
    test_loss += loss_function(recon_batch, data, qy).item() * len(data)

    if i % 100 == 1:
        temp = np.maximum(temp * np.exp(-ANNEAL_RATE * i), temp_min)

    if i == 0:
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n],recon_batch.view(128, 1, 28, 28)[:n]])
        save_image(comparison.data.to(device),f"./reconstruction_{epoch:03d}.png", nrow=n)

  test_loss /= len(test_loader.dataset)
  print('====> Test set loss: {:.4f}'.format(test_loss))

In [36]:
epochs = 50
prec = math.ceil(math.log10(epochs / 100))


latent_dim = 32
temp = 1.0
temp_min = 0.5


model = VAE_gumbel(latent_dim)
model.to(device)

VAE_gumbel(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=320, bias=True)
  (fc4): Linear(in_features=320, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=512, bias=True)
  (fc6): Linear(in_features=512, out_features=784, bias=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
)

In [38]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, epochs + 1):
    train(epoch, model, train_loader, optimizer, temp, True)
    test(epoch, model, test_loader, temp, True)

    M = 64 * latent_dim
    np_y = np.zeros((M, 2), dtype=np.float32)
    np_y[range(M), np.random.choice(2, M)] = 1
    np_y = np.reshape(np_y, [M // latent_dim, latent_dim, 2])

    sample = torch.from_numpy(np_y).view(M // latent_dim, latent_dim * 2)
    sample = sample.to(device)
    sample = model.decode(sample).to(device)

    save_image(sample.data.view(M // latent_dim, 1, 28, 28),f"./sample_{epoch:03d}.png")

====> Epoch: 1 Average loss: 164.9451
====> Test set loss: 149.1496


RuntimeError: ignored