<a href="https://colab.research.google.com/github/pejmanrasti/From_Shallow_to_Deep/blob/main/6_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from keras.callbacks import ModelCheckpoint, EarlyStopping

# Import Tensorboard
from torch.utils.tensorboard import SummaryWriter



In [None]:
# Define the VAE model
class VAE(nn.Module):
    def __init__(self, image_size):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )

        self.fc1 = nn.Linear(1024 * (image_size // 16)**2, 2048)
        self.fc2 = nn.Linear(2048, 2048)
        self.fc3 = nn.Linear(2048, 2048)
        self.fc4 = nn.Linear(2048, 2048)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2048, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return self.fc2(x), self.fc3(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        z = self.fc4(z)
        z = z.view(z.size(0), 2048, 1, 1)
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



In [None]:
# Load images from local folder
folder = 'path/to/folder'
transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])
dataset = ImageFolder(folder,transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
# Define the loss function and optimizer
image_size = 512
vae = VAE(image_size)
reconstruction_loss = nn.MSELoss()
kl_loss = lambda mu, logvar: -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
optimizer = optim.Adam(vae.parameters())

In [None]:
# Define early stopping and model checkpointing
early_stopping = EarlyStopping(patience=5)
checkpoint = ModelCheckpoint('vae.pth', save_best_only=True, save_weights_only=True)

# Define Tensorboard writer
writer = SummaryWriter()

In [None]:
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader):
        optimizer.zero_grad()
        images = images.to(device)
        recon_images, mu, logvar = vae(images)
        loss = reconstruction_loss(recon_images, images) + kl_loss(mu, logvar)
        loss.backward()
        optimizer.step()
        
        # Log images to Tensorboard
        writer.add_images('Generated Images', recon_images, epoch)
        
    # Check for early stopping and save best model
    early_stopping(loss, vae)
    checkpoint(loss, vae)
    if early_stopping.early_stop:
        print("Early stopping at epoch {}".format(epoch))
        break
    # Print the loss every 10 epochs and log to Tensorboard
    if (epoch+1) % 10 == 0:
        print("Epoch [{}/{}], Loss: {:.4f}".format(epoch+1, num_epochs, loss.item()))
        writer.add_scalar('Loss', loss.item(), epoch)

In [None]:
# Load the best model
vae.load_state_dict(torch.load('vae.pth'))

# Close Tensorboard writer
writer.close()