<div style="background-color:yellow; text-align:center; text-align:center; padding:40px;">
<h1  style="color:red;" > MMI-714 : Generative Models for Multimedia </h1>   
<h2  style="color:red;" > Final Report </h2>
<br>
<h3  style="color:red;  font-style:italic;" > Exploration of the Intuitive Physics through the 
Latent Space Disentanglement</h3>
<br>
<h4  style="color:red;" >Turgay Yıldız</h4>
<br>
<h4  style="color:red;" >Cognitive Sciences,  Middle East Technical University (METU)</h4>
</div>


In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.checkpoint import checkpoint 
from torch.amp import GradScaler, autocast
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
from torch.nn.functional import relu
from torchvision import models
from torch.amp import GradScaler, autocast
from PIL import Image
from einops.layers.torch import Rearrange
from einops import repeat
from torch import Tensor
import tqdm
import pandas as pd
import cv2
import os 
import pygame
import pymunk
from pymunk import pygame_util

In [None]:
#path = "/home/turgay/Turgay/Academic/2024-2025/Fall/Generative_Models/Final_Project/Data/6_pairs/" 

In [None]:
#path_data = path + "concatenated_data.npy"

In [None]:
#X = np.load(path_data, mmap_mode="r")

In [None]:
X  =  X.reshape(3529, 6, 3, 224, 224)

In [None]:
X_image =  X[:, 0, :, :, :] 
X_color =  X[:, 1, :, :, :] 
X_order =  X[:, 2, :, :, :] 

X_image2 =  X[:, 3, :, :, :]
X_color2 =  X[:, 4, :, :, :]
X_order2 =  X[:, 5, :, :, :]

In [None]:
X_image.shape

In [None]:
def plot_img(data, row, col, size1, size2, c_map = None): 
    
    if (row == 1) and (col == 1): 
        fig, ax = plt.subplots(1, 1, figsize=(size1, size2))
        ax.imshow(data, cmap=c_map) 
        ax.set_axis_off() 

    elif (row == 1) and (col > 1):
        fig, axes = plt.subplots(1, col, figsize=(size1, size2))
        for i in range(col):
            axes[i].imshow(data[i], cmap=c_map) 
            axes[i].set_axis_off() 
            axes[i].set_title(f"Image {i}")

    else:
        fig, axes = plt.subplots(row, col, figsize=(size1, size2))
        axes = axes.flatten()  # Flatten the axes to make indexing easier
        for i in range(row * col):
            if i < len(data):  # Ensure you do not exceed the length of data
                axes[i].imshow(data[i], cmap=c_map) 
                axes[i].set_axis_off()  
                axes[i].set_title(f"Image {i}")
            else:
                axes[i].axis('off')  # Turn off unused axes

    plt.tight_layout()  
    plt.show()


In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 16, embed_dim = 64):
        self.patch_size = patch_size
        super().__init__()
        
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, embed_dim)
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x

In [None]:
class ImageDataset(Dataset):
    def __init__(self, img_dir, csv_dir):
        
        self.img_dir     =    img_dir
        self.csv_dir     =    csv_dir
        
        self.img_files   =    sorted([f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        self.csv_files   =    sorted([f for f in os.listdir(csv_dir) if f.endswith('.csv')]) 
        
        assert len(self.img_files) == len(self.csv_files), "Mismatch between image and CSV files count." 
        
    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        csv_path = os.path.join(self.csv_dir, self.csv_files[idx])

        img         = Image.open(img_path).convert("RGB")  # Ensure 3-channel RGB
        img_array   = np.array(img, dtype=np.float32) / 255.0  # Normalize to [0, 1]
        img_tensor  = torch.tensor(img_array).reshape(3, 512, 512)  # Convert to PyTorch tensor and reshape to (C, H, W)


        csv_data   = pd.read_csv(csv_path).values  # Load as NumPy array
        csv_tensor = torch.tensor(csv_data, dtype=torch.float32)  # Convert to PyTorch tensor

        return img_tensor, csv_tensor

In [None]:
img_path = "/home/turgay/falling_objects_dataset/img_files/"
csv_path = "/home/turgay/falling_objects_dataset/csv_files/"

In [None]:
#print(os.listdir(img_path)[:3])

In [None]:
dataset = ImageDataset(img_dir=img_path, csv_dir=csv_path)

In [None]:
class MyDataset(Dataset):
    
    def __init__(self, X_image,   transform=None):
        
        super(MyDataset, self).__init__()
        
        self.X_image     =    X_image   /  255.0
        self.transform   =    transform
    
    def __len__(self):
        return len(self.X_image)
    
    def __getitem__(self, idx):
        
        img   =   torch.tensor(self.X_image[idx] ,  dtype=torch.float32)
        
        if self.transform:
            img = self.transform(img)
        
        return img 


In [None]:
transform = transforms.Compose([
    
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Randomly crop to 224x224, scale between 80% to 100%
    transforms.RandomHorizontalFlip(p=0.5),               # Randomly flip the image horizontally
    transforms.RandomRotation(degrees=15),                # Randomly rotate the image by ±15 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)  # Random color adjustments
])


In [None]:
#dataset    =  MyDataset(X_image, transform=None)

In [None]:
train_size =  int(0.8 * len(dataset))
val_size   =  len(dataset) - train_size

In [None]:
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) 

In [None]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=16, shuffle=True)

