In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from torchvision import transforms
from torch.utils.data import DataLoader

In [2]:
# Load dataset
dataset_zip = np.load('dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', allow_pickle=True, encoding='bytes')

print('Keys in the dataset:', dataset_zip.keys())
imgs = dataset_zip['imgs']
latents_values = dataset_zip['latents_values']
latents_classes = dataset_zip['latents_classes']
metadata = dataset_zip['metadata'][()]

print('Metadata: \n', metadata)

Keys in the dataset: KeysView(<numpy.lib.npyio.NpzFile object at 0x7fb0fdf69c60>)
Metadata: 
 {b'date': b'April 2017', b'description': b'Disentanglement test Sprites dataset.Procedurally generated 2D shapes, from 6 disentangled latent factors.This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite. All possible variations of the latents are present. Ordering along dimension 1 is fixed and can be mapped back to the exact latent values that generated that image.We made sure that the pixel outputs are different. No noise added.', b'version': 1, b'latents_names': (b'color', b'shape', b'scale', b'orientation', b'posX', b'posY'), b'latents_possible_values': {b'orientation': array([0.        , 0.16110732, 0.32221463, 0.48332195, 0.64442926,
       0.80553658, 0.96664389, 1.12775121, 1.28885852, 1.44996584,
       1.61107316, 1.77218047, 1.93328779, 2.0943951 , 2.25550242,
       2.41660973, 2.57771705, 2.73882436, 2.89993168, 3.061039  ,
       3.22

In [3]:
class DspritesDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_zip):
        self.imgs = dataset_zip['imgs']
        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        img = self.transform(img)
        return img


# Create dataset and dataloader
dataset = DspritesDataset(dataset_zip)
data_loader = DataLoader(dataset, batch_size=64, shuffle=True)


In [None]:
class BetaVAE(nn.Module):
    def __init__(self, beta, z_dim, n_rows, n_cols, n_channels, n_filters, kernel_size):
        super(BetaVAE, self).__init__()
        self.z_dim = z_dim
        self.n_rows = n_rows
        self.n_cols = n_cols
        self.n_channels = n_channels
        self.n_filters = n_filters
        self.kernel_size = kernel_size
        self.beta = beta

        self.conv1 = nn.Conv2d(n_channels, n_filters, kernel_size, stride=2, padding=1)
        self.conv2 = nn.Conv2d(n_filters, n_filters, kernel_size, stride=2, padding=1)
        self.conv3 = nn.Conv2d(n_filters*2, n_filters*2, kernel_size, stride=2, padding=1)
        self.conv4 = nn.Conv2d(n_filters*2, n_filters*2, kernel_size, stride=2, padding=1)

        self.fc1 = nn.Linear(4*4*n_filters*2, 256)
        self.fc21 = nn.Linear(256, z_dim)
        self.fc22 = nn.Linear(256, z_dim)

        self.fc_reverse1 = nn.Linear(z_dim, 256)
        self.fc_reverse2 = nn.Linear(256, 4*4*n_filters*2)
        self.upconv1 = nn.ConvTranspose2d(n_filters*2, n_filters*2, kernel_size, stride=2, padding=1)
        self.upconv2 = nn.ConvTranspose2d(n_filters*2, n_filters, kernel_size, stride=2, padding=1)
        self.upconv3 = nn.ConvTranspose2d(n_filters, n_filters, kernel_size, stride=2, padding=1)
        self.upconv4 = nn.ConvTranspose2d(n_filters, n_channels, kernel_size, stride=2, padding=1)

    def encoder(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(-1, 256)

        x = F.relu(self.fc1(x))
        mu = self.fc21(x)
        logvar = self.fc22(x)
        return mu, logvar
    
    def decoder(self, z):
        z = F.relu(self.fc_reverse1(z))
        z = F.relu(self.fc_reverse2(z))
        z = z.view(-1, self.n_filters*2, 4, 4)
        z = F.relu(self.upconv1(z))
        z = F.relu(self.upconv2(z))
        z = F.relu(self.upconv3(z))
        z = self.upconv4(z)
        return z
    

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar
    
    def loss_function(self, y, x, mu, logvar):
        BCE = F.binary_cross_entropy(y, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + self.beta * KLD

In [None]:
# training phase
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
beta = 4
model = BetaVAE(4, 32, 64, 64, 3, 32, 4).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

epochs = 10

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = model.loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(data_loader.dataset)))

