In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from PIL import Image

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

In [4]:
# Load dataset
dataset_zip = np.load('/content/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', encoding='bytes')

In [7]:
class dSpritesDataset(Dataset):
    def __init__(self, imgs, transform=None):
        self.imgs = imgs
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        if self.transform:
            img = Image.fromarray(img)
            img = self.transform(img)
        return img

transform = transforms.Compose([
    transforms.ToTensor(),
])

imgs = dataset_zip['imgs']
dsprites_dataset = dSpritesDataset(imgs, transform=transform)

batch_size = 64
data_loader = DataLoader(dsprites_dataset, batch_size=batch_size, shuffle=True)

In [9]:
# Hyperparameters

z_dim = 6 # for dsprites dataset
beta = 4
learning_rate = 1e-4

In [10]:
class Encoder(nn.Module):
    def __init__(self, z_dim, n_channels=1, n_filters=32):
        super(Encoder, self).__init__()

        self.n_filters = n_filters

        self.conv1 = nn.Conv2d(n_channels, n_filters, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(n_filters, n_filters, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(n_filters, n_filters*2, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(n_filters*2, n_filters*2, kernel_size=4, stride=2, padding=1)
        self.fc = nn.Linear(4 * 4 * n_filters * 2, 128)
        self.fc_mu = nn.Linear(128, z_dim)
        self.fc_logvar = nn.Linear(128, z_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(-1, 4 * 4 * self.n_filters * 2)
        x = self.fc(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, z_dim, n_channels=1, n_filters=32):
        super(Decoder, self).__init__()

        self.n_filters = n_filters

        self.fc1 = nn.Linear(z_dim, 128)
        self.fc2 = nn.Linear(128, 4 * 4 * n_filters * 2)
        self.conv1 = nn.ConvTranspose2d(n_filters * 2, n_filters * 2, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(n_filters * 2, n_filters, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.ConvTranspose2d(n_filters, n_filters, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.ConvTranspose2d(n_filters, n_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.view(-1, self.n_filters * 2, 4, 4)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.conv4(x)
        return x


class BetaVAE(nn.Module):
    def __init__(self, beta, z_dim):
        super(BetaVAE, self).__init__()
        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)
        self.beta = beta

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var

    def loss_function(self, x_recon, x, mu, log_var):
        BCE = F.binary_cross_entropy_with_logits(x_recon, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + self.beta * KLD

In [None]:
# training phase
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BetaVAE(beta, z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

epochs = 10

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for data in data_loader:
        data = data.to(device)
        recon_batch, mu, logvar = model(data)
        loss = model.loss_function(recon_batch, data, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(data_loader.dataset)))

