In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

# Hyperparameters
latent_size = 10
batch_size = 128
learning_rate = 0.001
num_epochs = 10

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Data loading
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

In [8]:
class VAE(nn.Module):
    # ... (Implement encoder, decoder, reparameterize, and forward methods)
    def __init__(self):
        super(VAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),  # Flatten input to 784
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),  # Adjust output dimension as needed
        )
        self.fc_mu = nn.Linear(128, latent_size)
        self.fc_var = nn.Linear(128, latent_size)
        
         # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, 784),
            nn.ReLU(),
            nn.Unflatten(1, (28, 28)),  # Reshape to 28x28
            nn.Sequential(  # Pixel-wise layers
                nn.Linear(28, 28),
                nn.ReLU(),
                nn.Linear(28, 28),
                nn.ReLU(),
                nn.Linear(28, 28),
                nn.Sigmoid()  # Output pixel values between 0 and 1
            )
        )
    def encode(self, x):
        return self.fc_mu(self.encoder(x)), self.fc_var(self.encoder(x))

    def decode(self, z):
        return self.decoder(z)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def loss_function(recon_x, x, mu, logvar):
    reconstruction_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + kl_divergence

In [9]:
# Initialize model and optimizer
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (x, _) in enumerate(train_loader):
        x = x.to(device)

        recon_batch, mu, logvar = model(x)
        loss = loss_function(recon_batch, x, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, len(train_loader), loss.item()))

RuntimeError: unflatten: Provided sizes [28, 28] don't multiply up to the size of dim 1 (128) in the input tensor