In [None]:
for i, j in train_loader:
    print(i.shape)
    break

In [None]:
i.max(), i.min()

In [None]:
i  =  i.reshape(-1, 512, 512, 3)

In [None]:
plot_img(i[:10], 2, 5, 5, 5) 

In [None]:
j.shape

<div span style="background-color:yellow;    color:red;      text-align:center;    padding:5px;">
<h2>  Physics Model </h2>
</div>

In [None]:
def simulate_with_dynamics(csv_tensor, iterations=1000):

    pygame.init()

    # Initialize Pymunk
    space = pymunk.Space()
    space.gravity = (0, 9.8)

    screen_width, screen_height = 512, 512
    screen = pygame.Surface((screen_width, screen_height))
    draw_options = pymunk.pygame_util.DrawOptions(screen)

    # Add ground
    ground = pymunk.Segment(space.static_body, (0, screen_height), (screen_width, screen_height), 0)
    space.add(ground)

    # Create blocks from CSV tensor
    for block_data in csv_tensor:
        pos_x, pos_y, width, height, mass, color_r, color_g, color_b, time = block_data

        moment = pymunk.moment_for_box(mass, (width, height))
        body = pymunk.Body(mass, moment)
        body.position = (pos_x, pos_y)

        shape       = pymunk.Poly.create_box(body, (width, height))
        shape.color = (color_r, color_g, color_b, 1.0)  # Normalize RGBA
        space.add(body, shape)

    iteration = 0
    while iteration <= 1000:  
        screen.fill((255, 255, 255))  # Clear the screen to white
        space.debug_draw(draw_options)  # Draw the physics objects

        if (iteration  ==  time * 50):
            # Capture the current screen as a numpy array
            img_array = pygame.surfarray.array3d(screen)  # Shape: (width, height, 3)
            img_array = np.transpose(img_array, (1, 0, 2))  # Transpose to get the shape (height, width, 3)
            
            # Normalize the pixel values to [0, 1]
            img_array = img_array.astype(np.float32) / 255.0

            return img_array
            break 
        
        space.step(1 / 60.0)  # Step the physics simulation
        iteration += 1

    pygame.quit()

In [None]:
j[0].shape

In [None]:
rec_from_parameters   =  simulate_with_dynamics(j[0])

In [None]:
rec_from_parameters[0].min(), rec_from_parameters[0].max()

In [None]:
i[0].shape, rec_from_parameters.shape

In [None]:
two_imgs   =  np.zeros((2, 512, 512, 3))

two_imgs[0]  =  i[0]
two_imgs[1]  =  rec_from_parameters

In [None]:
plot_img(two_imgs, 1, 2, 8, 6) 

<div span style="background-color:yellow;    color:red;      text-align:center;    padding:5px;">
<h2>  Physics Model for Batches </h2>
</div>

