# Data Preparation for Diffusion Models

This notebook covers the following steps:
1. Importing necessary libraries
2. Defining data preparation functions
3. Preparing and cleaning CelebA and Flowers102 datasets
4. Splitting data into training, validation, and test sets
5. Saving prepared datasets

In [None]:
import os
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import CelebA, Flowers102
from tqdm import tqdm

In [None]:
def prepare_data(dataset_path, is_flowers=False, image_size=64, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    if is_flowers:
        dataset = Flowers102(root=dataset_path, download=True, transform=transform)
    else:
        dataset = CelebA(root=dataset_path, download=True, transform=transform)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    return dataloader, dataset

def clean_data(dataloader):
    cleaned_images = []
    cleaned_labels = []

    for batch in tqdm(dataloader, desc="Cleaning data"):
        images, labels = batch  # Unpack the batch into images and labels
    
        # Apply isfinite() to the images tensor
        mask = torch.isfinite(images).all(dim=(1, 2, 3))
    
        # Apply the mask to both images and labels
        cleaned_images.append(images[mask])
        cleaned_labels.append(labels[mask])
    
    return torch.cat(cleaned_images, dim=0), torch.cat(cleaned_labels, dim=0)

def split_data(images, labels, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-5, "Ratios must sum to 1"
    
    dataset = torch.utils.data.TensorDataset(images, labels)
    total_size = len(dataset)
    train_size = int(train_ratio * total_size)
    val_size = int(val_ratio * total_size)
    test_size = total_size - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    
    return train_dataset, val_dataset, test_dataset

## Prepare CelebA Dataset

In [None]:
celeba_loader, celeba_dataset = prepare_data("./data", is_flowers=False)
cleaned_celeba_images, cleaned_celeba_labels = clean_data(celeba_loader)

celeba_train, celeba_val, celeba_test = split_data(cleaned_celeba_images, cleaned_celeba_labels)

print(f"CelebA - Train: {len(celeba_train)}, Validation: {len(celeba_val)}, Test: {len(celeba_test)}")

## Prepare Flowers102 Dataset

In [None]:
flowers_loader, flowers_dataset = prepare_data("./data", is_flowers=True)
cleaned_flowers_images, cleaned_flowers_labels = clean_data(flowers_loader)

flowers_train, flowers_val, flowers_test = split_data(cleaned_flowers_images, cleaned_flowers_labels)

print(f"Flowers102 - Train: {len(flowers_train)}, Validation: {len(flowers_val)}, Test: {len(flowers_test)}")

## Save Prepared Datasets

In [None]:
def save_dataset(dataset, filename):
    images, labels = [], []
    for img, lbl in dataset:
        images.append(img)
        labels.append(lbl)
    torch.save({
        'images': torch.stack(images),
        'labels': torch.stack(labels)
    }, filename)

# Save CelebA datasets
save_dataset(celeba_train, "celeba_train.pt")
save_dataset(celeba_val, "celeba_val.pt")
save_dataset(celeba_test, "celeba_test.pt")

# Save Flowers102 datasets
save_dataset(flowers_train, "flowers_train.pt")
save_dataset(flowers_val, "flowers_val.pt")
save_dataset(flowers_test, "flowers_test.pt")

print("All datasets have been saved.")