In [None]:
import torch
import torchvision
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F

from pprint import pprint
import numpy as np
import torchinfo
from matplotlib import pyplot as plt
import time
import cv2
from PIL import Image
import time
from tqdm import tqdm
import torchinfo 
from pprint import pprint

# this is same for all the code

# Dataset class implementation was done by Anshul EE20BTECH1104
# Residual block was written by Dhruv Srikanth EE20BTECH11014
# various functions like mse, mae, ssim were written by Dhruv Srikanth and Anshul Gupta
# Basic architecture of the model was taken from internet and then later modified by Utkarsh Doshi E20BTECH11052

https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py

In [None]:
# Anshul EE20BTECH11004

channels_img = 3
batch_size = 32
data_dir = r'C:/Users/utkar/Desktop/ivp/FFHQ64x64/'
#data_dir_new = r'C:/Users/utkar/Desktop/ML/Dataset/Celeb_dataset/500_img'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop((64, 64)),
    #transforms.Normalize(())
    ])

dataset = datasets.ImageFolder(root=data_dir, transform=transform)

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    pin_memory=True,
    shuffle=True
)

device = 'cuda' #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
!nvidia-smi

In [None]:

## calculate the variance of the dataset
#variance = torch.zeros((channels_img, 64, 64))
#for images, _ in train_loader:
#    variance += np.var(images).mean(dim=0)
#variance /= len(train_loader)
#
#print(f"Dataset variance: {variance}")
#

In [None]:
# Anshul EE20BTECH11004

img = dataset[10500][0].permute(1, 2, 0)

img_blur = torch.tensor(cv2.blur(np.array(img), (2, 2)))

img_d = img - img_blur
imgs = np.concatenate((img, img_blur, img_d*3), 1)
plt.figure(figsize=(15, 5))
plt.axis('off')
plt.imshow(imgs)
plt.show()

plt.axis('off')
plt.imshow(img_d + img_blur)
plt.show()

plt.axis('off')
plt.imshow(img_d)
plt.show()

In [None]:


class AttentionBlock(nn.Module):
    def __init__(self, in_channels, kernel_size=1):
        super(AttentionBlock, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=kernel_size)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=kernel_size)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)


    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # Project features onto query, key, and value
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        
        # Compute attention map
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        
        # Apply attention to value
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)
        
        # Apply gamma scaling and add residual connection
        out = self.gamma * out + x
        
        return out


In [None]:
# Dhruv Srikanth EE20BTECH11041

# residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super(ResidualBlock, self).__init__()
        self.ResBlock = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=2, bias=False, dilation=2),
            nn.ReLU(True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=False)
        )
        
    def forward(self, x):
        return x + self.ResBlock(x)
    

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25):
        super(VectorQuantizer, self).__init__()
        self.K = num_embeddings
        self.D = embedding_dim
        self.beta = beta
        
        self.embedding = nn.Embedding(self.K, self.D)
        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)  # weight initializing
        
    def forward(self, latents):
        latents = latents.permute(0, 2, 3, 1).contiguous()  # taking dimension (channel) to right end
        latents_shape = latents.shape
        flat_latents = latents.view(-1, self.D)
        
        # computing L2 distance
        dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - \
               2 * torch.matmul(flat_latents, self.embedding.weight.t())  # [BHW x K]
            
        # Get the encoding that has the min distance
        encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHW, 1]
        
        # Convert to one-hot encodings
        device = latents.device
        encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
        encoding_one_hot.scatter_(1, encoding_inds, 1)  # [BHW x K]

        # Quantize the latents
        quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)  # [BHW, D]
        quantized_latents = quantized_latents.view(latents_shape)  # [B x H x W x D]
        
        # Compute the VQ Losses
        commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
        embedding_loss = F.mse_loss(quantized_latents, latents.detach())
        
        vq_loss = commitment_loss * self.beta + embedding_loss

        # Add the residue back to the latents
        quantized_latents = latents + (quantized_latents - latents).detach()

        return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss  # [B x D x H x W]

# we changed the model architecture here, iteratively done by Utkarsh Doshi (major), Anshul, dhruv  
# idea for decreasing the number of code book vector is what we agreed upon here 