In [None]:
def simulate_batch(csv_tensors, batch_size):
    
    screen_width, screen_height = 512, 512
    images = []  # To store the generated images

    pygame.init()

    for csv_tensor in csv_tensors:
        
        # Initialize Pygame and Pymunk for each tensor
        screen = pygame.Surface((screen_width, screen_height))
        space = pymunk.Space()
        draw_options = pymunk.pygame_util.DrawOptions(screen)

        # Add ground
        ground = pymunk.Segment(space.static_body, (0, screen_height), (screen_width, screen_height), 0)
        space.add(ground)

        # Create blocks
        for block_data in csv_tensor:
            pos_x, pos_y, width, height, mass, color_r, color_g, color_b, time = block_data

            # Create a dynamic body and its shape
            moment = pymunk.moment_for_box(mass, (width, height))
            body = pymunk.Body(mass, moment)
            body.position = (pos_x, pos_y)
            shape = pymunk.Poly.create_box(body, (width, height))

            shape.color = (color_r, color_g, color_b, 1.0)  # Normalize RGBA
            space.add(body, shape)
    
        # Simulate dynamics
        iteration = 0
        captured  = False  # Track whether an image is captured
        
        while iteration <= 1000:
            screen.fill((255, 255, 255))  # Clear the screen to white
            space.debug_draw(draw_options)  # Draw the physics objects

            # Capture image at the specified time
            if iteration == csv_tensor[0][-1] * 50:  # Assuming `time` is consistent across blocks
                img_array = pygame.surfarray.array3d(screen)
                img_array = np.transpose(img_array, (1, 0, 2))  # Transpose to (height, width, 3)
                img_array = img_array.astype(np.float32) / 255.0  # Normalize pixel values to [0, 1]
                images.append(img_array)
                captured = True
                break

            space.step(1 / 60.0)  # Step the physics simulation
            iteration += 1

        # If no image was captured for this tensor, append a blank image
        if not captured:
            images.append(np.zeros((screen_height, screen_width, 3), dtype=np.float32))

    pygame.quit()  # Quit Pygame after all tensors are processed

    # Convert the list of images to a batch array
    result = np.array(images).reshape(-1, 3, screen_height, screen_width)

    # If the batch is smaller than `batch_size`, pad with blank images
    if len(images) < batch_size:
        padding = np.zeros((batch_size - len(images), 3, screen_height, screen_width), dtype=np.float32)
        result = np.vstack((result, padding))

    result = torch.tensor(result, dtype=torch.float32)  # Convert to PyTorch tensor
    return result

    

In [None]:
j.shape

In [None]:
rec_imgs   =   simulate_batch(j, 16) 

In [None]:
rec_imgs.shape

In [None]:
img_recon             =   simulate_batch(params, 16) 

In [None]:
img_recon.shape

<div span style="background-color:yellow;    color:red;      text-align:center;    padding:5px;">
<h2>  Decoder : ViT </h2>
</div>

In [None]:
class Attention(nn.Module):
    def __init__(self, embed_dim, n_heads, dropout):
        super().__init__()
        
        self.n_heads   =    n_heads
        self.att       =    torch.nn.MultiheadAttention(embed_dim =  embed_dim,
                                                       num_heads  =  n_heads,
                                                       dropout    =  dropout ) 
        self.q = torch.nn.Linear(embed_dim, embed_dim)
        self.k = torch.nn.Linear(embed_dim, embed_dim)
        self.v = torch.nn.Linear(embed_dim, embed_dim) 

    def forward(self, x):
        
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        
        attn_output, attn_output_weights = self.att(x, x, x)
        
        return attn_output

In [None]:
#Attention(embed_dim=256, n_heads=4, dropout=0.)(torch.ones((1, 196, 256))).shape

In [None]:
class PreNorm(nn.Module):
    def __init__(self, embed_dim, fn):
        super().__init__()
        
        self.norm   =   nn.LayerNorm(embed_dim)
        self.fn     =   fn
        
    def forward(self, x, **kwargs):
        
        return self.fn(self.norm(x), **kwargs) 
    

In [None]:
#norm = PreNorm(embed_dim=256, fn=Attention(embed_dim=256, n_heads=4, dropout=0.))

