In [55]:
import torch
import torch.nn as nn

In [56]:
class VAE(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Encoder part
        self.encode = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
           nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 1024)  # Corrected to match the output size after Conv2d layers
        )

        # Latent space
        self.mu = nn.Linear(1024, 256)
        self.logvar = nn.Linear(1024, 256)

        # Decoder part
        self.linear_decode = nn.Sequential(
            nn.Linear(256, 1024),
            nn.Linear(1024, 128 * 4* 4),  # Adjusted to match the flatten output size
        )

        self.decode = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # Output: (batch_size, 64, 8, 8)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),   # Output: (batch_size, 32, 16, 16)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1  ),    # Output: (batch_size, 3, 28, 28)
            nn.Sigmoid(),
        )

    def encoder(self, x):
        return self.encode(x)

    def decoder(self, x):
        linear_decode = self.linear_decode(x)
        # Reshaping to (batch_size, 128, 4, 4)
        linear_decode = linear_decode.view(linear_decode.size(0), 128, 4, 4)
        return self.decode(linear_decode)

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        encode = self.encoder(x)  # Encoder output
        mu = self.mu(encode)      # Latent mean
        logvar = self.logvar(encode)  # Latent log variance
        z = self.reparameterize(mu, logvar)  # Latent vector z
        decode = self.decoder(z)  # Reconstructed image
        return decode, mu, logvar

# Create the model and test it with a dummy input
model = VAE()
decode, mu, logvar = model(torch.randn(size=(32, 3, 32, 32)))

# Print the output shape of the decoded image
print(decode.shape)


torch.Size([32, 3, 32, 32])


In [None]:
# Loss function
def loss_function(recon_x, x, mean, logvar):
    # Applying sigmoid activation to the reconstructed image to ensure values are between 0 and 1
    recon_x = torch.sigmoid(recon_x)
    BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 3*32*32), x.view(-1, 3*32*32), reduction='sum')

    # KL divergence remains the same
    KL = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    return BCE + KL

# Training the VAE
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Train loop
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mean, logvar = model(data)
        loss = loss_function(recon_batch, data, mean, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}")