In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import wandb


wandb.init(project="VQVAE", entity="soninidhiverma3")

# Define the Residual Block used in Encoder and Decoder
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
        )

    def forward(self, x):
        return x + self.block(x)

# Encoder definition
class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()
        self.initial_conv = nn.Conv2d(in_channels, num_hiddens // 2, kernel_size=4, stride=2, padding=1)
        self.down_conv = nn.Conv2d(num_hiddens // 2, num_hiddens, kernel_size=4, stride=2, padding=1)
        self.final_conv = nn.Conv2d(num_hiddens, num_hiddens, kernel_size=3, padding=1)
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_hiddens) for _ in range(num_residual_layers)]
        )

    def forward(self, x):
        x = nn.ReLU()(self.initial_conv(x))
        x = nn.ReLU()(self.down_conv(x))
        x = self.final_conv(x)
        return self.residual_blocks(x)

# Define VectorQuantizer
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.embedding.weight.data.uniform_(-1./num_embeddings, 1./num_embeddings)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        flat_input = inputs.view(-1, self.embedding_dim)
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                     + torch.sum(self.embedding.weight**2, dim=1)
                     - 2 * torch.matmul(flat_input, self.embedding.weight.t()))
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        quantized = self.embedding(encoding_indices).view(input_shape)
        e_latent_loss = torch.mean((quantized.detach() - inputs)**2)
        q_latent_loss = torch.mean((quantized - inputs.detach())**2)
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        quantized = inputs + (quantized - inputs).detach()
        return quantized.permute(0, 3, 1, 2), loss

# Decoder definition
class Decoder(nn.Module):
    def __init__(self, out_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()
        self.up_conv_1 = nn.ConvTranspose2d(num_hiddens, num_hiddens // 2, kernel_size=4, stride=2, padding=1)
        self.up_conv_2 = nn.ConvTranspose2d(num_hiddens // 2, out_channels, kernel_size=4, stride=2, padding=1)
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_hiddens) for _ in range(num_residual_layers)]
        )

    def forward(self, x):
        x = self.residual_blocks(x)
        x = nn.ReLU()(self.up_conv_1(x))
        return nn.Tanh()(self.up_conv_2(x))