In [None]:
#norm(torch.ones((1, 196, 256))).shape

In [None]:
class FeedForward(nn.Sequential):
    
    def __init__(self, embed_dim, hidden_dim, dropout = 0.):    
        
        super().__init__(
            
            nn.Linear(embed_dim, hidden_dim),       #   hidden_dim   =   2   *   embed_dim
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),       
            nn.Dropout(dropout) 
        )

In [None]:
#ff = FeedForward(embed_dim=256, hidden_dim=512)

In [None]:
#ff(torch.ones((1, 196, 256))).shape

In [None]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        
        self.fn = fn

    def forward(self, x, **kwargs):
        
        res   =   x
        x     =   self.fn(x, **kwargs)
        x    +=   res
        
        return x

In [None]:
#residual_att = ResidualAdd(Attention(embed_dim=256, n_heads=4, dropout=0.))

In [None]:
#residual_att(torch.ones((1, 196, 256))).shape

<div span style="background-color:yellow;    color:red;      text-align:center;    padding:5px;">
<h2>  Full-Model </h2>
</div>

In [None]:
class PINN(nn.Module):
    
    def __init__(self, ch=3, img_size=224, patch_size=16, embed_dim=64, latent_dim = 64, n_layers=6, dropout=0.1, heads=8):
        super(PINN, self).__init__()

        # Attributes
        self.channels    =   ch
        self.height      =   img_size
        self.width       =   img_size
        self.patch_size  =   patch_size
        self.n_layers    =   n_layers
        self.embed_dim   =   embed_dim
        self.latent_dim  =   latent_dim
        
        
        num_patches      =   (img_size // patch_size) ** 2
        self.num_patches =   num_patches
        
#-----------------------------------------------------------------------------------------------------------
        # Patching
        self.patch_embedding = PatchEmbedding(in_channels =  ch,
                                              patch_size  =  patch_size,
                                              embed_dim   =  self.embed_dim)
#-----------------------------------------------------------------------------------------------------------        
        self.pos_embedding    =    nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
#-----------------------------------------------------------------------------------------------------------
        # Transformer Encoder
        self.encoder_layers = nn.ModuleList([])
        
        for _ in range(n_layers):
            transformer_encoder_block = nn.Sequential(
                                            ResidualAdd(PreNorm(self.embed_dim, Attention(self.embed_dim, n_heads = heads, dropout = dropout))),
                                            ResidualAdd(PreNorm(self.embed_dim, FeedForward(self.embed_dim, 2 * self.embed_dim, dropout = dropout)))
                                            )
            self.encoder_layers.append(transformer_encoder_block) #  (self.num_patches,   self.embed_dim)  
#-----------------------------------------------------------------------------------------------------------

        self.fc_encoder_mu = nn.Sequential(
            
                                        nn.BatchNorm1d(self.num_patches * self.embed_dim),
                                        nn.Linear(self.num_patches * self.embed_dim,   512),
                                        nn.LeakyReLU(),
                                         
                                        nn.BatchNorm1d(512),
                                        nn.Linear(512, 256),
                                        nn.LeakyReLU(),
            
                                        nn.BatchNorm1d(256),
                                        nn.Linear(256,  self.latent_dim),
                                        nn.LeakyReLU()
                                        )
    
        self.fc_encoder_logvar = nn.Sequential(
            
                                        nn.BatchNorm1d(self.num_patches * self.embed_dim),
                                        nn.Linear(self.num_patches * self.embed_dim,   512),
                                        nn.LeakyReLU(),
                                         
                                        nn.BatchNorm1d(512),
                                        nn.Linear(512, 256),
                                        nn.LeakyReLU(),
            
                                        nn.BatchNorm1d(256),
                                        nn.Linear(256,  self.latent_dim),
                                        nn.LeakyReLU()
                                        )
#-----------------------------------------------------------------------------------------------------------
        self.fc_decoder = nn.Sequential(
            
                                        nn.Linear(self.latent_dim,         self.latent_dim * 2),
                                        nn.LeakyReLU(),
                                        nn.BatchNorm1d(self.latent_dim*2),
            
                                        nn.Linear(self.latent_dim*2,         self.latent_dim),
                                        nn.LeakyReLU(),
                                        nn.BatchNorm1d(self.latent_dim),
            
                                        nn.Linear(self.latent_dim,        45),
                                        nn.Softplus() 
                                        )

#----------------------------------------------------------------------------------------------------------------------------#
#-------------------------------------------           FUNCTIONS               ----------------------------------------------#
#----------------------------------------------------------------------------------------------------------------------------#

    def encode(self, x):
        
        x          =   self.patch_embedding(x)
        b, n, _    =   x.shape 

        x         +=   self.pos_embedding[:, : n]

        # Transformer Encoder layers
        for i in range(self.n_layers):
            x = self.encoder_layers[i](x)
            
        x      =   x.reshape(-1, self.num_patches * self.embed_dim)
            
        mu     = self.fc_encoder_mu(x)
        logvar = self.fc_encoder_logvar(x) 
        
        return mu , logvar
#----------------------------------------------------------------------------------------------------------------------------#
    
    def reparameterize(self, mu, logvar):
        
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std) 
        z   = mu + eps * std 
        
        return z  

