In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

# Hyperparameters
latent_size = 10
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
# Data loading
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

In [32]:
class VAE(nn.Module):
    # ... (Implement encoder, decoder, reparameterize, and forward methods)
    def __init__(self):
        super(VAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),  # Flatten input to 784
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),  # Adjust output dimension as needed
        )
        self.fc_mu = nn.Linear(128, latent_size)
        self.fc_var = nn.Linear(128, latent_size)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Sigmoid()  # Output pixel values between 0 and 1
        )

    def encode(self, x):
        return self.fc_mu(self.encoder(x)), self.fc_var(self.encoder(x))

    def decode(self, z):
        return self.decoder(z)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        recon_x = recon_x.view(-1, 1, 28, 28)  # Reshape to match input size
        return recon_x, mu, logvar

def loss_function(recon_x, x, mu, logvar):
    reconstruction_loss = F.mse_loss(recon_x, x, reduction='sum')
    # kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl_divergence = torch.tensor(0.0)
    return reconstruction_loss + kl_divergence

In [33]:
# Initialize model and optimizer
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
mse_loss = nn.MSELoss()

for epoch in range(num_epochs):
    for i, (x, _) in enumerate(train_loader):
        x = x.to(device)

        recon_batch, mu, logvar = model(x)
        loss = loss_function(recon_batch, x, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, len(train_loader), loss.item()))

Epoch [1/10], Step [100/469], Loss: 6181.0005
Epoch [1/10], Step [200/469], Loss: 5731.0181
Epoch [1/10], Step [300/469], Loss: 4543.7559
Epoch [1/10], Step [400/469], Loss: 3709.3159
Epoch [2/10], Step [100/469], Loss: 3023.7329
Epoch [2/10], Step [200/469], Loss: 2892.3960
Epoch [2/10], Step [300/469], Loss: 2617.1399
Epoch [2/10], Step [400/469], Loss: 2326.7622
Epoch [3/10], Step [100/469], Loss: 2186.1108
Epoch [3/10], Step [200/469], Loss: 1966.6953
Epoch [3/10], Step [300/469], Loss: 2130.7756
Epoch [3/10], Step [400/469], Loss: 1653.4248
Epoch [4/10], Step [100/469], Loss: 1798.9209
Epoch [4/10], Step [200/469], Loss: 1866.7827
Epoch [4/10], Step [300/469], Loss: 1908.2474
Epoch [4/10], Step [400/469], Loss: 1792.4895
Epoch [5/10], Step [100/469], Loss: 1545.4459
Epoch [5/10], Step [200/469], Loss: 1589.6298
Epoch [5/10], Step [300/469], Loss: 1456.7727
Epoch [5/10], Step [400/469], Loss: 1554.9406
Epoch [6/10], Step [100/469], Loss: 1545.1510
Epoch [6/10], Step [200/469], Loss