In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

### **Load Dateset**

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)


train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) 

### **Model**

In [3]:
class ConvolutionalVAE(nn.Module):
    def __init__(self, latent_dim):
        super(ConvolutionalVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.flatten = nn.Flatten() # Flatten the output of the convolutional encoder
        self.fc_mu = nn.Linear(32 * 7 * 7, latent_dim)  # Adjust the input size based on your image size and encoder layers
        self.fc_logvar = nn.Linear(32 * 7 * 7, latent_dim)
        self.unflatten = nn.Unflatten(1, (32, 7, 7)) # Unflatten before the decoder
        self.linear_decoder = nn.Linear(latent_dim, 32 * 7 * 7) 
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
        

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        hidden = self.encoder(x)
        hidden = self.flatten(hidden) # Flatten for the fully connected layers
        mu = self.fc_mu(hidden)
        logvar = self.fc_logvar(hidden)
        z = self.reparameterize(mu, logvar)
        z = self.linear_decoder(z)
        z = self.unflatten(z) # Unflatten before the convolutional decoder
        decoded = self.decoder(z)
        return decoded, mu, logvar


def loss_function(reconstructed_x, x, mu, logvar):
    reconstruction_loss = nn.MSELoss()(reconstructed_x, x)
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + kl_divergence



### **Training Loop**

In [4]:
vae_model = ConvolutionalVAE(latent_dim = 64)
optimizer = torch.optim.Adam(vae_model.parameters(), lr=0.001)

epoch = 1
for i in range(epoch):
    ite = 0
    for x,y in train_loader:
        
        optimizer.zero_grad()
        
        outputs, mu, logvar  = vae_model(x)
        loss = loss_function(outputs, x, mu, logvar)
        
        loss.backward()
        optimizer.step()
        
        ite += 1
        
        if ite % 10 == 0:
            print("iteration: ", ite , "Loss: ", float(loss))

iteration:  10 Loss:  2.306786298751831
iteration:  20 Loss:  1.557590126991272
iteration:  30 Loss:  1.3530007600784302
iteration:  40 Loss:  1.2610061168670654
iteration:  50 Loss:  1.1733235120773315
iteration:  60 Loss:  1.123488426208496
iteration:  70 Loss:  1.0798426866531372
iteration:  80 Loss:  1.0956699848175049
iteration:  90 Loss:  1.0120761394500732
iteration:  100 Loss:  0.9585146903991699
iteration:  110 Loss:  0.9695087671279907
iteration:  120 Loss:  0.935856819152832
iteration:  130 Loss:  0.9504994750022888
iteration:  140 Loss:  0.8869339227676392
iteration:  150 Loss:  0.8392549753189087
iteration:  160 Loss:  0.8579514026641846
iteration:  170 Loss:  0.8511385321617126
iteration:  180 Loss:  0.8703644871711731
iteration:  190 Loss:  0.8119795322418213
iteration:  200 Loss:  0.8327888250350952
iteration:  210 Loss:  0.8342608213424683
iteration:  220 Loss:  0.8262317776679993
iteration:  230 Loss:  0.8024864792823792
iteration:  240 Loss:  0.8170163631439209
itera

### **Evaluation**

In [5]:
vae_model.eval

with torch.no_grad():
    mu_list = []
    logvar_list = []
    for x,y in test_loader:
        
        _, mu, logvar = vae_model(x)
        mu_list.append(mu)
        logvar_list.append(logvar)

logvar_list = torch.stack(logvar_list)     
mu_list = torch.stack(mu_list)
print(mu_list.shape)
mu_list

torch.Size([120, 100, 64])


tensor([[[-1.4248e-05,  2.8832e-05,  1.4438e-05,  ...,  1.3050e-06,
          -9.5947e-06,  2.5988e-05],
         [-1.4248e-05,  2.8832e-05,  1.4438e-05,  ...,  1.3050e-06,
          -9.5947e-06,  2.5988e-05],
         [-1.4248e-05,  2.8832e-05,  1.4438e-05,  ...,  1.3050e-06,
          -9.5947e-06,  2.5988e-05],
         ...,
         [ 3.5342e-04,  5.5507e-05,  4.0163e-04,  ..., -8.1336e-04,
          -5.3792e-04, -2.8998e-04],
         [-1.4248e-05,  2.8832e-05,  1.4438e-05,  ...,  1.3050e-06,
          -9.5947e-06,  2.5988e-05],
         [-1.4248e-05,  2.8832e-05,  1.4438e-05,  ...,  1.3050e-06,
          -9.5947e-06,  2.5988e-05]],

        [[-9.0247e-06,  1.6637e-05,  1.1014e-05,  ...,  1.4341e-05,
          -1.2198e-05,  1.6358e-05],
         [-1.4248e-05,  2.8832e-05,  1.4438e-05,  ...,  1.3050e-06,
          -9.5947e-06,  2.5988e-05],
         [-1.4248e-05,  2.8832e-05,  1.4438e-05,  ...,  1.3050e-06,
          -9.5947e-06,  2.5988e-05],
         ...,
         [-1.4248e-05,  2