In [None]:
# VQ-VAE
class VQVAE(nn.Module):
    def __init__(self,
                 in_c,
                 embedding_dim,
                 num_embeddings,
                 hidden_dims=None,
                 beta=0.25,
                 img_size=64,
                 ):
        super(VQVAE, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.num_embedding = num_embeddings
        self.img_size = img_size
        self.beta = beta
    
        modules = []
        if hidden_dims is None:
            hidden_dims = [128, 256]
        
        for i, hidden_dim in enumerate(hidden_dims):
            modules.append(nn.Sequential(
                nn.Conv2d(in_c, out_channels=hidden_dim, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                #AttentionBlock(hidden_dim),
                nn.LeakyReLU(),
                nn.Conv2d(hidden_dim, out_channels=hidden_dim, kernel_size=3, stride=2, padding=1),
                nn.LeakyReLU(),
            ))
            # modules.append(AttentionBlock(hidden_dim))
            in_c = hidden_dim
        
        for _ in range(3):
            modules.append(ResidualBlock(in_c, in_c))
            modules.append(AttentionBlock(in_c))
        
        modules.append(nn.LeakyReLU())
        
        modules.append(nn.Sequential(
                nn.Conv2d(in_c, embedding_dim, kernel_size=1, stride=1),
                nn.LeakyReLU()
            ))
        modules.append(nn.Identity())
        
        self.encoder = nn.Sequential(*modules)
        
        self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, self.beta)
        
        # Build Decoder
        modules = []
        
        modules.append(nn.Sequential(
                nn.Conv2d(embedding_dim, in_c, kernel_size=1, stride=1),
                nn.LeakyReLU()
            ))
        
        for _ in range(3):
            modules.append(ResidualBlock(in_c,  in_c))
            modules.append(AttentionBlock(in_c))
        
        hidden_dims.reverse()
        for i, hidden_dim in enumerate(hidden_dims):
            modules.append(nn.Sequential(
                nn.Conv2d(in_c, out_channels=hidden_dim, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                #AttentionBlock(hidden_dim),
                nn.LeakyReLU(),
                nn.ConvTranspose2d(hidden_dim, out_channels=hidden_dim, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU(),
            ))
            # modules.append(AttentionBlock(hidden_dim))
            in_c = hidden_dim
        
        modules.append(nn.Sequential(
            nn.Conv2d(hidden_dims[1], out_channels=3, kernel_size=1, stride=1, padding=0),
            nn.Tanh()
        ))
        
        self.decoder = nn.Sequential(*modules)
         
    def encode(self, x):
        x = self.encoder(x)
        return x
    
    def decode(self, z):
        z = self.decoder(z)
        return z
    
    def forward(self, x):
        encodings = self.encoder(x)
        quantized_vec, vq_loss = self.vq_layer(encodings)
        return [self.decoder(quantized_vec), x, vq_loss]
    
    def loss_f(self, *args, **kwargs):
        recons = args[0]
        inputs = args[1]
        vq_loss = args[2]
        
        recons_loss = F.mse_loss(recons, inputs)
        
        loss = recons_loss + vq_loss
        return loss, recons_loss, vq_loss
    
    def sample(self, num_samples, current_device):
        raise Warning('VQVAE sampler is not implemented.')

    def generate(self, x):
        return self.forward(x)[0]


In [None]:
# in_c, embedding_dim, num_embeddings, hidden_dims, beta=0.25, img_size=64,
vqvae = VQVAE(3, 64, 256).to(device)
# vqvae = torch.load('C:/Users/utkar/Desktop/ivp/models tried/day-7/vqvae_day7_epoch100.pt')

In [None]:
x = torch.rand((1, 3, 64, 64)).to('cuda')
print(x.shape)
encode_vec = vqvae.encoder(x)
encoded = vqvae.vq_layer(encode_vec)[0]
vqvae.decoder(encoded).shape

In [None]:
pprint(torchinfo.summary(vqvae, (1, 3, 64, 64)))

In [None]:
# 

num_training_updates = 15000
p = r"C:\Users\utkar\Desktop\ivp\models tried\day-8\vqvae_day8_epoch35.pt"
vqvae = torch.load(p)
# num_hiddens = 128
# num_residual_hiddens = 32
# num_residual_layers = 2
# 
# embedding_dim = 64
# num_embeddings = 512
# commitment_cost = 0.25
decay = 0.99

vqvae.eval()

learning_rate = 1e-4
optimizer = optim.Adam(vqvae.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

from torchmetrics import MeanSquaredError
from torchmetrics import MeanAbsoluteError
from torchmetrics import StructuralSimilarityIndexMeasure
import torch

def mse(img1, img2):
    target = img1
    preds = img2
    mean_squared_error = MeanSquaredError()
    mean_s = mean_squared_error(preds, target)
    mean_absolute_error = MeanAbsoluteError()
    mean_abs = mean_absolute_error(preds, target)
    return mean_s, mean_abs


from torchmetrics import StructuralSimilarityIndexMeasure
import torch
def ssim(img1, img2):
    target = img1
    preds = img2
    ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
    return ssim(preds, target)
    
def display_func():
    with torch.no_grad():
        n = 10001
        img = torch.unsqueeze(train_loader.dataset[n][0], 0).to('cuda')
        out = vqvae(img)[0]
        
        print(mse(img.cpu(), out.cpu()))
        print(ssim(img.cpu(), out.cpu()))
        pic = np.transpose(np.concatenate((img.cpu(), out.cpu()), 3)[0], (1, 2, 0))
        plt.imshow(pic)
        plt.show()

display_func()

In [None]:
from tqdm import tqdm

def train_vqvae(model, train_loader, optimizer, criterion, device, epochs=10):
    model.train()
    
    train_loss = []
    recon_loss = []
    vq_loss = []
    
    for epoch in range(epochs):
        epoch_train_loss = 0
        epoch_recon_loss = 0
        epoch_vq_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch_idx, (data, _) in enumerate(progress_bar):
            data = data.to(device)

            optimizer.zero_grad()

            outputs = model(data)
            loss, recon_loss_batch, vq_loss_batch = model.loss_f(*outputs)

            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()
            epoch_recon_loss += recon_loss_batch.item()
            epoch_vq_loss += vq_loss_batch.item()

            progress_bar.set_postfix({
                "Train Loss": epoch_train_loss / (batch_idx + 1),
                "Recon Loss": epoch_recon_loss / (batch_idx + 1),
                "VQ Loss": epoch_vq_loss / (batch_idx + 1),
            })
            
        train_loss.append(epoch_train_loss / len(train_loader))
        recon_loss.append(epoch_recon_loss / len(train_loader))
        vq_loss.append(epoch_vq_loss / len(train_loader))
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss[-1]:.4f}, Recon Loss: {recon_loss[-1]:.4f}, VQ Loss: {vq_loss[-1]:.4f}")
        display_func()
        if epoch % 5 == 0:
            PATH = r'C:/Users/utkar/Desktop/ivp/models tried/day-8/vqvae_day8_epoch' + str(epoch) + '.pt'
            torch.save(vqvae, PATH)
    return train_loss, recon_loss, vq_loss


In [None]:
train_vqvae(vqvae, train_loader, optimizer, criterion, device, epochs=40)

In [None]:
PATH = r'C:/Users/utkar/Desktop/ivp/models tried/day-8/vqvae_day8.pt'
torch.save(vqvae, PATH)

In [None]:
def display_func1():
    with torch.no_grad():
        n = np.random.randint(0, 50000)
        img = torch.unsqueeze(train_loader.dataset[n][0], 0).to('cuda')
        out = vqvae(img)[0]
        img_d = img - out
        pic = np.transpose(np.concatenate((img.cpu(), out.cpu(), img_d.cpu()*4), 3)[0], (1, 2, 0))
        plt.figure(figsize=(15, 5))
        plt.imshow(pic)
        plt.show()
display_func1()

In [None]:
def display_func():
    with torch.no_grad():
        n = np.random.randint(0, 50000)
        print(n)
        img = torch.unsqueeze(train_loader.dataset[n][0], 0).to('cuda')
        out = vqvae(img)[0]
        pic = np.transpose(np.concatenate((img.cpu(), out.cpu()), 3)[0], (1, 2, 0))
        plt.imshow(pic)
        plt.show()