In [5]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Load Data

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)


train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) 

# VAE

In [7]:
class ConvolutionalVAE(nn.Module):
    def __init__(self, latent_dim):
        super(ConvolutionalVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.flatten = nn.Flatten() # Flatten the output of the convolutional encoder
        self.fc_mu = nn.Linear(32 * 7 * 7, latent_dim)  # Adjust the input size based on your image size and encoder layers
        self.fc_logvar = nn.Linear(32 * 7 * 7, latent_dim)
        self.unflatten = nn.Unflatten(1, (32, 7, 7)) # Unflatten before the decoder
        self.linear_decoder = nn.Linear(latent_dim, 32 * 7 * 7) 
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
        

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        hidden = self.encoder(x)
        hidden = self.flatten(hidden) # Flatten for the fully connected layers
        mu = self.fc_mu(hidden)
        logvar = self.fc_logvar(hidden)
        z = self.reparameterize(mu, logvar)
        z = self.linear_decoder(z)
        z = self.unflatten(z) # Unflatten before the convolutional decoder
        decoded = self.decoder(z)
        return decoded, mu, logvar


def loss_function(reconstructed_x, x, mu, logvar):
    reconstruction_loss = nn.MSELoss()(reconstructed_x, x)
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + kl_divergence



In [8]:
vae_model = ConvolutionalVAE(latent_dim = 64)
optimizer = torch.optim.Adam(vae_model.parameters(), lr=0.001)

epoch = 1
for i in range(epoch):
    ite = 0
    for x,y in train_loader:
        
        optimizer.zero_grad()
        
        outputs, mu, logvar  = vae_model(x)
        loss = loss_function(outputs, x, mu, logvar)
        
        loss.backward()
        optimizer.step()
        
        ite += 1
        
        if ite % 10 == 0:
            print("iteration: ", ite , "Loss: ", float(loss))

iteration:  10 Loss:  2.7487003803253174
iteration:  20 Loss:  1.687086582183838
iteration:  30 Loss:  1.3851256370544434
iteration:  40 Loss:  1.2362749576568604
iteration:  50 Loss:  1.1777422428131104
iteration:  60 Loss:  1.1163502931594849
iteration:  70 Loss:  1.0577329397201538
iteration:  80 Loss:  1.03568696975708
iteration:  90 Loss:  1.0474708080291748
iteration:  100 Loss:  1.001891016960144
iteration:  110 Loss:  0.8907874822616577
iteration:  120 Loss:  0.922481894493103
iteration:  130 Loss:  0.8945369720458984
iteration:  140 Loss:  0.8715651631355286
iteration:  150 Loss:  0.8653495907783508
iteration:  160 Loss:  0.8826323747634888
iteration:  170 Loss:  0.8221346139907837
iteration:  180 Loss:  0.8206678628921509
iteration:  190 Loss:  0.8077137470245361
iteration:  200 Loss:  0.8222501277923584
iteration:  210 Loss:  0.8674689531326294
iteration:  220 Loss:  0.8041913509368896
iteration:  230 Loss:  0.8080997467041016
iteration:  240 Loss:  0.7551628351211548
iterat

In [10]:
vae_model.eval

with torch.no_grad():
    mu_list = []
    logvar_list = []
    for x,y in test_loader:
        
        _, mu, logvar = vae_model(x)
        mu_list.append(mu)
        logvar_list.append(logvar)

logvar_list = torch.stack(logvar_list)     
mu_list = torch.stack(mu_list)
print(mu_list.shape)
mu_list

torch.Size([120, 100, 64])


tensor([[[-2.7620e-05, -3.5650e-06,  4.9145e-05,  ...,  5.9949e-07,
          -1.5809e-05,  1.8803e-05],
         [-2.7620e-05, -3.5650e-06,  4.9145e-05,  ...,  5.9949e-07,
          -1.5809e-05,  1.8803e-05],
         [-7.0616e-04, -2.5848e-04,  3.9620e-04,  ..., -8.5785e-04,
          -6.0410e-04,  8.6610e-04],
         ...,
         [-2.7620e-05, -3.5650e-06,  4.9145e-05,  ...,  5.9949e-07,
          -1.5809e-05,  1.8803e-05],
         [-2.7620e-05, -3.5650e-06,  4.9145e-05,  ...,  5.9949e-07,
          -1.5809e-05,  1.8803e-05],
         [-2.7620e-05, -3.5650e-06,  4.9145e-05,  ...,  5.9949e-07,
          -1.5809e-05,  1.8803e-05]],

        [[-2.7620e-05, -3.5650e-06,  4.9145e-05,  ...,  5.9949e-07,
          -1.5809e-05,  1.8803e-05],
         [-2.7620e-05, -3.5650e-06,  4.9145e-05,  ...,  5.9949e-07,
          -1.5809e-05,  1.8803e-05],
         [-2.1464e-05, -2.9758e-05,  1.0743e-04,  ..., -6.2523e-05,
          -5.3479e-05,  2.8218e-05],
         ...,
         [-2.7620e-05, -3