Implement a DDIM inversion method. Get the (inverted) latents for a pair of real images and plot the generated images obtained via linear
interpolation of the latents corresponding to these image

In [None]:
# Importing the necessary libraries
import torch  # Import PyTorch library
import torch.nn as nn  # Import neural network module
import torch.optim as optim  # Import optimization module
from torchvision import datasets, transforms  # Import datasets and transforms
from torchvision.utils import save_image, make_grid  # Import utility to save images
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision  # Import torchvision library
import matplotlib.pyplot as plt  # Import plotting library
import os  # Import os module for file operations
import numpy as np  # Import numpy library        nn.InstanceNorm2d(out_channels),

from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import matplotlib.pyplot as plt
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.models import inception_v3
import numpy as np
import shutil
from PIL import Image  # Import PIL for image processing
import scipy
from torchvision.models import resnet50
from tqdm.notebook import trange
from pathlib import Path
import math
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F

### Hyper-parameters

In [None]:

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_EPOCHS = 600
IMG_SIZE = 128
TRAIN_DDIM = True
BATCH_SIZE = 128
LR = 1e-4
STEP_GAMMA = 0.3
SCALING_FACTOR = 5000
BUTTERFLY = "butterfly"
ANIMAL = "animal"
dataset = os.getenv('DATASET', BUTTERFLY)  # butterfly or animal

NOISE_STEPS = 450
SAMPLING_STEPS = NOISE_STEPS

In [None]:
LOG_DIR = log_dir = f'DDIM/tensorboard/{dataset}'
IMG_DIR = f'DDIM/generated_images/{dataset}'


def recreate_directory(dir_path):
    # Check if the directory exists
    if os.path.exists(dir_path):
        # Delete the directory if it exists
        shutil.rmtree(dir_path)
    # Create the directory
    os.makedirs(dir_path)

os.makedirs('DDIM/models', exist_ok=True)

if TRAIN_DDIM:
    recreate_directory(LOG_DIR)
    recreate_directory(IMG_DIR)

In [None]:

# Tensorboard stuff
writer = SummaryWriter(LOG_DIR)


def log_losses_to_tensorboard(epoch, loss):
    writer.add_scalar('Loss/Total', loss, epoch)


def log_gradients_to_tensorboard(model, epoch):
    total_norm = 0
    for name, param in model.named_parameters():
        if param.grad is not None:
            norm = param.grad.norm(2).item()
            total_norm += norm ** 2
            writer.add_scalar(f'Gradients/{name}', norm, epoch)
    total_norm = total_norm ** 0.5
    writer.add_scalar(f'Gradients/total_norm', total_norm, epoch)

### Prepare Datasets

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        # Directory containing all images (including subfolders)
        self.root_dir = root_dir
        self.transform = transform  # Transformations to apply to images
        self.image_files = []  # List to store all image file paths

        # Traverse through all subfolders
        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_files.append(os.path.join(root, file))

    def __len__(self):
        return len(self.image_files)  # Return the total number of images

    def __getitem__(self, idx):
        img_path = self.image_files[idx]  # Get image path
        image = Image.open(img_path).convert(
            'RGB')  # Open image and convert to RGB

        if self.transform:
            image = self.transform(image)  # Apply transformations if any

        return image, 0

In [None]:
transform = transforms.Compose([
    # Randomly flip images horizontally
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    # Convert images to PyTorch tensors and scale to [0, 1]
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

butterly_dataset = ImageDataset(root_dir="data/butterfly_data", transform=transform)

animal_dataset = datasets.ImageFolder(root='data/Animals_data/animals/animals',  # Specify the root directory of the dataset
                               transform=transform)  # Apply the defined transformations to the dataset

ds = butterly_dataset if dataset == BUTTERFLY else animal_dataset


dataloader = torch.utils.data.DataLoader(ds, BATCH_SIZE, shuffle=True, num_workers=2)

### Define Model

In [None]:
class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, embed_dim):
        super(SinusoidalPositionalEmbedding, self).__init__()
        self.embed_dim = embed_dim

    def forward(self, time_step):
        half_dim = self.embed_dim // 2
        freqs = torch.exp(-math.log(SCALING_FACTOR) *
                          torch.arange(half_dim, device=time_step.device) / half_dim)
        angles = time_step[:, None] * freqs[None, :]  # Shape: [B, half_dim]
        # Shape: [B, embed_dim]
        pos_embed = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        return pos_embed


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_embed_dim):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, padding=1)
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, padding=1)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Time embedding projection to match out_channels
        self.time_dense = nn.Sequential(
            nn.Linear(time_embed_dim, out_channels),
            nn.ReLU()
        )

        # Residual connection adjustment if input and output channels differ
        self.residual_conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, time_embed):
        residual = self.residual_conv(x)

        # First convolution + norm + time embedding addition
        x = self.conv1(x)
        x = self.norm1(x)

        # Add time embedding
        # Shape: [B, out_channels, 1, 1]
        time_embed = self.time_dense(time_embed).unsqueeze(-1).unsqueeze(-1)
        x = x + time_embed

        # Activation
        x = self.relu(x)

        # Second convolution + norm
        x = self.conv2(x)
        x = self.norm2(x)

        # Add residual and apply activation
        return self.relu(x + residual)


