# **Exploring Latent Space via VAE**

## **Important Libraries**

### **Install**

In [None]:
!curl -LsSf https://astral.sh/uv/install.sh | sh

In [None]:
!uv pip install -q --no-cache-dir --system lightning torchmetrics umap-learn

### **Import**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

import lightning as L
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint

from torchmetrics.image import LearnedPerceptualImagePatchSimilarity

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import median_abs_deviation

import umap
import math
import random

import cv2
import imageio

import os
import gc
import shutil
import urllib.request
from tqdm.auto import tqdm

import time
from datetime import datetime

from IPython.display import Image as IImage
from IPython.display import display
from google.colab.patches import cv2_imshow

import warnings

In [None]:
warnings.filterwarnings("ignore")

%matplotlib inline
plt.rcParams['axes.facecolor'] = 'lightgray'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'

## **Configuration**

In [None]:
os.makedirs("experiment", exist_ok=True)
os.makedirs("experiment/training", exist_ok=True)
os.makedirs("experiment/dataset", exist_ok=True)
os.makedirs("experiment/model", exist_ok=True)
EXPERIMENT_DIR = "experiment"

In [None]:
H_DIM = 256
Z_DIM = 64
BATCH_SIZE = 128
MAX_EPOCH = 16
LEARNING_RATE = 3.1e-4
KLD_WEIGHT = 0.0249673
SCHEDULER_GAMMA = 0.978654321
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_SAMPLE = 23140

In [None]:
SEED = int(np.random.randint(2147483647))
print(f"Random seed: {SEED}")

In [None]:
METRIC_TO_MONITOR = "val_lpips"
METRIC_MODE = "min"

## **Dataset**

### **Utils**

In [None]:
IMAGE_SIZE = 128
IMAGE_TRANSFORM = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5, 0.5, 0.5),
            (0.5, 0.5, 0.5)
        ),
    ]
)

In [None]:
class CelebADataModule(L.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.split = {
            "train": "all",
            "val": "valid",
        }
        self.train_dataset = None
        self.val_dataset = None
        self.transform = IMAGE_TRANSFORM
        self.dataset = datasets.CelebA
        self.url_download = (
            "https://github.com/reshalfahsi/latent-space-vae"
            "/releases/download/dataset"
        )
        self.download_files = [
            "identity_CelebA.txt",
            "img_align_celeba.zip",
            "list_attr_celeba.txt",
            "list_bbox_celeba.txt",
            "list_eval_partition.txt",
            "list_landmarks_align_celeba.txt",
        ]

    def prepare_data(self):
        if not os.path.exists(
            f"{EXPERIMENT_DIR}/dataset/celeba/img_align_celeba.zip"
        ):
            os.makedirs(f"{EXPERIMENT_DIR}/dataset/celeba", exist_ok=True)
            for filename in self.download_files:
                urllib.request.urlretrieve(
                    f"{self.url_download}/{filename}",
                    f"{EXPERIMENT_DIR}/dataset/celeba/{filename}"
                )
            self.dataset(
                root=f"{EXPERIMENT_DIR}/dataset",
                split=self.split['train'],
                transform=self.transform,
                download=True,
            )

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = self.dataset(
                root=f"{EXPERIMENT_DIR}/dataset",
                split=self.split['train'],
                transform=self.transform,
                download=False,
            )
            if stage == "fit":
                self.val_dataset = self.dataset(
                    root=f"{EXPERIMENT_DIR}/dataset",
                    split=self.split['val'],
                    transform=self.transform,
                    download=False,
                )

    def generation_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=2,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=2,
        )

### **Load**

In [None]:
DATASET = CelebADataModule()

## **Model**

### **Utils**

In [None]:
class AvgMeter(object):
    def __init__(self, num=40):
        self.num = num
        self.reset()

    def reset(self):
        self.scores = []

    def update(self, val):
        self.scores.append(val)

    def show(self):
        out = torch.mean(
            torch.stack(
                self.scores[np.maximum(len(self.scores)-self.num, 0):]
            )
        )
        return out

### **VAE**

