In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np


In [None]:

# VAE model
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim):
        super(VAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=1)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = F.relu(self.fc2(z))
        return self.softmax(self.fc3(h))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Custom Likelihood Loss
class CustomLikelihoodLoss(nn.Module):
    def __init__(self, L_im):
        super(CustomLikelihoodLoss, self).__init__()
        self.L_im = torch.tensor(L_im, dtype=torch.float32)
    
    def forward(self, pi, mu, logvar):
        # Reconstruction loss (negative log-likelihood)
        inner_sum = torch.sum(pi * self.L_im, dim=1)
        recon_loss = -torch.sum(torch.log(inner_sum))
        
        # KL divergence
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        return recon_loss + kl_div

Adam(model.parameters(), lr=1e-3)



In [None]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(mnist_data, batch_size=128, shuffle=True)

# Generate L_im (for demonstration, we'll use random values)
L_im = np.random.rand(len(mnist_data), 10)  # 10 classes in MNIST

# Initialize model and optimizer
input_dim = 784  # 28x28 MNIST images
hidden_dim = 400
latent_dim = 20
output_dim = 10  # 10 classes in MNIST

model = VAE(input_dim, hidden_dim, latent_dim, output_dim)
criterion = CustomLikelihoodLoss(L_im)
optimizer = optim.Adam(model.parameters(), lr=1e-3)



In [None]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        optimizer.zero_grad()
        pi, mu, logvar = model(data)
        loss = criterion(pi, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item() / len(data):.4f}')
    
    print(f'====> Epoch: {epoch+1} Average loss: {train_loss / len(data_loader.dataset):.4f}')



In [None]:
# Print fitted pi for a few samples
model.eval()
with torch.no_grad():
    for i, (data, _) in enumerate(data_loader):
        if i >= 5:  # Print for 5 batches
            break
        pi, _, _ = model(data)
        print(f"\nBatch {i+1}:")
        print(f"Mean pi: {pi.mean(dim=0).numpy()}")
        print(f"Min pi: {pi.min(dim=0)[0].numpy()}")
        print(f"Max pi: {pi.max(dim=0)[0].numpy()}")
        print(f"Sum of pi (should be close to 1): {pi.sum(dim=1).mean().item():.4f}")



In [None]:
# Visualize reconstructions
import matplotlib.pyplot as plt

def plot_reconstructions(model, data, n=10):
    model.eval()
    with torch.no_grad():
        pi, _, _ = model(data[:n])
    
    fig, axes = plt.subplots(2, n, figsize=(1.5*n, 3))
    for i in range(n):
        # Original
        axes[0, i].imshow(data[i].squeeze(), cmap='gray')
        axes[0, i].axis('off')
        
        # Reconstruction (visualized as probabilities)
        axes[1, i].imshow(pi[i].view(28, 28), cmap='viridis')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Get a batch of test images
dataiter = iter(data_loader)
images, labels = next(dataiter)

# Plot reconstructions
plot_reconstructions(model, images)