In [None]:
import torch
from torch.utils.data import DataLoader
from dataset import SingleDataset
import torchvision.transforms as transforms
from vae import VAE
import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
dataset = SingleDataset(transform=transforms.Compose([transforms.ToTensor()]))
data_loader = DataLoader(dataset, shuffle=True, num_workers=6, batch_size=1)

# Model parameters
batch_size = 1024
latent_dim = 10

# Load the model from checkpoint
checkpoint_path = f"./weights/lat_{latent_dim}.ckpt"
checkpoint = torch.load(checkpoint_path)
model = VAE(latent_dim, batch_size).eval().to(device) # Set the model to evaluation mode

# Create a dictionary for folder names
folder_dict = {}
folders = [f for f in os.listdir('mechanisms') if os.path.isdir(os.path.join('mechanisms', f))]

# Enumerate through the folders and assign numbers starting from 0
for index, folder_name in enumerate(folders):
    folder_dict[folder_name] = index

# Process and save data
for batch_num, batch in enumerate(data_loader):
    images, description = batch
    mech_type = description[0].split('/')[1]
    description = [x for x in description[0].split('/')[2].split(' ') if x]
    
    if mech_type.startswith("T"):
        description[-5] = float(folder_dict[mech_type])
    else:
        index = description.index(mech_type)
        description[index] = float(folder_dict[mech_type])
    
    description = np.array([float(x) for x in description]).reshape(1, -1)
    images = images.to(device)
    
    # Forward pass through the VAE model
    with torch.no_grad():
        x = model.encoder(images)
        mean, logvar = x[:, :model.latent_dim], x[:, model.latent_dim:]
        z = model.reparameterize(mean, logvar)
        z = z.cpu().detach().numpy()  
        z_description = np.concatenate((z, description), axis=1)
    
    np.save(f'shrinath_new/{batch_num}.npy', z_description)


In [None]:
!zip -r shrinath.zip shrinath # Zipping all 

In [None]:
# Define a function to plot original and reconstructed images
def plot_images(original_images, reconstructed_images):
    n_images = min(len(original_images), len(reconstructed_images))
    fig, axes = plt.subplots(2, n_images, figsize=(12, 4))

    for i in range(n_images):
        axes[0, i].imshow(original_images[i].permute(1, 2, 0))  # Original image
        axes[0, i].set_title("Original")
        axes[0, i].axis("off")

        axes[1, i].imshow(reconstructed_images[i].permute(1, 2, 0))  # Reconstructed image
        axes[1, i].set_title("Reconstructed")
        axes[1, i].axis("off")

    plt.show()

# Model parameters
batch_size = 16
latent_dim = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = SingleDataset(transform=transforms.Compose([transforms.ToTensor(), ]))
data_loader = DataLoader(dataset, shuffle=True, num_workers=6, batch_size=16, drop_last=True)

# Load the model from checkpoint
checkpoint_path = f"./weights/lat_{latent_dim}.ckpt"
checkpoint = torch.load(checkpoint_path)
model = VAE(latent_dim, batch_size).eval().to(device) # Set the model to evaluation mode

# Iterate through the DataLoader to obtain and plot original and reconstructed images
for batch in data_loader:
    # Separate the batch into images and labels (assuming labels are not needed)
    images, _ = batch
    
    images = images.to(device)
    
    # Forward pass through the VAE model to obtain reconstructed images
    with torch.no_grad():
        reconstructed_images, _, _ = model(images)
    
    images = images.cpu().detach()
    reconstructed_images = reconstructed_images.cpu().detach()

    # Plot the original and reconstructed images
    plot_images(images, reconstructed_images)

    # Break after the first batch (remove this line to process all batches)
    break