<a href="https://colab.research.google.com/github/vijaygwu/classideas/blob/main/VAE1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Parameters
original_dim = 784
intermediate_dim = 256
latent_dim = 2
batch_size = 100
epochs = 50
learning_rate = 0.01

# MNIST DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: torch.flatten(x))])
train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(datasets.MNIST('./data', train=False, transform=transform), batch_size=batch_size)

# VAE Model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.fc1 = nn.Linear(original_dim, intermediate_dim)
        self.fc_mean = nn.Linear(intermediate_dim, latent_dim)
        self.fc_logvar = nn.Linear(intermediate_dim, latent_dim)

        # Decoder
        self.fc3 = nn.Linear(latent_dim, intermediate_dim)
        self.fc4 = nn.Linear(intermediate_dim, original_dim)

    def encode(self, x):
        h = torch.relu(self.fc1(x))
        return self.fc_mean(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Training loop
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch: {epoch} \t Loss: {train_loss / len(train_loader.dataset)}")

# Run the training
for epoch in range(1, epochs + 1):
    train(epoch)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 96929573.67it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 23567255.61it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 104236622.00it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3891834.27it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Epoch: 1 	 Loss: 192.29559204101562
Epoch: 2 	 Loss: 168.89347835286458
Epoch: 3 	 Loss: 164.20819223632813
Epoch: 4 	 Loss: 161.56139915364582
Epoch: 5 	 Loss: 159.78053754882814
Epoch: 6 	 Loss: 158.4600842122396
Epoch: 7 	 Loss: 157.44611544596353
Epoch: 8 	 Loss: 156.64754497070314
Epoch: 9 	 Loss: 155.96837330729167
Epoch: 10 	 Loss: 155.36627189127603
Epoch: 11 	 Loss: 154.8243448079427
Epoch: 12 	 Loss: 154.328567578125
Epoch: 13 	 Loss: 153.87484010416668
Epoch: 14 	 Loss: 153.4072259440104
Epoch: 15 	 Loss: 153.1374827311198
Epoch: 16 	 Loss: 152.78847809244792
Epoch: 17 	 Loss: 152.48030992838542
Epoch: 18 	 Loss: 152.15377875976563
Epoch: 19 	 Loss: 151.8882744466146
Epoch: 20 	 Loss: 151.62802752278645
Epoch: 21 	 Loss: 151.3557314127604
Epoch: 22 	 Loss: 151.21542177734375
Epoch: 23 	 Loss: 150.95098434244792
Epoch: 24 	 Loss: 150.7164887044271
Epoch: 25 	 Loss: 150.52951546223957
Epoch: 26 	 Loss: 150.37076280924478
Epoch: 27 	 Loss: 150.17106214192708
Epoch: 28 	 Loss: 1