# Imports and Consts

In [None]:
IS_IN_COLAB = False
DATASET_PATH = "../Dataset/bedroom/"

In [None]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from Enum import enum
from torchvision.utils import save_image

# Functions

In [None]:
class LSUNBedroomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []

        # Traverse the directory to get image paths
        for subdir, _, files in os.walk(self.root_dir):
            for file in files:
                if file.endswith(("png", "jpg", "jpeg")):
                    self.image_paths.append(os.path.join(subdir, file))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image


class ModelType(enum):
    UNet = 1


class NoiseSchedulerType(enum):
    DDPMSched = 1


class OptimizerType(enum):
    Adam = 1


def namestr(obj, namespace) -> str:
    """Get name of a variable as a string"""
    return [name for name in namespace if namespace[name] is obj]


def getOptimizer(optimzierType: OptimizerType, model, params):
    match optimzierType:
        case OptimizerType.Adam:
            return torch.optim.Adam(model.parameters(), **params)


def getScheduler(schedulerType: NoiseSchedulerType, params):
    match schedulerType:
        case NoiseSchedulerType.DDPMSched:
            return DDPMScheduler(**params)


def getModel(modelType: ModelType, device):
    match modelType:
        case ModelType.UNet:
            model = UNet2DModel(
                sample_size=32,  # the target image resolution
                in_channels=3,  # the number of input channels, 3 for RGB images
                out_channels=3,  # the number of output channels
                layers_per_block=1,
                block_out_channels=(32, 64, 128),
                down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"),
                up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D"),
            )
            model.to(device)
            return model


def trainModel(model, num_epochs, dataloader, optimizer, device, noise_scheduler):
    for epoch in range(num_epochs):
        model.train()
        for batch in tqdm(dataloader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()

            batch = batch.to(device)
            timesteps = torch.randint(
                0, noise_scheduler.num_train_timesteps, (batch.size(0),), device=device
            ).long()
            noise = torch.randn_like(batch)
            noisy_images = noise_scheduler.add_noise(batch, noise, timesteps)

            noise_pred = model(noisy_images, timesteps).sample

            loss = F.mse_loss(noise_pred, noise)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1} completed. Loss: {loss.item()}")


def generateImages(model, device, noise_scheduler, num_epochs, output_dir, num_images):
    model.eval()

    generated_images = []

    with torch.no_grad():
        for _ in tqdm(range(num_images), desc="Generating Images"):
            noisy_image = torch.randn(1, 3, 32, 32, device=device)
            for t in reversed(
                range(noise_scheduler.config.num_train_timesteps)
            ):  # Access via config
                timesteps = torch.full((1,), t, device=device, dtype=torch.long)
                model_output = model(noisy_image, timesteps)

                step_result = noise_scheduler.step(model_output.sample, t, noisy_image)
                noisy_image = step_result.prev_sample

            generated_image = noisy_image.squeeze(0).cpu()
            generated_images.append(generated_image)

    # Saving generated images
    output_dir = f"generated_images_{num_epochs}_epochs"
    os.makedirs(output_dir, exist_ok=True)

    for idx, image in enumerate(generated_images):
        save_image(image, f"{output_dir}/generated_image_{namestr(model)}_{idx+1}.png")

    print(f"{num_images} images generated and saved in {output_dir}")

# Experiments

## Prepare the dataset

In [None]:
# Define transforms
transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])

# Create dataset
dataset = LSUNBedroomDataset(root_dir="DATASET_PATH", transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

## UNet2D

In [None]:
# Parameters

In [None]:
# Train

In [1]:
# Evaluate and Generate images