class MultiHeadAttention(nn.Module):
    def __init__(self, in_channels, num_heads=4):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = in_channels // num_heads
        assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads"

        self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.out = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        self.softmax = nn.Softmax(dim=-1)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        B, C, H, W = x.size()
        # (B, num_heads, head_dim, H*W)
        query = self.query(x).view(B, self.num_heads, self.head_dim, -1)
        # (B, num_heads, head_dim, H*W)
        key = self.key(x).view(B, self.num_heads, self.head_dim, -1)
        # (B, num_heads, head_dim, H*W)
        value = self.value(x).view(B, self.num_heads, self.head_dim, -1)

        # Compute attention scores
        # (B, num_heads, H*W, H*W)
        attn_scores = torch.einsum('bhqd,bhkd->bhqk', query, key) * self.scale
        attn = self.softmax(attn_scores)

        # Apply attention to values
        # (B, num_heads, head_dim, H*W)
        out = torch.einsum('bhqk,bhvd->bhqd', attn, value)
        out = out.contiguous().view(B, C, H, W)  # Reshape back to original dims
        return self.out(out) + x  # Residual connection


class UNet(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, base_filters=64, embed_dim=32):
        super(UNet, self).__init__()
        self.time_embedding = SinusoidalPositionalEmbedding(embed_dim)

        # Encoder
        self.enc1 = ResidualBlock(input_channels, base_filters, embed_dim)
        self.enc2 = ResidualBlock(base_filters, base_filters * 2, embed_dim)
        self.enc3 = ResidualBlock(
            base_filters * 2, base_filters * 4, embed_dim)
        self.enc4 = ResidualBlock(
            base_filters * 4, base_filters * 8, embed_dim)

        self.attn1 = MultiHeadAttention(base_filters * 8)

        # Bottleneck
        self.bottleneck = ResidualBlock(
            base_filters * 8, base_filters * 16, embed_dim)
        self.attn2 = MultiHeadAttention(base_filters * 16)

        # Decoder
        self.dec4 = ResidualBlock(
            base_filters * 16, base_filters * 8, embed_dim)
        self.attn3 = MultiHeadAttention(base_filters * 8)

        self.dec3 = ResidualBlock(
            base_filters*4 + base_filters * 8, base_filters * 4, embed_dim)
        self.dec2 = ResidualBlock(
            base_filters * 2+base_filters * 4, base_filters * 2, embed_dim)
        self.dec1 = ResidualBlock(
            base_filters + base_filters * 2, base_filters, embed_dim)

        # Final Output
        self.final = nn.Conv2d(base_filters, output_channels, kernel_size=1)

        # MaxPooling and Upsampling
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.ConvTranspose2d(
            base_filters * 16, base_filters * 8, kernel_size=2, stride=2)

    def forward(self, x, time_step):
        # Generate time-step embedding
        time_embed = self.time_embedding(time_step)  # Shape: [B, embed_dim]

        # Encoder path
        enc1 = self.enc1(x, time_embed)
        enc2 = self.enc2(self.pool(enc1), time_embed)
        enc3 = self.enc3(self.pool(enc2), time_embed)
        enc4 = self.attn1(self.enc4(self.pool(enc3), time_embed))

        # Bottleneck
        bottleneck = self.attn2(self.bottleneck(self.pool(enc4), time_embed))

        # Decoder path
        dec4 = self.attn3(self.dec4(
            torch.cat([self.upsample(bottleneck), enc4], dim=1), time_embed))
        dec3 = self.dec3(
            torch.cat([F.interpolate(dec4, scale_factor=2), enc3], dim=1), time_embed)
        dec2 = self.dec2(
            torch.cat([F.interpolate(dec3, scale_factor=2), enc2], dim=1), time_embed)
        dec1 = self.dec1(
            torch.cat([F.interpolate(dec2, scale_factor=2), enc1], dim=1), time_embed)

        return self.final(dec1)

