# Unet Pretraining

## Important note:
**Multithreading does not work in jupyter notebooks, so its disabled here. This file is meant as an overview of the training process, it works but since theres no concurrency it takes aproximately 4 times as long to train**

This notebook can be used to follow the pretraining steps of the base unet for our final model

## Dependencies

In [None]:
# Uncomment the line below if you haven't installed the required libraries
#!pip install -q wandb torch torchvision numpy<2.0.0 pillow

## Imports

In [None]:
import os
import gc
import time
import wandb
import numpy as np
from PIL import Image
from dataclasses import dataclass, asdict
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import LambdaLR

## Utilities

This cell contains utilities for converting from RGB to LAB colorspace using torch to speed up operations at training and inference time. The code and conversions were taken from skimage and modified to work with torch.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@torch.jit.script
def rgb_to_lab(rgb_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    with torch.no_grad():
        RGB_TO_XYZ = torch.tensor(
            [
                [0.412453, 0.357580, 0.180423],
                [0.212671, 0.715160, 0.072169],
                [0.019334, 0.119193, 0.950227],
            ],
            dtype=torch.float32,
            device=rgb_image.device,
        )
        XYZ_REF = torch.tensor(
            [0.95047, 1.0, 1.08883], dtype=torch.float32, device=rgb_image.device
        )

        # Convert RGB to linear RGB
        if rgb_image.max() > 1.0:
            rgb_image = rgb_image / 255.0

        mask = rgb_image > 0.04045
        rgb_linear = torch.where(
            mask, torch.pow(((rgb_image + 0.055) / 1.055), 2.4), rgb_image / 12.92
        )

        # Convert linear RGB to XYZ
        xyz = torch.matmul(rgb_linear, RGB_TO_XYZ.t())

        # Normalize XYZ values
        xyz_scaled = xyz / XYZ_REF

        # XYZ to LAB conversion
        epsilon = 0.008856
        kappa = 903.3

        f = torch.where(
            xyz > epsilon, xyz_scaled.pow(1 / 3), (kappa * xyz_scaled + 16) / 116
        )

        x, y, z = f[..., 0], f[..., 1], f[..., 2]
        l = (116 * y - 16).unsqueeze(0)
        a = (500 * (x - y)).unsqueeze(0)
        b = (200 * (y - z)).unsqueeze(0)
        ab = torch.cat([a, b], dim=0)

        # Normalize to [-1, 1]
        l = (l / 50.0) - 1.0
        ab = ab / 110.0

        return l, ab


@torch.jit.script
def lab_to_rgb(l: torch.Tensor, ab: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        XYZ_TO_RGB = torch.tensor(
            [
                [3.24048134, -1.53715152, -0.49853633],
                [-0.96925495, 1.87599, 0.04155593],
                [0.05564664, -0.20404134, 1.05731107],
            ],
            dtype=torch.float32,
            device=l.device,
        )
        XYZ_REF = torch.tensor(
            [0.95047, 1.0, 1.08883], dtype=torch.float32, device=l.device
        ).view(1, 3, 1, 1)

        if l.dim() == 3:
            l = l.unsqueeze(0)
        if ab.dim() == 3:
            ab = ab.unsqueeze(0)

        # Denormalize from [-1, 1]
        l = (l + 1.0) * 50.0
        ab = ab * 110.0

        y = (l + 16) / 116
        x = ab[:, 0:1] / 500 + y
        z = y - ab[:, 1:2] / 200

        xyz = torch.cat([x, y, z], dim=1)

        mask = xyz > 0.2068966
        xyz = torch.where(mask, xyz.pow(3), (xyz - 16 / 116) / 7.787)

        xyz = xyz * XYZ_REF

        batch_size, _, height, width = xyz.shape
        xyz_reshaped = xyz.view(batch_size, 3, -1)

        rgb_linear = torch.bmm(XYZ_TO_RGB.expand(batch_size, -1, -1), xyz_reshaped)

        rgb_linear = rgb_linear.view(batch_size, 3, height, width)

        mask = rgb_linear > 0.0031308
        rgb = torch.where(
            mask, 1.055 * rgb_linear.pow(1 / 2.4) - 0.055, 12.92 * rgb_linear
        )

        return rgb.clamp(0, 1).permute(0, 2, 3, 1)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Dataset Loading

In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        for subdir, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith((".jpg")):
                    self.image_paths.append(os.path.join(subdir, file))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)
        image = torch.from_numpy(np.array(image).astype(np.float32))
        l_chan, ab_chan = rgb_to_lab(image)
        return l_chan, ab_chan


class FastColorizationDataLoader(DataLoader):
    def __init__(self, dataset, batch_size, shuffle, num_workers, pin_memory):
        super().__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            # Multiprocessing is not supported in Jupyter Notebooks, so we disable it
            #num_workers=num_workers,
            pin_memory=pin_memory,
        )
        self.stream = torch.cuda.Stream()

    def __iter__(self):
        iterator = super().__iter__()
        for data in iterator:
            with torch.cuda.stream(self.stream):
                yield [item.cuda(non_blocking=True) for item in data]