In [None]:
class VAE(nn.Module):
    def __init__(
        self,
        h_dim=H_DIM,
        z_dim=Z_DIM,
        image_size=IMAGE_SIZE,
        in_channels=3,
    ):
        super().__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(
                in_channels, h_dim // 8, kernel_size=4, stride=2, padding=1
            ),
            nn.SiLU(),
            ####################################################################
            nn.Conv2d(
                h_dim // 8, h_dim // 4, kernel_size=4, stride=2, padding=1
            ),
            nn.SiLU(),
            ####################################################################
            nn.Conv2d(
                h_dim // 4, h_dim // 2, kernel_size=4, stride=2, padding=1
            ),
            nn.SiLU(),
            ####################################################################
            nn.Conv2d(
                h_dim // 2, h_dim, kernel_size=4, stride=2, padding=1
            ),
            nn.SiLU(),
            nn.Flatten(),
        )

        # Re-parameterization
        self.fc_mu = nn.Linear(
            h_dim * (image_size // 16) * (image_size // 16), z_dim
        )
        self.fc_logvar = nn.Linear(
            h_dim * (image_size // 16) * (image_size // 16), z_dim
        )

        # Decoder
        self.fc_decode = nn.Linear(
            z_dim, h_dim * (image_size // 16) * (image_size // 16)
        )
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (h_dim, (image_size // 16), (image_size // 16))),
            nn.ConvTranspose2d(
                h_dim, h_dim // 2, kernel_size=4, stride=2, padding=1
            ),
            nn.SiLU(),
            ####################################################################
            nn.ConvTranspose2d(
                h_dim // 2, h_dim // 4, kernel_size=4, stride=2, padding=1
            ),
            nn.SiLU(),
            ####################################################################
            nn.ConvTranspose2d(
                h_dim // 4, h_dim // 8, kernel_size=4, stride=2, padding=1
            ),
            nn.SiLU(),
            ####################################################################
            nn.ConvTranspose2d(
                h_dim // 8, in_channels, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh(),
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_logvar(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_decode(z)
        # Reshape h back to convolutional layers' expected input shape
        # The shape is (batch_size, channels, height, width)
        # Example:
        #   - After fc_decode, h is (batch_size, 256 * 8 * 8)
        #   - We need to reshape it to (batch_size, 256, 8, 8)
        # The Unflatten layer in the Sequential decoder takes care of this
        reconst_image = self.decoder(h)
        return reconst_image

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

### **Loss Function**

In [None]:
class VAELoss(nn.Module):
    def __init__(self, kld_weight=KLD_WEIGHT):
        super().__init__()
        self.kld_weight = kld_weight

    def forward(self, reconst_x, x, mu, log_var):
        reconst_loss = nn.functional.mse_loss(reconst_x, x)
        kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
        loss = reconst_loss + self.kld_weight * kld_loss
        return loss, reconst_loss, kld_loss

### **Wrapper**

In [None]:
class VAEWrapper(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.model = VAE()
        self.loss = VAELoss()

        self.batch_size = BATCH_SIZE
        self.lr = LEARNING_RATE
        self.max_epoch = MAX_EPOCH

        self.val_lpips = LearnedPerceptualImagePatchSimilarity()

        self.val_lpips.eval()

        self.val_lpips_recorder = AvgMeter()

        self.val_lpips_list = list()

        self.train_loss_recorder = AvgMeter()
        self.train_reconst_loss_recorder = AvgMeter()
        self.train_kld_loss_recorder = AvgMeter()

        self.train_loss = list()
        self.train_reconst_loss = list()
        self.train_kld_loss = list()

        self.sanity_check_counter = 1

        self.automatic_optimization = False

    def forward(self, x):
        reconst_images, _, _ = self.model(x)
        return reconst_images

    def training_step(self, batch, batch_idx):
        images, _ = batch

        reconst_images, mu, log_var = self.model(images)
        loss, reconst_loss, kld_loss = self.loss(
            reconst_images, images, mu, log_var
        )

        opt = self.optimizers()
        opt.zero_grad()
        self.manual_backward(loss)
        opt.step()

        self.log("train_loss", loss.data.cpu(), prog_bar=True)
        self.train_loss_recorder.update(loss.data.cpu())

        self.log("train_reconst_loss", reconst_loss.data.cpu(), prog_bar=True)
        self.train_reconst_loss_recorder.update(reconst_loss.data.cpu())

        self.log("train_kld_loss", kld_loss.data.cpu(), prog_bar=True)
        self.train_kld_loss_recorder.update(kld_loss.data.cpu())

    def on_train_epoch_end(self):
        mean = self.train_loss_recorder.show()
        self.train_loss.append(mean.data.cpu().numpy())
        self.train_loss_recorder = AvgMeter()

        mean = self.train_reconst_loss_recorder.show()
        self.train_reconst_loss.append(mean.data.cpu().numpy())
        self.train_reconst_loss_recorder = AvgMeter()

        mean = self.train_kld_loss_recorder.show()
        self.train_kld_loss.append(mean.data.cpu().numpy())
        self.train_kld_loss_recorder = AvgMeter()

        self._plot_evaluation_metrics()

        scheduler = self.lr_schedulers()
        scheduler.step()

    def validation_step(self, batch, batch_idx):
        images, _ = batch

        reconst_images = self(images)

        if self.sanity_check_counter == 0:
            self.val_lpips.update(reconst_images, images)

            lpips = self.val_lpips.compute()

            self.log("val_lpips", lpips.data.cpu(), prog_bar=True)
            self.val_lpips_recorder.update(lpips.data.cpu())

    def on_validation_epoch_end(self):
        if self.sanity_check_counter == 0:
            mean = self.val_lpips_recorder.show()
            self.val_lpips_list.append(mean.data.cpu().numpy())
            self.val_lpips_recorder = AvgMeter()
        else:
            self.sanity_check_counter -= 1

    def _plot_evaluation_metrics(self):
        # VAE Loss
        vae_loss_img_file = os.path.join(
            EXPERIMENT_DIR,
            "training/VAE_loss_plot.png"
        )
        plt.plot(self.train_loss, color="r", label="loss")
        plt.plot(self.train_reconst_loss, color="g", label="reconst_loss")
        plt.plot(self.train_kld_loss, color="b", label="kld_loss")
        plt.title("VAE Loss Curves")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid()
        plt.savefig(vae_loss_img_file)
        plt.clf()

        # LPIPS
        lpips_img_file = os.path.join(
            EXPERIMENT_DIR,
            "training/VAE_lpips_plot.png"
        )
        plt.plot(self.val_lpips_list, color="b", label="lpips_score")
        plt.title("LPIPS Curves")
        plt.xlabel("Epoch")
        plt.ylabel("LPIPS")
        plt.legend()
        plt.grid()
        plt.savefig(lpips_img_file)
        plt.clf()

    def configure_optimizers(self):

        optimizer = optim.AdamW(
            self.model.parameters(),
            lr=self.lr,
        )

        scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer,
            gamma=SCHEDULER_GAMMA,
        )

        return [optimizer], [scheduler]

In [None]:
MODEL = VAEWrapper
BEST_MODEL_PATH = os.path.join(
    EXPERIMENT_DIR,
    "model",
    "VAE_best.ckpt",
)

## **Training**

In [None]:
def _train_loop():
    seed_everything(SEED, workers=True)

    print("\n=================[ Variational Auto Encoder ]=================\n")

    model = MODEL()

    callbacks = list()

    checkpoint = ModelCheckpoint(
        monitor=METRIC_TO_MONITOR,
        dirpath=os.path.join(EXPERIMENT_DIR, "model"),
        mode=METRIC_MODE,
        filename="VAE_best",
    )
    callbacks.append(checkpoint)

    if os.path.exists(BEST_MODEL_PATH):
        ckpt_path = BEST_MODEL_PATH
    else:
        ckpt_path = None

    trainer = Trainer(
        accelerator="auto",
        devices=1,
        max_epochs=MAX_EPOCH,
        logger=False,
        callbacks=callbacks,
        log_every_n_steps=5,
    )
    trainer.fit(model, ckpt_path=ckpt_path, datamodule=DATASET)

_train_loop()

In [None]:
cv2_imshow(
    cv2.imread(os.path.join(EXPERIMENT_DIR, "training/VAE_loss_plot.png"))
)
cv2_imshow(
    cv2.imread(os.path.join(EXPERIMENT_DIR, "training/VAE_lpips_plot.png"))
)

## **Explore the Latent Space!!**

In [None]:
# @title **Prepare**

try:
    DATASET.prepare_data()
    DATASET.setup()
    generation_loader = DATASET.generation_dataloader()
except Exception as e:
    print(f"Cannot prepare data: {e}")
    generation_loader = None

model = VAEWrapper.load_from_checkpoint(BEST_MODEL_PATH).model
model.eval()
print()

### **Random Sampling**

In [None]:
# Generate samples

sample_size = 64

with torch.no_grad():
    z = torch.randn(sample_size, Z_DIM).to(DEVICE)
    generated_images = model.decode(z)

    # Then normalize
    generated_images_norm = 0.5 * (generated_images + 1.)

    # Resize the interpolated images
    # Define the desired output size
    output_size = (256, 256)
    # Use torchvision.transforms.Resize to resize the images
    resize_transform = transforms.Resize(output_size)
    generated_images_norm = resize_transform(generated_images_norm)

    save_image(
        generated_images_norm.cpu(),
        f"{EXPERIMENT_DIR}/generated_images.png",
    )

    print("Generated samples saved.")

### **Visualize Sampled Image**

In [None]:
display(IImage(filename=f"{EXPERIMENT_DIR}/generated_images.png"))

In [None]:
# @title **Pick an Attribute!**
attr_names = [
    "5_o_Clock_Shadow",
    "Arched_Eyebrows",
    "Attractive",
    "Bags_Under_Eyes",
    "Bald",
    "Bangs",
    "Big_Lips",
    "Big_Nose",
    "Black_Hair",
    "Blond_Hair",
    "Blurry",
    "Brown_Hair",
    "Bushy_Eyebrows",
    "Chubby",
    "Double_Chin",
    "Eyeglasses",
    "Goatee",
    "Gray_Hair",
    "Heavy_Makeup",
    "High_Cheekbones",
    "Male",
    "Mouth_Slightly_Open",
    "Mustache",
    "Narrow_Eyes",
    "No_Beard",
    "Oval_Face",
    "Pale_Skin",
    "Pointy_Nose",
    "Receding_Hairline",
    "Rosy_Cheeks",
    "Sideburns",
    "Smiling",
    "Straight_Hair",
    "Wavy_Hair",
    "Wearing_Earrings",
    "Wearing_Hat",
    "Wearing_Lipstick",
    "Wearing_Necklace",
    "Wearing_Necktie",
    "Young"
]
attr_dict = {
    "5_o_Clock_Shadow": 0,
    "Arched_Eyebrows": 1,
    "Attractive": 2,
    "Bags_Under_Eyes": 3,
    "Bald": 4,
    "Bangs": 5,
    "Big_Lips": 6,
    "Big_Nose": 7,
    "Black_Hair": 8,
    "Blond_Hair": 9,
    "Blurry": 10,
    "Brown_Hair": 11,
    "Bushy_Eyebrows": 12,
    "Chubby": 13,
    "Double_Chin": 14,
    "Eyeglasses": 15,
    "Goatee": 16,
    "Gray_Hair": 17,
    "Heavy_Makeup": 18,
    "High_Cheekbones": 19,
    "Male": 20,
    "Mouth_Slightly_Open": 21,
    "Mustache": 22,
    "Narrow_Eyes": 23,
    "No_Beard": 24,
    "Oval_Face": 25,
    "Pale_Skin": 26,
    "Pointy_Nose": 27,
    "Receding_Hairline": 28,
    "Rosy_Cheeks": 29,
    "Sideburns": 30,
    "Smiling": 31,
    "Straight_Hair": 32,
    "Wavy_Hair": 33,
    "Wearing_Earrings": 34,
    "Wearing_Hat": 35,
    "Wearing_Lipstick": 36,
    "Wearing_Necklace": 37,
    "Wearing_Necktie": 38,
    "Young": 39,
    "Random": 99,
}
attribute = "Male" #@param ["5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes", "Bald", "Bangs", "Big_Lips", "Big_Nose", "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair", "Bushy_Eyebrows", "Chubby", "Double_Chin", "Eyeglasses", "Goatee", "Gray_Hair", "Heavy_Makeup", "High_Cheekbones", "Male", "Mouth_Slightly_Open", "Mustache", "Narrow_Eyes", "No_Beard", "Oval_Face", "Pale_Skin", "Pointy_Nose", "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling", "Straight_Hair", "Wavy_Hair", "Wearing_Earrings", "Wearing_Hat", "Wearing_Lipstick", "Wearing_Necklace", "Wearing_Necktie", "Young", "Random"]
positive_attribute_index = attr_dict[attribute]

### **Collect and Process Latent Vectors**

In [None]:
# Collect latent vectors and corresponding labels (positive/negative)
all_latent_vectors = list()
all_labels = list()

try:
    print(
        f"Selected positive attribute: {attr_names[positive_attribute_index]}"
    )

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(generation_loader)):
            images = images.to(DEVICE)
            # Get the positive label for this batch
            # Assuming the labels tensor has the positive attribute at
            # positive_attribute_index
            positive_labels = labels[:, positive_attribute_index]

            # Encode images to get latent vectors
            mu, log_var = model.encode(images)
            # Use reparameterized z for visualization
            z = model.reparameterize(mu, log_var)

            all_latent_vectors.append(z.cpu().numpy())
            all_labels.append(positive_labels.cpu().numpy())

            # Process a reasonable number of data
            # Adjust this number based on needs and memory
            if i * BATCH_SIZE > MAX_SAMPLE:
                break

    all_latent_vectors = np.concatenate(all_latent_vectors, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
except Exception as e:
    print(f"Cannot use attr_names: {e}")
    all_latent_vectors = np.random.randn(MAX_SAMPLE, Z_DIM)
    all_labels = None

In [None]:
random_state = random.randint(0, np.iinfo(np.int32).max)
reducer = umap.UMAP(
    n_components=2,
    n_neighbors=25,
    min_dist=0.1,
    random_state=random_state,
)
embedding_latent = reducer.fit_transform(
    all_latent_vectors,
    y=all_labels,
)

In [None]:
# Find the min and max values of the UMAP embedding
x_min, x_max = embedding_latent[:, 0].min(), embedding_latent[:, 0].max()
y_min, y_max = embedding_latent[:, 1].min(), embedding_latent[:, 1].max()

print(f"x_min: {x_min}, x_max: {x_max}")
print(f"y_min: {y_min}, y_max: {y_max}")

print()

# Create a grid of points
grid_size = 15
grid_x = np.linspace(x_min, x_max, grid_size)
grid_y = np.linspace(y_min, y_max, grid_size)

# Generate grid points in the 2D UMAP space
grid_points_2d = list()
for i in range(grid_size):
    for j in range(grid_size):
        grid_points_2d.append([grid_x[j], grid_y[i]])

grid_points_2d = np.array(grid_points_2d)

# Inverse transform the UMAP grid points back to the original latent space
# This is an approximation as UMAP's inverse transform is not perfect
# Use the trained UMAP reducer to inverse transform
grid_points_latent = reducer.inverse_transform(grid_points_2d)

# Convert the latent points to a PyTorch tensor
grid_points_latent_tensor = torch.tensor(
    grid_points_latent, dtype=torch.float32
).to(DEVICE)

# Decode the latent points to generate images
with torch.no_grad():
    generated_grid_images = model.decode(grid_points_latent_tensor)

# Normalize the images
generated_grid_images_norm = 0.5 * (generated_grid_images + 1.)

# Arrange the images in a grid for visualization
# We need to reshape the generated images to a grid of size
# grid_size x grid_size and then stack them
generated_grid_images_norm = generated_grid_images_norm.cpu()

# Create a list of images for the grid
image_list = list()
for i in range(grid_size):
    row_images = list()
    for j in range(grid_size):
        # The order in grid_points_2d is row by row (y then x)
        image_index = i * grid_size + j
        row_images.append(generated_grid_images_norm[image_index])
    # Stack images in a row
    image_list.append(torch.cat(row_images, dim=2)) # Stack along width (dim=2)

# flipped the y-axis for proper visualization
image_list.reverse()

# Stack rows of images
final_image_grid = torch.cat(image_list, dim=1) # Stack along height (dim=1)

# Save the grid of images
save_image(
    final_image_grid,
    f"{EXPERIMENT_DIR}/latent_space_grid.png",
)

print("Latent space grid of generated images saved.\n")

In [None]:
# Load the image to draw on
image_path = os.path.join(EXPERIMENT_DIR, "latent_space_grid.png")
img = cv2.imread(image_path)

# Get image dimensions
height, width, _ = img.shape

# Define start and end points for the lines (relative to the image)
# Center point is the center of the image
center_x, center_y = width // 2, height // 2

# Horizontal line: starts at left edge, ends at right edge,
# at the vertical center
start_point_h = (0, center_y)
end_point_h = (width, center_y)

# Vertical line: starts at top edge, ends at bottom edge,
# at the horizontal center
start_point_v = (center_x, 0)
end_point_v = (center_x, height)

# Define line color (BGR format) and thickness
line_color = (0, 0, 255)  # Red color
line_thickness = 2

# Draw the horizontal line
cv2.line(img, start_point_h, end_point_h, line_color, line_thickness)

# Draw the vertical line
cv2.line(img, start_point_v, end_point_v, line_color, line_thickness)

# Calculate the intersection point
intersection_point = (center_x, center_y)

# Define text to display at the intersection
text_to_display = "O"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_thickness = 2
text_color = (255, 0, 0)

# Get the size of the text
(text_width, text_height), baseline = cv2.getTextSize(
    text_to_display, font, font_scale, font_thickness
)

# Calculate the text position so it's centered at the intersection
text_x = intersection_point[0] - text_width // 2
# Adjust y based on text height
text_y = intersection_point[1] + text_height // 2

# Draw the text on the image
cv2.putText(
    img,
    text_to_display,
    (text_x, text_y),
    font,
    font_scale,
    text_color,
    font_thickness,
    cv2.LINE_AA,
)

# Save the image with the lines
cv2.imwrite(image_path, img)

### **The 2D Latent Space Visualization**

In [None]:
# Display the generated grid image
display(IImage(filename=f"{EXPERIMENT_DIR}/latent_space_grid.png"))

In [None]:
# @title **Latent Space Traversal**

x_pos = 0.0 #@param {type:"slider", min:-1, max:1, step:1e-15}
y_pos = 0.0 #@param {type:"slider", min:-1, max:1, step:1e-15}

x_pos = (x_pos + 1.) / 2.
y_pos = (y_pos + 1.) / 2.

x_pos = x_pos * (x_max - x_min) + x_min
y_pos = y_pos * (y_max - y_min) + y_min

traversed_latent_vector = reducer.inverse_transform(
    np.array([[x_pos, y_pos]])
)

with torch.no_grad():
    traversed_image = model.decode(
        torch.tensor(traversed_latent_vector, dtype=torch.float32).to(DEVICE)
    )

# Resize the interpolated images
# Define the desired output size
output_size = (256, 256)
# Use torchvision.transforms.Resize to resize the images
resize_transform = transforms.Resize(output_size)
traversed_image = resize_transform(traversed_image)

# Normalize image
traversed_image = 0.5 * (traversed_image + 1.)

save_image(
    traversed_image.cpu(),
    f"{EXPERIMENT_DIR}/traversed_image.png",
)

### **Image at the Location**

In [None]:
display(IImage(filename=f"{EXPERIMENT_DIR}/traversed_image.png"))

### **Latent Vector Interpolation**

In [None]:
expected_positive_vector = None
expected_negative_vector = None

try:
    # Separate latent vectors based on the positive label
    positive_latent_vectors = reducer.transform(
        all_latent_vectors[all_labels == 1]
    )
    negative_latent_vectors = reducer.transform(
        all_latent_vectors[all_labels == 0]
    )

    # Calculate the expected latent vector for each group
    expected_positive_vector = np.median(positive_latent_vectors, axis=0)
    expected_negative_vector = np.median(negative_latent_vectors, axis=0)

    # Calculate the dispersion latent vector for each group
    dispersion_positive_vector = median_abs_deviation(
        positive_latent_vectors, axis=0
    )
    dispersion_negative_vector = median_abs_deviation(
        negative_latent_vectors, axis=0
    )

    print(
        f"Expected Latent Vector for {attr_names[positive_attribute_index]}:"
    )
    print(expected_positive_vector)
    print(
        f"Dispersion of Latent Vector for "
        f"{attr_names[positive_attribute_index]}:"
    )
    print(dispersion_positive_vector)

    print()

    print(
        "Expected Latent Vector for Not "
        f"{attr_names[positive_attribute_index]}:"
    )
    print(expected_negative_vector)
    print(
        f"Dispersion of Latent Vector for "
        f"Not {attr_names[positive_attribute_index]}:"
    )
    print(dispersion_negative_vector)

    # Calculate the difference vector
    difference_vector = expected_positive_vector - expected_negative_vector
    print(
        f"\nDifference Vector ({attr_names[positive_attribute_index]} - "
        f"Not {attr_names[positive_attribute_index]}):"
    )
    print(difference_vector)
except Exception as e:
    print(
        f"Error during latent vector calculation: {e}\n"
        "Not generating pre-generated vectors"
    )

In [None]:
try:
    expected_positive_vector = expected_positive_vector + (
        (dispersion_positive_vector)
        * np.random.rand(2)
        - (dispersion_positive_vector / 2.)
    )
    print(f"Compressed positive vector: {expected_positive_vector}")
    expected_positive_vector = reducer.inverse_transform(
        [expected_positive_vector]
    )[0]

    expected_negative_vector = expected_negative_vector + (
        (dispersion_negative_vector)
        * np.random.rand(2)
        - (dispersion_negative_vector / 2.)
    )
    print(f"Compressed negative vector: {expected_negative_vector}")
    expected_negative_vector = reducer.inverse_transform(
        [expected_negative_vector]
    )[0]
except:
    print(
        "Warning: no pre-generated vectors "
        "at hand. Using random vectors instead."
    )

    expected_positive_vector = np.random.rand(2)
    expected_positive_vector[0] = (
        (x_max - x_min) * expected_positive_vector[0] + x_min
    )
    expected_positive_vector[1] = (
        (y_max - y_min) * expected_positive_vector[1] + y_min
    )
    print(f"Compressed positive vector: {expected_positive_vector}")
    expected_positive_vector = reducer.inverse_transform(
        [expected_positive_vector]
    )[0]

    expected_negative_vector = np.random.rand(2)
    expected_negative_vector[0] = (
        (x_max - x_min) * expected_negative_vector[0] + x_min
    )
    expected_negative_vector[1] = (
        (y_max - y_min) * expected_negative_vector[1] + y_min
    )
    print(f"Compressed negative vector: {expected_negative_vector}")
    expected_negative_vector = reducer.inverse_transform(
        [expected_negative_vector]
    )[0]

print()

print(f"Expected Positive Vector: {expected_positive_vector}")
print()
print(f"Expected Negative Vector: {expected_negative_vector}")

In [None]:
# Perform latent space interpolation
num_interpolation_steps = 10  # Number of steps for interpolation
interpolation_vectors = list()

# Linearly interpolate between the two expected vectors
for i in range(num_interpolation_steps):
    # Interpolation factor from 0 to 1
    alpha = i / (num_interpolation_steps - 1)
    interpolated_vector = (
        (1 - alpha) * expected_negative_vector
        + alpha * expected_positive_vector
    )
    interpolation_vectors.append(interpolated_vector)

# Convert the list of numpy arrays back to a torch tensor
interpolation_vectors_tensor = torch.tensor(
    np.array(interpolation_vectors), dtype=torch.float32
).to(DEVICE)

In [None]:
# Generate images from the interpolated latent vectors
model.eval()
with torch.no_grad():
    # The decode function outputs images
    interpolated_images = model.decode(interpolation_vectors_tensor)

    # Then normalize
    interpolated_images_norm = 0.5 * (interpolated_images + 1.)

    # Resize the interpolated images
    # Define the desired output size
    output_size = (512, 512)
    # Use torchvision.transforms.Resize to resize the images
    resize_transform = transforms.Resize(output_size)
    resized_interpolated_images = resize_transform(interpolated_images_norm)

    # Save the interpolated images
    save_image(
        resized_interpolated_images.cpu(),
        f"{EXPERIMENT_DIR}/interpolated_images.png",
        nrow=num_interpolation_steps, # Arrange images in a single row
        normalize=True,
    )

print(
    f"Interpolated images saved to {EXPERIMENT_DIR}/interpolated_images.png"
)

In [None]:
# Interpolated_images is a torch tensor of shape (N, C, H, W)
# where N is the number of images, C is channels, H is height, W is width.

# Convert the tensor to a numpy array
# Move to CPU if it's on GPU and permute dimensions for imageio (H, W, C)
# Then normalize
interpolated_images_norm = 0.5 * (interpolated_images + 1.)

# Resize the interpolated images
# Define the desired output size
output_size = (512, 512)
# Use torchvision.transforms.Resize to resize the images
resize_transform = transforms.Resize(output_size)
resized_interpolated_images = resize_transform(interpolated_images_norm)

# Convert float values (0-1) to uint8 (0-255) for imageio
interpolated_images_np = (
    resized_interpolated_images.cpu().permute(0, 2, 3, 1).numpy() * 255
).astype(np.uint8)

# Save the images as a GIF
gif_path = (
    f"{EXPERIMENT_DIR}/interpolated_animation.gif"
)
imageio.mimsave(
    gif_path, interpolated_images_np, loop=0, fps=1,
)
print(f"GIF saved to {gif_path}")

### **Interpolation Visualization**

In [None]:
display(IImage(filename=f"{EXPERIMENT_DIR}/interpolated_images.png"))

In [None]:
# Display the GIF in the notebook
display(IImage(open(gif_path, "rb").read()))