In [139]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import idx2numpy
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt

In [270]:
input_dim = 784
hidden_dim = 200
latent_dim = 20
epochs = 30
learning_rate = 0.00003
batch_size = 32

In [271]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, hidden_dim)
        self.linear_mu = nn.Linear(hidden_dim, latent_dim)
        self.linear_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        h = self.linear(x)
        print(x.shape, np.isnan(h.detach().numpy()).any())
        h = F.relu(h)
        mu = self.linear_mu(h)
        logvar = self.linear_logvar(h)
        sigma = torch.exp(0.5 * logvar)
        return mu, sigma
    
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(latent_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        h = self.linear1(z)
        h = F.relu(h)
        h = self.linear2(h)
        x_hat = F.sigmoid(h)
        return h

def reparameterize(mu, sigma):
    eps = torch.randn_like(sigma)
    z = mu + eps * sigma
    return z

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

    def get_loss(self, x):
        mu, sigma = self.encoder(x)
        z = reparameterize(mu, sigma)
        x_hat = self.decoder(z)
        if np.isnan(mu.detach().numpy()).any():
            print("|||", mu[:10], sigma[:10], mu.shape, sigma.shape)
            raise
        batch_size = len(x)
        L1 = F.mse_loss(x_hat, x, reduction='sum')
        L2 = - torch.sum(1 + torch.log(sigma**2) - mu ** 2 - sigma ** 2)
        return (L1 + L2) / batch_size


In [272]:
for _ in range(100):
    a = torch.FloatTensor(images).flatten(1,-1)
    encoder = Encoder(input_dim, hidden_dim, latent_dim)
    m, c = encoder(a)
    print(c.shape)

torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size([60000, 784]) False
torch.Size([60000, 20])
torch.Size

In [None]:
images = idx2numpy.convert_from_file('data/MNIST/raw/train-images.idx3-ubyte')  # shape: (num_images, rows, cols)
labels = idx2numpy.convert_from_file('data/MNIST/raw/train-labels.idx1-ubyte')  # shape: (num_images, rows, cols)

class mnistDataset(Dataset):
    def __init__(self, images, labels):
        self.x_data = torch.FloatTensor(images, requires_grad = True).flatten(1,-1)
        self.y_data = torch.FloatTensor(labels)

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        x = self.x_data[idx]
        y = self.y_data[idx]

        return x, y

dataset = mnistDataset(images, labels)

dataloader = DataLoader(
    dataset, 
    batch_size = batch_size, 
    shuffle=True
)

In [None]:
model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
losses = []

for epoch in range(epochs):
    loss_sum = 0
    cnt = 0

    for x, label in dataloader:
        optimizer.zero_grad()
        loss = model.get_loss(x)
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
        cnt += 1

    loss_avg = loss_sum / cnt
    losses.append(loss_avg)
    print(loss_avg)

torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]) False
torch.Size([32, 784]

RuntimeError: No active exception to reraise

In [None]:
for param in model.encoder.parameters():
    print(param)

Parameter containing:
tensor([[ 0.0008, -0.0330,  0.0316,  ..., -0.0192, -0.0342,  0.0032],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        ...,
        [ 0.0283, -0.0100,  0.0090,  ..., -0.0068, -0.0327,  0.0049],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [-0.0294,  0.0147,  0.0282,  ..., -0.0030, -0.0120,  0.0354]],
       requires_grad=True)
Parameter containing:
tensor([ 0.0280,     nan,     nan,  0.0063,     nan,     nan,  0.0131,     nan,
         0.0006,  0.0141,  0.0049,     nan, -0.0267,     nan,     nan, -0.0237,
            nan,     nan,     nan,  0.0180, -0.0212,     nan, -0.0154, -0.0184,
            nan,     nan, -0.0254,     nan,  0.0115, -0.0346, -0.0295, -0.0223,
        -0.0227,  0.0221,     nan, -0.0281,     nan,     nan,     nan,  0.0283,
            nan,     nan,  0.0226,  0.0076, -0.0317, -0.0058,  0.0324,     nan,
        -0.0306