In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

In [None]:
# VQ-VAE Model Definition
class VQVAE(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_embeddings, embedding_dim, commitment_cost):
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, embedding_dim, kernel_size=4, stride=2, padding=1)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, hidden_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_channels, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
        self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)

    def forward(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.vq_layer(encoded)
        decoded = self.decoder(quantized)
        return decoded, vq_loss

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
        self.commitment_cost = commitment_cost

    def forward(self, x):
        # Quantize
        flattened = x.view(-1, self.embedding_dim)
        distances = torch.cdist(flattened, self.embedding.weight)
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embedding(encoding_indices).view(x.size())

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = x + (quantized - x).detach()  # Straight-through estimator

        return quantized, loss, encoding_indices

# Function to load and transform image
def load_and_transform_image(image_path, transform):
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)  # Add batch dimension

# Interpolation function
def interpolate(model, image1, image2, num_steps=10):
    model.eval()
    image1 = image1.to(device)
    image2 = image2.to(device)

    with torch.no_grad():
        # Encode both images
        latent1 = model.encoder(image1)
        latent2 = model.encoder(image2)

        # Quantize both latents
        quantized1, _, _ = model.vq_layer(latent1)
        quantized2, _, _ = model.vq_layer(latent2)

        # Interpolate between quantized latents
        interpolated_images = []
        for alpha in np.linspace(0, 1, num_steps):
            interpolated_latent = (1 - alpha) * quantized1 + alpha * quantized2
            decoded_image = model.decoder(interpolated_latent)
            interpolated_images.append(decoded_image.squeeze(0))

    return interpolated_images

# Parameters
num_steps = 10
output_dir = 'path_to_output_directory'  # Output directory for saving results
os.makedirs(output_dir, exist_ok=True)

# Path to the saved model
model_save_path = os.path.join(output_dir, "vqvae_model.pth")

# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_channels = 128
embedding_dim = 64
num_embeddings = 512
commitment_cost = 0.25

model = VQVAE(in_channels=3, hidden_channels=hidden_channels, num_embeddings=num_embeddings,
              embedding_dim=embedding_dim, commitment_cost=commitment_cost).to(device)
model.load_state_dict(torch.load(model_save_path, map_location=device))

# Transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Image folders for interpolation
folder_paths = [
    'path_to_folder_1',  # Replace with actual folder path
    'path_to_folder_2'   # Replace with actual folder path
]

# Collect all image paths
all_image_paths = []
for folder_path in folder_paths:
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.endswith(('.png', '.jpg', '.jpeg')):
                all_image_paths.append(os.path.join(root, file))

# Randomly select pairs of images
num_pairs = 5
image_pairs = random.sample(list(zip(all_image_paths, all_image_paths)), num_pairs)

# Interpolate and save images
for idx, (image1_path, image2_path) in enumerate(image_pairs):
    image1 = load_and_transform_image(image1_path, transform)
    image2 = load_and_transform_image(image2_path, transform)

    interpolated_images = interpolate(model, image1, image2, num_steps=num_steps)

    # Visualize interpolation
    grid = make_grid(interpolated_images, nrow=num_steps, normalize=True, value_range=(-1, 1))
    plt.figure(figsize=(20, 4))
    plt.axis('off')
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.title(f"Interpolation {idx + 1}")
    plt.show()

    # Save interpolated images
    save_image(grid, os.path.join(output_dir, f"interpolation_{idx + 1}.png"))