#----------------------------------------------------------------------------------------------------------------------------#
      
    def decode(self, z):
        
            x      =   self.fc_decoder(z)
            x      =   x.reshape(-1, 5, 9)
            
            return x

#----------------------------------------------------------------------------------------------------------------------------#
    def forward(self, x):
        
        mu, logvar        =    self.encode(x)
        z                 =    self.reparameterize(mu, logvar)
        params            =    self.decode(z) 

        kl_loss           =    -0.5 * torch.sum(1 + logvar - mu**2 - torch.exp(logvar), dim=-1)
        
        
        return params.reshape(-1, 5, 9) , kl_loss.mean() 


In [None]:
#model   =   PINN(ch=3, img_size=512, patch_size=16, embed_dim=64, latent_dim = 64, n_layers=6, dropout=0.1, heads=8).to(device)

In [None]:
#mu, logvar      =    model.encode(torch.randn(10, 3, 512, 512).to(device) )

In [None]:
#z               =    model.reparameterize(mu, logvar)

In [None]:
#z.shape

In [None]:
#params            =    model.decode(z) 

In [None]:
#params.shape

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

In [None]:
model   =   PINN(ch=3, img_size=512, patch_size=16, embed_dim=64, latent_dim = 64, n_layers=6, dropout=0.1, heads=8).to(device)

In [None]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Number of trainable parameters in the model:", num_params)

In [None]:
learning_rate = 0.0001

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
criterion = nn.MSELoss()

In [None]:
path_model   =  "/home/turgay/Turgay/Academic/2024-2025/Fall/Generative_Models/Final_Project/Final_Report/codes/weights/PINN_transformers_weights_0.pth"
path_losses  =  "/home/turgay/Turgay/Academic/2024-2025/Fall/Generative_Models/Final_Project/Final_Report/codes/weights/PINN_transformers_losses_0.pth"

In [None]:
torch.save({
            'model_state_dict': model.state_dict(),
            'best_loss'       : 9999999999999,
        }, path_model)   

In [None]:
torch.save({
            'train_loss_rec'   : [],
            'train_loss_kl'    : [],
            'train_loss_param' : [],
            'train_loss'       : [],
    
            'val_loss_rec'     : [],
            'val_loss_kl'      : [],
            'val_loss_param'   : [],
            'val_loss'         : [],
    
            'epochs'           : [],
    
        }, path_losses)    

In [None]:
checkpoint_weights   =   torch.load(path_model, weights_only=True) 
checkpoint_losses    =   torch.load(path_losses, weights_only=True)  

model.load_state_dict(checkpoint_weights['model_state_dict'])

best_loss   =  checkpoint_weights['best_loss'] 

train_loss         =  checkpoint_losses['train_loss']
train_loss_param   =  checkpoint_losses['train_loss_param']
train_loss_rec     =  checkpoint_losses['train_loss_rec']
train_loss_kl      =  checkpoint_losses['train_loss_kl']