In [None]:
images, labels = next(iter(DataLoader(ds, batch_size=1)))
t = torch.tensor([1])
summary(UNet(), input_data=[images, t])

### Diffusion

In [None]:
betas = torch.linspace(1e-4, 0.02, NOISE_STEPS).to(device)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

In [None]:
def get_index_from_list(vals, t, x_shape):
    batch_size = t.shape[0]
    out = vals.gather(-1, t)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))


def forward_diffusion(x0: torch.Tensor, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    noise = torch.randn_like(x0)
    sqrt_alphas_cumprod_t = get_index_from_list(
        sqrt_alphas_cumprod, t, x0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x0.shape)
    xt = sqrt_alphas_cumprod_t * x0 + sqrt_one_minus_alphas_cumprod_t * noise
    return xt, noise


def ddim_sample(
    model: UNet,
    x_T: torch.Tensor,
    sampling_steps: int,
):
    model.eval()
    # Get sampling timesteps (you can choose these differently)
    steps = torch.linspace(0, NOISE_STEPS - 1,
                           sampling_steps).flip(0).int().to(device)

    x_t = x_T
    for i, step in enumerate(steps):
        if i == len(steps) - 1:
            step_prev = 0
        else:
            step_prev = steps[i + 1]

        alpha_cumprod_t = alphas_cumprod[step]
        alpha_cumprod_t_prev = alphas_cumprod[step_prev]

        # Predict noise
        t_batch = step.repeat(x_t.shape[0])
        with torch.no_grad():
            predicted_noise = model(x_t, t_batch)

        # Get the predicted x0
        sqrt_alpha_cumprod_t = torch.sqrt(alpha_cumprod_t)
        sqrt_one_minus_alpha_cumprod_t = torch.sqrt(1 - alpha_cumprod_t)
        predicted_x0 = (x_t - sqrt_one_minus_alpha_cumprod_t *
                        predicted_noise) / sqrt_alpha_cumprod_t

        # DDIM formula
        sqrt_alpha_cumprod_t_prev = torch.sqrt(alpha_cumprod_t_prev)
        sqrt_one_minus_alpha_cumprod_t_prev = torch.sqrt(
            1 - alpha_cumprod_t_prev)

        x_t = sqrt_alpha_cumprod_t_prev * predicted_x0 + \
            sqrt_one_minus_alpha_cumprod_t_prev * predicted_noise

    return torch.clamp(x_t, -1, 1)

### Train DDIM

In [None]:
def generate_and_save_images(model: UNet, epoch_index, img_dir, display=False):
    n_rows = 5
    # Generate and save images for the current epoch
    with torch.no_grad():  # Disable gradient calculation for inference
        model.eval()
        # Sample 4 images
        x_T = torch.randn(n_rows*n_rows, 3, IMG_SIZE, IMG_SIZE).to(device)
        samples = ddim_sample(model, x_T, sampling_steps=SAMPLING_STEPS)

    if display:
        # Function to display grid using matplotlib
        def show_images(tensor, title):
            tensor = (tensor + 1) / 2.0
            grid = make_grid(tensor, nrow=n_rows, padding=2,
                             normalize=True)  # Create grid
            plt.figure(figsize=(10, 10))
            # Convert to numpy for plotting
            plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
            plt.axis('off')
            plt.title(title)
            plt.show()

        show_images(samples, f'Generated Images - Epoch {epoch_index + 1}')
    else:

        # Save generated images
        image_path_gen = f'{img_dir}/generated_{epoch_index + 1}.png'
        save_image(make_grid(samples, nrow=n_rows, padding=2,
                   normalize=True), image_path_gen)

In [None]:
def train_ddim(
    num_epochs: int
):
    model = torch.nn.DataParallel(UNet().to(device))
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=num_epochs//4, gamma=STEP_GAMMA)

    for epoch in trange(num_epochs):
        model.train()
        for images, _ in dataloader:
            images = images.to(device)
            optimizer.zero_grad()

            # Sample t uniformly
            t = torch.randint(0, NOISE_STEPS, (images.shape[0],)).to(device)

            # Get noisy image and noise
            x_t, noise = forward_diffusion(images, t)

            # Predict noise
            predicted_noise = model(x_t, t)

            # Calculate loss
            loss = torch.nn.functional.mse_loss(predicted_noise, noise)

            loss.backward()
            optimizer.step()

        scheduler.step()
        log_gradients_to_tensorboard(model, epoch)

        if epoch % 20 == 0:
            print(f"Epoch [{epoch}/{num_epochs}]  loss: {loss.item():.4f}")
            generate_and_save_images(model, epoch, img_dir=IMG_DIR)

        log_losses_to_tensorboard(
            epoch, loss.item())

    print("Training completed.")
    return model


ddim_model_path = f'DDIM/models/model_{dataset}.pth'

if TRAIN_DDIM:
    model = train_ddim(num_epochs=NUM_EPOCHS)
    torch.save(model.state_dict(), ddim_model_path)
    print(f"Models saved in {ddim_model_path}")
else:
    model = torch.nn.DataParallel(UNet().to(device))
    model.load_state_dict(torch.load(ddim_model_path, weights_only=True))

In [None]:
generate_and_save_images(model, NUM_EPOCHS, img_dir=IMG_DIR, display=True)

### DDIM Inversion

In [None]:
def ddim_inversion(
    unet: nn.Module,
    x0: torch.Tensor,
    num_steps: int,
) -> torch.Tensor:

    # Initialize x_t with x0
    x_t = x0

    # DDIM inversion steps
    for t in range(num_steps):
        t_tensor = torch.full(
            (x0.shape[0],), t, device=device, dtype=torch.long)

        # Predict noise
        with torch.no_grad():
            predicted_noise = unet(x_t, t_tensor)

        # If not the final step
        if t < num_steps - 1:
            # Current alpha and sigma
            alpha_t = sqrt_alphas_cumprod[t]
            sigma_t = sqrt_one_minus_alphas_cumprod[t]

            # Next timestep's alpha and sigma
            alpha_next = sqrt_alphas_cumprod[t + 1]
            sigma_next = sqrt_one_minus_alphas_cumprod[t + 1]

            # Predict x0
            pred_x0 = (x_t - sigma_t * predicted_noise) / alpha_t

            # DDIM update step
            x_t = alpha_next * pred_x0 + sigma_next * predicted_noise

    return x_t

In [None]:
def sample_images(ds, num_samples=2):
    loader = DataLoader(ds, batch_size=num_samples, shuffle=True)
    real_images, _ = next(iter(loader))
    return real_images.to(device)

# Function to linearly interpolate latents
def interpolate_latents(latent1, latent2, num_steps=10):
    interpolations = torch.stack([latent1 * (1 - alpha) + latent2 * alpha for alpha in np.linspace(0, 1, num_steps)])
    return interpolations

In [None]:
real_images = sample_images(ds)
latents = ddim_inversion(model, real_images, num_steps=NOISE_STEPS)
latent1, latent2 = latents[0], latents[1]
latents = interpolate_latents(latent1, latent2)
generated_images = ddim_sample(model, latents, sampling_steps=SAMPLING_STEPS)


def plot_images(real_images, generated_images):
    # Concatenate real and interpolated images for plotting
    images_to_plot = torch.cat(
        [real_images[0:1], generated_images, real_images[1:2]], dim=0)

    images = images_to_plot.cpu().permute(0, 2, 3, 1).numpy()  # Convert to HWC
    fig, axes = plt.subplots(1, len(images), figsize=(15, 2))

    for i, img in enumerate(images):
        axes[i].imshow((img + 1) * 0.5)  # Assuming normalization [-1, 1]
        axes[i].axis('off')

    plt.suptitle("DDIM Inversion Interpolation")
    plt.show()


# Call plot_images
plot_images(real_images, generated_images)