# VQVAE Model
class VQVAE(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(in_channels, num_hiddens, num_residual_layers, num_residual_hiddens)
        self.vq = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
        self.decoder = Decoder(in_channels, num_hiddens, num_residual_layers, num_residual_hiddens)

    def forward(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss = self.vq(encoded)
        decoded = self.decoder(quantized)
        return decoded, vq_loss

# Data Loading
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(128, padding=4),
    transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = datasets.ImageFolder(root='/home/planck/NIDHI_SONI/DL_Ass/ass4_data/Train_data', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VQVAE(in_channels=3, num_hiddens=128, num_residual_layers=2, num_residual_hiddens=64, num_embeddings=256, embedding_dim=64, commitment_cost=0.25).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

for epoch in range(10):  # Number of epochs
    for inputs, _ in dataloader:
        inputs = inputs.to(device)
        optimizer.zero_grad()
        outputs, vq_loss = model(inputs)
        recon_loss = criterion(outputs, inputs)
        loss = recon_loss + vq_loss
        loss.backward()
        optimizer.step()
        wandb.log({"Reconstruction Loss": recon_loss.item(), "VQ Loss": vq_loss.item(), "Total Loss": loss.item()})
    print(f'Epoch {epoch+1}, Reconstruction Loss: {recon_loss.item()}, VQ Loss: {vq_loss.item()}')


In [None]:

# Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VQVAE(in_channels=3, num_hiddens=128, num_residual_layers=2, num_residual_hiddens=64, num_embeddings=256, embedding_dim=64, commitment_cost=0.25).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()


In [None]:

from torchvision.utils import save_image
import os
import os
import matplotlib.pyplot as plt
from torchvision.utils import save_image

# Function to display images
def show_images(original, reconstructed, n=5):
    plt.figure(figsize=(10, 4))
    for i in range(n):
        # Display original images
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(original[i].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
        plt.title("Original")
        plt.axis("off")

        # Display reconstructed images
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(reconstructed[i].permute(1, 2, 0).detach().cpu().numpy() * 0.5 + 0.5)
        plt.title("Reconstructed")
        plt.axis("off")
    plt.show()

def train_dataloader(model, dataloader, optimizer, criterion, device, num_epochs=10, save_dir='/home/planck/NIDHI_SONI/DL_Ass/saved_models/save_images'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for epoch in range(num_epochs):
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            optimizer.zero_grad()
            outputs, vq_loss = model(inputs)
            recon_loss = criterion(outputs, inputs)
            loss = recon_loss + vq_loss
            loss.backward()
            optimizer.step()

        # Visualize and save images at the end of each epoch
        if epoch % 1 == 0:
            show_images(inputs[:5], outputs[:5])
            save_image(outputs[:5].data.cpu(), os.path.join(save_dir, f'epoch_{epoch+1}_reconstructions.png'), nrow=5, normalize=True)

        print(f'Epoch {epoch+1}, Reconstruction Loss: {recon_loss.item()}, VQ Loss: {vq_loss.item()}')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VQVAE(in_channels=3, num_hiddens=128, num_residual_layers=2, num_residual_hiddens=64, num_embeddings=256, embedding_dim=64, commitment_cost=0.25).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()



train_dataloader(model, dataloader, optimizer, criterion, device, num_epochs=10, save_dir='/home/planck/NIDHI_SONI/DL_Ass/saved_models/save_images')


In [None]:
from PIL import Image
import os
import matplotlib.pyplot as plt

def display_saved_images(directory, num_images_per_row=5):
    # List all PNG files in the directory
    image_files = [f for f in os.listdir(directory) if f.endswith('.png')]
    image_files.sort()  # Sorting to maintain order

    # Calculate number of rows needed
    num_rows = len(image_files) // num_images_per_row + (1 if len(image_files) % num_images_per_row != 0 else 0)

    plt.figure(figsize=(num_images_per_row * 4, num_rows * 4))

    for index, file in enumerate(image_files):
        plt.subplot(num_rows, num_images_per_row, index + 1)
        img_path = os.path.join(directory, file)
        image = Image.open(img_path)
        plt.imshow(image)
        plt.title(f"Image {index + 1}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Example usage
display_saved_images('/home/planck/NIDHI_SONI/DL_Ass/saved_models/save_images')


In [None]:
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


def show_batch(dataloader):
    for images, _ in dataloader:  # Get one batch of images
        fig, ax = plt.subplots(figsize=(18, 10))
        ax.set_xticks([])
        ax.set_yticks([])

        ax.imshow(make_grid(images, nrow=16).permute(1, 2, 0) * 0.5 + 0.5)
        break  # only want to see one batch

# Call the function to display the batch
show_batch(dataloader)


In [None]:
import os


directory = "/home/planck/NIDHI_SONI/DL_Ass/saved_models/"


os.makedirs(directory, exist_ok=True)

# Save trained VQ-VAE model
torch.save(model.state_dict(), directory + "vq_vae_model.pth")


In [None]:
vq_vae_model = torch.load("/home/planck/NIDHI_SONI/DL_Ass/saved_models/vq_vae_model.pth")


In [None]:
test_dataset = datasets.ImageFolder(root='/home/planck/NIDHI_SONI/DL_Ass/ass4_data/Train_data', transform=transform)


test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)



In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def visualize_real_and_generated(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            generated_images, _ = model(images)


            comparison = torch.cat([images, generated_images])
            # Display the images
            plt.figure(figsize=(24, 12))
            plt.imshow(make_grid(comparison.cpu(), nrow=16).permute(1, 2, 0) * 0.5 + 0.5)
            plt.axis('off')
            plt.show()
            break  # Only show one batch

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualize_real_and_generated(model, test_dataloader, device)


In [None]:

vq_vae_model = VQVAE(in_channels=3, num_hiddens=128, num_residual_layers=2, num_residual_hiddens=32, num_embeddings=256, embedding_dim=64, commitment_cost=0.25)

# Load the model state
state_dict = torch.load("/home/planck/NIDHI_SONI/DL_Ass/saved_models/vq_vae_model.pth")
vq_vae_model.load_state_dict(state_dict)
vq_vae_model.eval()




In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def visualize_real_and_generated(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            generated_images, _ = model(images)


            comparison = torch.cat([images, generated_images])
            # Display the images
            plt.figure(figsize=(24, 12))
            plt.imshow(make_grid(comparison.cpu(), nrow=16).permute(1, 2, 0) * 0.5 + 0.5)
            plt.axis('off')
            plt.show()
            break  # Only show one batch

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualize_real_and_generated(model, test_dataloader, device)


In [None]:
import os

# Define the directory path
directory = "/home/planck/NIDHI_SONI/DL_Ass/saved_models/"

# Create the directory if it doesn't exist
os.makedirs(directory, exist_ok=True)

# Save trained VQ-VAE model
torch.save(model.state_dict(), directory + "vq_vae_model.pth")


In [None]:
vq_vae_model = torch.load("/home/planck/NIDHI_SONI/DL_Ass/saved_models/vq_vae_model.pth")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
#import wandb
from torchvision.utils import save_image
# Update PixelCNN architecture to ensure the final output matches expected quantized vector channels
class PixelCNN(nn.Module):
    def __init__(self, input_dim, dim=256, output_channels=128):  # Adjust `output_channels` to match VQVAE output
        super(PixelCNN, self).__init__()
        self.dim = dim
        self.input_dim = input_dim
        self.output_channels = output_channels
        self.layers = nn.Sequential(
            nn.Conv2d(self.input_dim, self.dim, kernel_size=7, padding=3, padding_mode='circular'),
            nn.ReLU(),
            nn.Conv2d(self.dim, self.dim, kernel_size=7, padding=3, padding_mode='circular'),
            nn.ReLU(),
            nn.Conv2d(self.dim, self.dim, kernel_size=7, padding=3, padding_mode='circular'),
            nn.ReLU(),
            nn.Conv2d(self.dim, self.output_channels, kernel_size=1)  # Ensure this matches the channels of VQVAE output
        )

    def forward(self, x):
        return self.layers(x)

# Ensure the training function handles the reshaping if necessary
def train_pixelcnn(pixelcnn, dataloader, vqvae, optimizer, device, epochs=10):
    pixelcnn.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, _ in dataloader:
            images = images.to(device)

            # Encode and quantize images using VQVAE
            with torch.no_grad():
                encoded = vqvae.encoder(images)
                quantized, _ = vqvae.vq(encoded)  # This is the correct input for PixelCNN

            optimizer.zero_grad()
            outputs = pixelcnn(quantized)
            loss = nn.MSELoss()(outputs, quantized)  # Ensure dimensions match here
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f'Epoch {epoch+1}: Loss: {total_loss / len(dataloader)}')


In [None]:
def generate_images(pixelcnn, vqvae, device, num_images=10):
    pixelcnn.eval()  # Set PixelCNN to evaluation mode
    with torch.no_grad():

        latent_vectors = torch.randn(num_images, vqvae.vq.embedding_dim, 1, 1).to(device)

        decoded_images = vqvae.decoder(latent_vectors)

        decoded_images = torch.cat([decoded_images] * pixelcnn.input_dim, dim=1)

        generated_images = pixelcnn(decoded_images)
    return generated_images


In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(root='/home/planck/NIDHI_SONI/DL_Ass/ass4_data/Train_data', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


vqvae = VQVAE(
    in_channels=3,
    num_hiddens=128,
    num_residual_layers=2,
    num_residual_hiddens=64,
    num_embeddings=256,
    embedding_dim=128,
    commitment_cost=0.25
).to(device)

pixelcnn = PixelCNN(input_dim=128, dim=256).to(device)

optimizer = optim.Adam(pixelcnn.parameters(), lr=0.001)

In [None]:
train_pixelcnn(pixelcnn, dataloader, vqvae, optimizer, device)

# Generate and save images
generated_images = generate_images(pixelcnn, vqvae, device, num_images=10)
save_image(generated_images, '/path/to/save/generated_images.png', nrow=5, normalize=True)
torch.save(pixelcnn.state_dict(), '/path/to/save/pixelcnn_model.pth')