val_loss          =  checkpoint_losses['val_loss'] 
val_loss_param    =  checkpoint_losses['val_loss_param'] 
val_loss_rec      =  checkpoint_losses['val_loss_rec'] 
val_loss_kl       =  checkpoint_losses['val_loss_kl'] 

epochs         =  checkpoint_losses['epochs'] 

In [None]:
scaler = GradScaler()

In [None]:
patience      =   10  
counter       =   0
#------------------------------------
beta_start    =   0.000001
beta_end      =   10

num_epochs    =   10
#------------------------------------
mu_values     =   [] 
logvar_values =   [] 

In [None]:
def linear_schedule(epoch):
    return beta_start + ((beta_end - beta_start) / num_epochs) * epoch 
#---------------------------------------------------------------------------------#
def exponential_schedule(epoch):
    return beta_start * ((beta_end / beta_start) ** (epoch / num_epochs))

In [None]:
num  =  20

In [None]:
num_epochs   =   10
beta         =   1 

In [None]:
for batch in tqdm.tqdm(train_loader):

    images, csv   = batch
    
    images   =  images.to(device)
    csv      =  csv.to(device) 
    break

In [None]:
params , kl_loss      =   model(images) 

In [None]:
params.shape, kl_loss.shape

In [None]:
params[0]

In [None]:
img_recon             =   simulate_batch(params) 

In [None]:
img_recon.shape

In [None]:
for epoch in range(num_epochs):

    alpha   =   1000
    beta    =   0.0001     
    
    model.train()
    
    total_loss_train      =   0.0
    total_loss_train_kl   =   0.0
    total_loss_train_rec  =   0.0
    total_loss_train_prm  =   0.0

    for batch in tqdm.tqdm(train_loader):

        images, csv   = batch
        
        images   =  images.to(device)
        csv      =  csv.to(device) 
  

#--------------------------------------------------------------------------------------------------------------------------

        with autocast(device_type='cuda'):
        
            params , kl_loss      =   model(images) 
            
            img_recon             =   simulate_batch(params, images.shape[0]).to(device) 
            
            img_rec_loss          =   criterion(img_recon, images)
            param_loss            =   criterion(csv, params) 

            

            loss    =    img_rec_loss   +   param_loss * alpha   +    kl_loss * beta   
            
#--------------------------------------------------------------------------------------------------------------------------
        scaler.scale(loss).backward()
        #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  
        scaler.step(optimizer)
        scaler.update()

        total_loss_train      +=   loss.item()
        total_loss_train_kl   +=   kl_loss.item()
        total_loss_train_rec  +=   img_rec_loss.item()  
        total_loss_train_prm  +=   param_loss.item()

    total_loss_train     /= len(train_loader)
    total_loss_train_kl  /= len(train_loader)
    total_loss_train_rec /= len(train_loader)
    total_loss_train_prm /= len(train_loader) 
    
    train_loss.append(total_loss_train)
    train_loss_kl.append(total_loss_train_kl)
    train_loss_rec.append(total_loss_train_rec)
    train_loss_prm.append(total_loss_train_prm)
    
    total_loss_train_tensor = torch.tensor(total_loss_train)

    if torch.isnan(total_loss_train_tensor):
        print("nan value is encountered !")

        break

    print( "-------------------------------------------------------------------------------")
    print(f"|  Epoch [{epoch+1}/{num_epochs}]          |        Total Train Loss : {total_loss_train:.4f}           |")
    print( "-------------------------------------------------------------------------------")
    
    print( "-------------------------------------------------------------------------------")
    print(f"|  Epoch [{epoch+1}/{num_epochs}]          |        Total KL Loss : {total_loss_train_kl:.4f}              |")
    print( "-------------------------------------------------------------------------------")
    
    print( "-------------------------------------------------------------------------------")
    print(f"|  Epoch [{epoch+1}/{num_epochs}]          |        Total REC Loss : {total_loss_train_rec:.4f}              |")
    print( "-------------------------------------------------------------------------------")

    print( "-------------------------------------------------------------------------------")
    print(f"|  Epoch [{epoch+1}/{num_epochs}]          |        Total PRM Loss : {total_loss_train_prm:.4f}              |")
    print( "-------------------------------------------------------------------------------")