## Model Settings

In [None]:
@dataclass
class ModelSettings:
    finetune: bool = False
    pretrained_model_path: str = ""
    root_dir: str = "img_data"
    finetune_learning_rate: float = 0.00035
    validation_image_count: int = 1024
    batch_size: int = 64
    learning_rate: float = 0.0007  # Peak LR
    min_lr: float = learning_rate / 10  # Minimum LR
    weight_decay: float = 1e-5
    warmup_steps: int = 1000  # Linear warmup over n steps
    validation_steps: int = 1000  # Validate every n steps
    display_imgs: int = 12
    early_stopping_patience: int = 5
    loss_function: str = "L1Loss"
    optimizer: str = "AdamW"
    model_name: str = "Unet"

    def set_total_image_count(self, count):
        self.total_image_count = count
        self.num_steps = (self.total_image_count - self.validation_image_count) // self.batch_size

    def __post_init__(self):
        self.total_image_count = None
        self.num_steps = None

    def create_run_name(self):
        prefix = "finetune_" if self.finetune else ""
        learning_rate = self.finetune_learning_rate if self.finetune else self.learning_rate
        return f"{prefix}{self.model_name}_lr{learning_rate}_bs{self.batch_size}_steps{self.num_steps}_loss{self.loss_function}_opt{self.optimizer}"

settings = ModelSettings()
dataset = ColorizationDataset(root_dir=settings.root_dir)
print(f"Total images: {len(dataset)}")
settings.set_total_image_count(len(dataset))

## Unet Model

This is the unet model used for training, it can be initialized with or without Criss Cross self attention. The Criss Cross attention is a modification of the standard transformer attention that only attends to the pixels in the same row and column as the pixel in question, which drastically reduces the memory requirements. After two iterations of attention the model has at least indirectly attended to every pixel.