#--------------------------------------------------------------------------------------------------------------------------
#--------------------------------------------------------------------------------------------------------------------------
    model.eval()
    
    total_loss_val      =   0.0
    total_loss_val_kl   =   0.0
    total_loss_val_rec  =   0.0
    total_loss_val_reg  =   0.0
    total_loss_val_prm  =   0.0

    with torch.no_grad():
        
        for batch in tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            
            images   =  batch
        
            images   =  images.to(device)
#--------------------------------------------------------------------------------------------------------------------------

            params , kl_loss      =   model(images)
            
            img_recon             =   simulate_batch(params, images.shape[0]).to(device) 
            
            img_rec_loss          =   criterion(img_recon, images)
            param_loss            =   criterion(csv, params) 

            

            loss    =    img_rec_loss   +   param_loss * alpha   +    kl_loss * beta 
          

            total_loss_val      +=   loss.item()
            total_loss_val_kl   +=   kl_loss.item()
            total_loss_val_rec  +=   img_rec_loss.item()
            total_loss_val_prm  +=   param_loss.item()

    
    total_loss_val     /= len(val_loader)
    total_loss_val_kl  /= len(val_loader)
    total_loss_val_rec /= len(val_loader)
    total_loss_val_prm /= len(val_loader)
    
    val_loss.append(total_loss_val)
    val_loss_kl.append(total_loss_val_kl)
    val_loss_rec.append(total_loss_val_rec)
    val_loss_prm.append(total_loss_val_prm)


    print( "-------------------------------------------------------------------------------")
    print(f"|  Epoch [{epoch+1}/{num_epochs}]         |       Total Validation Loss : {total_loss_val:.4f}           |")
    print( "-------------------------------------------------------------------------------")
    
    print( "-------------------------------------------------------------------------------")
    print(f"|  Epoch [{epoch+1}/{num_epochs}]         |       Total KL Loss : {total_loss_val_kl:.4f}              |")
    print( "-------------------------------------------------------------------------------")

    print( "-------------------------------------------------------------------------------")
    print(f"|  Epoch [{epoch+1}/{num_epochs}]         |       Total REC Loss : {total_loss_val_rec:.4f}             |")
    print( "-------------------------------------------------------------------------------")

    print( "-------------------------------------------------------------------------------")
    print(f"|  Epoch [{epoch+1}/{num_epochs}]         |       Total PRM Loss : {total_loss_val_prm:.4f}             |")
    print( "-------------------------------------------------------------------------------")

#--------------------------------------------------------------------------------------------------------------------------
    if len(val_loss) >= 2:
        
        res   =   (  (val_loss[-2] - val_loss[-1]) / val_loss[-2] ) * 100 
        print( "-------------------------------------------------------------------------------")
        print(f"|              Change in loss is      %   {res:.2f}                               |")
        print( "-------------------------------------------------------------------------------")

    if total_loss_val < best_loss:
        print("*************...saving best model *************")
        best_loss = total_loss_val 
        torch.save({
            'model_state_dict': model.state_dict(),
            'best_loss': best_loss,
        }, path_model)   

    torch.save({
            'train_loss_rec'   : [],
            'train_loss_kl'    : [],
            'train_loss_param' : [],
            'train_loss'       : [],
    
            'val_loss_rec'     : [],
            'val_loss_kl'      : [],
            'val_loss_param'   : [],
            'val_loss'         : [],
    
            'epochs'           : [],
    
        }, path_losses)  

#--------------------------------------------------------------------------------------------------------------------------
    if (len(val_loss) >= 2) and (val_loss[-2] > val_loss[-1]):
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping!")
            break

    torch.cuda.empty_cache()


In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,6))

ax.plot(train_loss, "b-", label="Train Loss")
ax.plot(val_loss, "r-", label="Validation Loss")
ax.set_xlabel("Epochs")
ax.set_ylabel("MSE_Loss")
ax.legend()

In [None]:
with torch.no_grad():
	for img in val_loader:
		img    = img.to(device)
		recon, kl  = model(img)
		break

In [None]:
img.shape, recon.shape

In [None]:
import matplotlib.pyplot as plt

plt.figure(dpi=250)