In [None]:
class CrissCrossAttention(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.channel_in = in_dim
        self.channel_out = in_dim // 8
        # Combined QKV projection
        self.qkv_conv = nn.Conv2d(in_dim, 3 * self.channel_out, 1)

        self.gamma = nn.Parameter(torch.zeros(1))
        self.out_conv = nn.Conv2d(in_dim // 8, in_dim, 1)

    def forward(self, x):
        B, C, H, W = x.size()

        # Combined QKV projection
        qkv = self.qkv_conv(x)
        query, key, value = qkv.chunk(3, dim=1)

        # Horizontal attention
        query_h = (
            query.permute(0, 3, 1, 2).contiguous().view(B * W, self.channel_out, H)
        )
        key_h = key.permute(0, 3, 1, 2).contiguous().view(B * W, self.channel_out, H)
        value_h = (
            value.permute(0, 3, 1, 2).contiguous().view(B * W, self.channel_out, H)
        )

        energy_h = torch.bmm(query_h, key_h.transpose(1, 2))
        attn_h = F.softmax(energy_h, dim=-1)
        out_h = torch.bmm(attn_h, value_h)

        # Vertical attention
        query_v = (
            query.permute(0, 2, 1, 3).contiguous().view(B * H, self.channel_out, W)
        )
        key_v = key.permute(0, 2, 1, 3).contiguous().view(B * H, self.channel_out, W)
        value_v = (
            value.permute(0, 2, 1, 3).contiguous().view(B * H, self.channel_out, W)
        )

        energy_v = torch.bmm(query_v, key_v.transpose(1, 2))
        attn_v = F.softmax(energy_v, dim=-1)
        out_v = torch.bmm(attn_v, value_v)

        # Reshape and combine
        out_h = out_h.view(B, W, self.channel_out, H).permute(0, 2, 3, 1)
        out_v = out_v.view(B, H, self.channel_out, W).permute(0, 2, 1, 3)

        out = out_h + out_v
        out = self.gamma * out

        # Project back to the original channel dimension
        out = self.out_conv(out)

        return out + x


class UnetBlock(nn.Module):
    def __init__(
        self,
        nf,
        ni,
        submodule=None,
        input_c=None,
        dropout=False,
        innermost=False,
        outermost=False,
        use_attention=True,
    ):
        super().__init__()
        self.outermost = outermost
        self.use_attention = use_attention
        if input_c is None:
            input_c = nf
        downconv = nn.Conv2d(
            input_c, ni, kernel_size=4, stride=2, padding=1, bias=False
        )
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)

        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(
                ni, nf, kernel_size=4, stride=2, padding=1, bias=False
            )
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(
                ni * 2, nf, kernel_size=4, stride=2, padding=1, bias=False
            )
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout:
                up += [nn.Dropout(0.5)]
            model = down + [submodule] + up

        self.model = nn.Sequential(*model)
        # Apply attention module if needed, except for the outermost layer and when the number of filters is too high
        self.use_attention_this_layer = not outermost and nf <= 512 and use_attention
        if self.use_attention_this_layer:
            self.attention = CrissCrossAttention(nf)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            output = self.model(x)
            if self.use_attention_this_layer:
                output = self.attention(output)
            return torch.cat([x, output], 1)


class Unet(nn.Module):
    def __init__(
        self,
        input_c=1,
        output_c=2,
        n_down=8,
        num_filters=64,
        use_attention=True,
        use_dropout=False,
    ):
        super().__init__()
        unet_block = UnetBlock(
            num_filters * 8,
            num_filters * 8,
            innermost=True,
            use_attention=use_attention,
        )

        for _ in range(n_down - 5):
            unet_block = UnetBlock(
                num_filters * 8,
                num_filters * 8,
                submodule=unet_block,
                dropout=use_dropout,
                use_attention=use_attention,
            )

        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(
                out_filters // 2,
                out_filters,
                submodule=unet_block,
                use_attention=use_attention,
            )
            out_filters //= 2

        self.model = UnetBlock(
            output_c,
            out_filters,
            input_c=input_c,
            submodule=unet_block,
            outermost=True,
            use_attention=use_attention,
        )

    def forward(self, x):
        return self.model(x)

## Learing Rate

We used a custom learing rate schedule, a linear warmup followed by cosine decay

In [None]:
def lr_lambda(current_step: int):
    if current_step < settings.warmup_steps:
        return float(current_step) / float(max(1, settings.warmup_steps))
    cosine_decay = 0.5 * (
        1
        + np.cos(
            np.pi
            * (current_step - settings.warmup_steps)
            / (settings.num_steps - settings.warmup_steps)
        )
    )
    return (
        settings.min_lr + (settings.learning_rate - settings.min_lr) * cosine_decay
    ) / settings.learning_rate

## Validation funtion

If wandb is used this logs sample images. 
An example of the wandb log can be viewed [here](https://wandb.ai/danielpwarren/unet-colorizer/runs/9r7ls909)

In [None]:
def validate_model(model, test_dataloader, criterion):
    torch.cuda.empty_cache()
    gc.collect()
    val_start_time = time.time()
    model.eval()
    val_loss = 0.0
    total_images = 0
    logged_images = 0

    with torch.no_grad():
        for l_chan, ab_chan in test_dataloader:
            with autocast():
                outputs = model(l_chan)
                loss = criterion(outputs, ab_chan)

            val_loss += loss.item()
            total_images += l_chan.size(0)

            if logged_images < settings.display_imgs:
                num_samples = min(settings.display_imgs - logged_images, l_chan.shape[0])
                l_chan_samples = l_chan[:num_samples]
                output_samples = outputs[:num_samples]
                target_samples = ab_chan[:num_samples]

                l_rgb_samples = lab_to_rgb(l_chan_samples, torch.zeros_like(output_samples))
                output_rgb_samples = lab_to_rgb(l_chan_samples, output_samples)
                target_rgb_samples = lab_to_rgb(l_chan_samples, target_samples)

                l_rgb_samples = [sample.detach().cpu().numpy() for sample in l_rgb_samples]
                output_rgb_samples = [sample.detach().cpu().numpy() for sample in output_rgb_samples]
                target_rgb_samples = [sample.detach().cpu().numpy() for sample in target_rgb_samples]

                stacked_L_rgb = np.hstack(l_rgb_samples)
                stacked_output_rgb = np.hstack(output_rgb_samples)
                stacked_target_rgb = np.hstack(target_rgb_samples)

                stacked_images = np.vstack(
                    (stacked_L_rgb, stacked_output_rgb, stacked_target_rgb)
                )

                wandb.log(
                    {
                        "Examples": wandb.Image(
                            stacked_images,
                            caption="For each column: Top: Grayscale, Middle: Predicted, Bottom: True",
                        )
                    },
                    commit=False,
                )

                logged_images += l_chan.size(0)

    avg_val_loss = val_loss / (total_images / settings.batch_size)
    val_time = time.time() - val_start_time
    print(
        f"Average Validation Loss: {avg_val_loss:.4f}, Validation Time: {val_time:.4f}s"
    )
    wandb.log(
        {"Average Validation Loss": avg_val_loss, "Validation Time": val_time},
        commit=False,
    )
    return avg_val_loss

## Training function

Calls validation function every `settings.validation_steps` steps, keeps track of validation improvement, and saves the model with the best validation loss. If there are `settings.early_stopping_patience` validation steps without improvement the model will stop training.

In [None]:
def train_model(
    model,
    train_dataloader,
    test_dataloader,
    criterion,
    optimizer,
    scheduler,
    num_steps,
    validation_steps,
):
    torch.cuda.synchronize()
    start_time = time.time()
    model.train()
    running_loss = 0.0
    step_times = []
    best_val_loss = float("inf")
    patience_counter = 0

    train_iter = iter(train_dataloader)

    for step in range(num_steps):
        try:
            l_chan, ab_chan = next(train_iter)
        except StopIteration:
            train_iter = iter(train_dataloader)
            l_chan, ab_chan = next(train_iter)

        l_chan, ab_chan = l_chan.to("cuda"), ab_chan.to("cuda")

        with autocast():
            outputs = model(l_chan)
            loss = criterion(outputs, ab_chan)

        scaler.scale(loss).backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        step_end_time = time.time()
        step_times.append(step_end_time)

        running_loss += loss.item()

        if step > 0:
            step_time = step_times[-1] - step_times[-2]
        else:
            step_time = step_end_time - start_time

        total_time_spent = step_end_time - start_time
        avg_step_time = total_time_spent / (step + 1)
        etc_seconds = avg_step_time * (num_steps - (step + 1))
        etc_hours = int(etc_seconds // 3600)
        etc_minutes = int((etc_seconds % 3600) // 60)
        time_to_next_checkpoint = avg_step_time * (
            validation_steps - (step % validation_steps)
        )

        print(
            f"Step [{step + 1}/{num_steps}], "
            f"Loss: {loss.item():.4f}, "
            f"Step Time: {step_time:.4f}s, "
            f"ETC: {etc_hours}h {etc_minutes}m, "
            f"Next Checkpoint: {time_to_next_checkpoint/60:.2f} minutes"
        )
        wandb.log(
            {
                "Training Loss": loss.item(),
                "Step": step + 1,
                "Step Time": step_time,
                "Learning Rate": scheduler.get_last_lr()[0],
                "ETC (hours)": etc_seconds / 3600,
                "Time to Next Checkpoint (minutes)": time_to_next_checkpoint / 60,
            }
        )

        if (step + 1) % validation_steps == 0:
            val_loss = validate_model(model, test_dataloader, criterion)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                torch.save(
                    {
                        "step": step + 1,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "loss": loss,
                    },
                    f"best_checkpoint_unet.pth",
                )
            else:
                patience_counter += 1
                if patience_counter >= settings.early_stopping_patience:
                    print(
                        f"Early stopping triggered after {patience_counter} validations without improvement."
                    )
                    break

            torch.save(
                {
                    "step": step + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "loss": loss,
                },
                f"checkpoint_unet.pth",
            )

    torch.cuda.synchronize()
    avg_train_loss = running_loss / num_steps
    total_time = time.time() - start_time

    print(f"Average Training Loss: {avg_train_loss:.4f}, Total Time: {total_time:.4f}s")
    wandb.log({"Average Training Loss": avg_train_loss, "Total Time": total_time})

    validate_model(model, test_dataloader, criterion)

    torch.save(model.state_dict(), "model_unet_final.pth")

## Begin training

All thats left now is to train the model! Live stats can be viewed on the wandb page if in use.

In [None]:
if __name__ == "__main__":
    wandb.init(project="unet-colorizer")

    wandb.config.update(asdict(settings))
    wandb.run.name = settings.create_run_name()

    print(f"Using device: {device}")
    print(f"Total images: {settings.total_image_count}")
    print(f"Total val images: {settings.validation_image_count}")

    train_size = settings.total_image_count - settings.validation_image_count
    test_size = settings.validation_image_count
    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, test_size]
    )

    train_dataloader = FastColorizationDataLoader(
        train_dataset,
        batch_size=settings.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    test_dataloader = FastColorizationDataLoader(
        test_dataset,
        batch_size=settings.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    model = Unet().to(device)

    if settings.finetune:
        print(f"Loading pre-trained model from {settings.pretrained_model_path}")
        model.load_state_dict(torch.load(settings.pretrained_model_path))

    print(f"The model has {count_parameters(model):,} trainable parameters")
    wandb.watch(model)

    print(model)

    scaler = GradScaler()
    criterion = getattr(nn, settings.loss_function)()

    if settings.finetune:
        optimizer = getattr(optim, settings.optimizer)(
            model.parameters(),
            lr=settings.finetune_learning_rate,
            weight_decay=settings.weight_decay,
        )
    else:
        optimizer = getattr(optim, settings.optimizer)(
            model.parameters(),
            lr=settings.learning_rate,
            weight_decay=settings.weight_decay,
        )

    scheduler = LambdaLR(optimizer, lr_lambda)

    print("Starting training...")
    train_model(
        model,
        train_dataloader,
        test_dataloader,
        criterion,
        optimizer,
        scheduler,
        settings.num_steps,
        settings.validation_steps,
    )