fig, ax = plt.subplots(2, 7, figsize=(15, 4))

for i in range(7):
	ax[0, i].imshow(img[i].reshape(224,224,3).cpu().numpy())
	ax[1, i].imshow(recon[i].reshape(224, 224, 3).cpu().numpy())
	ax[0, i].axis('OFF')
	ax[1, i].axis('OFF')
plt.show()


In [None]:
e = model.eval()

In [None]:

plt.figure(dpi=250)

fig, ax = plt.subplots(1, 1, figsize=(4, 4))

ax.imshow(img[1].reshape(224,224,3).cpu().numpy())

plt.show()


In [None]:
mu, logvar = model.encode(img[1].reshape(1, 3, 224, 224))

In [None]:
z = model.reparameterize(mu, logvar)

In [None]:
z.shape

In [None]:
def latent_space_traversal(model, z, latent_dim, steps=10, range_val=3):

    device = next(model.parameters()).device  

    fig, axs = plt.subplots(latent_dim, steps, figsize=(steps * 2, latent_dim * 2))

    for dim in range(latent_dim):
        for step, val in enumerate(torch.linspace(-range_val, range_val, steps)):
 
            z_traversal             =   z.clone()
            z_traversal[0, 0, dim]  =   val 

            with torch.no_grad():
                
                recon              = model.decode(z_traversal).squeeze(0).cpu().numpy()
                #unnormalized_recon = ((recon + 1) * (255.0 / 2)).to(torch.int).cpu().numpy()

            axs[dim, step].imshow(recon.reshape(224, 224, 3))
            axs[dim, step].axis("off")

In [None]:
latent_space_traversal(model, z, 64) 

In [None]:
# Simulate distributions

B, D = 1000, 128  # Batch size and latent dimension



# Simulate mu and logvar

mu = np.random.normal(0, 1, B)               # Mean, normal distribution

logvar = np.random.normal(-1, 0.5, B)       # Log variance, skewed normal



# Calculate std and epsilon

std = np.exp(0.5 * logvar)                  # Standard deviation

eps = np.random.normal(0, 1, B)             # Epsilon, N(0, 1)



# Calculate z

z = mu + eps * std



# Plot distributions

plt.figure(figsize=(15, 8))



# Plot mu

plt.subplot(2, 3, 1)

plt.hist(mu, bins=50, alpha=0.7, color='blue', density=True)

plt.title('Distribution of mu (Mean)')

plt.xlabel(f'Value')

plt.ylabel(f'Frequency')



# Plot logvar

plt.subplot(2, 3, 2)

plt.hist(logvar, bins=50, alpha=0.7, color='orange', density=True)

plt.title(f'Distribution of logvar (Log-Variance)')

plt.xlabel(f'Value')

plt.ylabel(f'Frequency')



# Plot std

plt.subplot(2, 3, 3)

plt.hist(std, bins=50, alpha=0.7, color='green', density=True)

plt.title(f'Distribution of sigma (Standard Deviation)')

plt.xlabel('Value')

plt.ylabel('Frequency')



# Plot eps

plt.subplot(2, 3, 4)

plt.hist(eps, bins=50, alpha=0.7, color='red', density=True)

plt.title(f'Distribution of epsilon (Random Noise)')

plt.xlabel('Value')

plt.ylabel('Frequency')



# Plot z

plt.subplot(2, 3, 5)

plt.hist(z, bins=50, alpha=0.7, color='purple', density=True)

plt.title(f'Distribution of z (Latent Variable)')

plt.xlabel('Value')

plt.ylabel('Frequency')



plt.tight_layout()

plt.show()

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Define Softplus
softplus = nn.Softplus()

# Generate input data
x = torch.linspace(-10, 10, 100)
y = softplus(x)

# Plot Softplus
plt.plot(x.numpy(), y.numpy(), label='Softplus(x)')
plt.title("Softplus Activation Function")
plt.xlabel("Input (x)")
plt.ylabel("Output (Softplus(x))")
plt.axhline(0, color='black', linewidth=0.5, linestyle='--')
plt.axvline(0, color='black', linewidth=0.5, linestyle='--')
plt.legend()
plt.show()
