# Unet GAN Finetuning

## Important note:
**Multithreading does not work in jupyter notebooks, so its disabled here. This file is meant as an overview of the training process, it works but since theres no concurrency it takes aproximately 4 times as long to train**

This notebook can be used to follow the finetunings steps of the unet for our final model, in the `training_scripts` subfolder of this repository you can find the multithreaded version of this script.

## Dependencies

In [None]:
# Uncomment the line below if you haven't installed the required libraries
#!pip install -q wandb torch torchvision numpy<2.0.0 pillow

## Imports

In [None]:
import os
import time
import wandb
import numpy as np
from PIL import Image
from dataclasses import dataclass, asdict
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

## Utilities

This cell contains utilities for converting from RGB to LAB colorspace using torch to speed up operations at training and inference time. The code and conversions were taken from skimage and modified to work with torch.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@torch.jit.script
def rgb_to_lab(rgb_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    with torch.no_grad():
        RGB_TO_XYZ = torch.tensor(
            [
                [0.412453, 0.357580, 0.180423],
                [0.212671, 0.715160, 0.072169],
                [0.019334, 0.119193, 0.950227],
            ],
            dtype=torch.float32,
            device=rgb_image.device,
        )
        XYZ_REF = torch.tensor(
            [0.95047, 1.0, 1.08883], dtype=torch.float32, device=rgb_image.device
        )

        # Convert RGB to linear RGB
        if rgb_image.max() > 1.0:
            rgb_image = rgb_image / 255.0

        mask = rgb_image > 0.04045
        rgb_linear = torch.where(
            mask, torch.pow(((rgb_image + 0.055) / 1.055), 2.4), rgb_image / 12.92
        )

        # Convert linear RGB to XYZ
        xyz = torch.matmul(rgb_linear, RGB_TO_XYZ.t())

        # Normalize XYZ values
        xyz_scaled = xyz / XYZ_REF

        # XYZ to LAB conversion
        epsilon = 0.008856
        kappa = 903.3

        f = torch.where(
            xyz > epsilon, xyz_scaled.pow(1 / 3), (kappa * xyz_scaled + 16) / 116
        )

        x, y, z = f[..., 0], f[..., 1], f[..., 2]
        l = (116 * y - 16).unsqueeze(0)
        a = (500 * (x - y)).unsqueeze(0)
        b = (200 * (y - z)).unsqueeze(0)
        ab = torch.cat([a, b], dim=0)

        # Normalize to [-1, 1]
        l = (l / 50.0) - 1.0
        ab = ab / 110.0

        return l, ab


@torch.jit.script
def lab_to_rgb(l: torch.Tensor, ab: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        XYZ_TO_RGB = torch.tensor(
            [
                [3.24048134, -1.53715152, -0.49853633],
                [-0.96925495, 1.87599, 0.04155593],
                [0.05564664, -0.20404134, 1.05731107],
            ],
            dtype=torch.float32,
            device=l.device,
        )
        XYZ_REF = torch.tensor(
            [0.95047, 1.0, 1.08883], dtype=torch.float32, device=l.device
        ).view(1, 3, 1, 1)

        if l.dim() == 3:
            l = l.unsqueeze(0)
        if ab.dim() == 3:
            ab = ab.unsqueeze(0)

        # Denormalize from [-1, 1]
        l = (l + 1.0) * 50.0
        ab = ab * 110.0

        y = (l + 16) / 116
        x = ab[:, 0:1] / 500 + y
        z = y - ab[:, 1:2] / 200

        xyz = torch.cat([x, y, z], dim=1)

        mask = xyz > 0.2068966
        xyz = torch.where(mask, xyz.pow(3), (xyz - 16 / 116) / 7.787)

        xyz = xyz * XYZ_REF

        batch_size, _, height, width = xyz.shape
        xyz_reshaped = xyz.view(batch_size, 3, -1)

        rgb_linear = torch.bmm(XYZ_TO_RGB.expand(batch_size, -1, -1), xyz_reshaped)

        rgb_linear = rgb_linear.view(batch_size, 3, height, width)

        mask = rgb_linear > 0.0031308
        rgb = torch.where(
            mask, 1.055 * rgb_linear.pow(1 / 2.4) - 0.055, 12.92 * rgb_linear
        )

        return rgb.clamp(0, 1).permute(0, 2, 3, 1)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Dataset Loading

In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        for subdir, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith((".jpg")):
                    self.image_paths.append(os.path.join(subdir, file))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)
        image = torch.from_numpy(np.array(image).astype(np.float32))
        l_chan, ab_chan = rgb_to_lab(image)
        return l_chan, ab_chan


class FastColorizationDataLoader(DataLoader):
    def __init__(self, dataset, batch_size, shuffle, num_workers, pin_memory):
        super().__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            # Multiprocessing is not supported in Jupyter Notebooks, so we disable it
            #num_workers=num_workers,
            pin_memory=pin_memory,
        )
        self.stream = torch.cuda.Stream()

    def __iter__(self):
        iterator = super().__iter__()
        for data in iterator:
            with torch.cuda.stream(self.stream):
                yield [item.cuda(non_blocking=True) for item in data]

## Model Settings

In [None]:
@dataclass
class ModelSettings:
    finetune: bool = True
    pretrained_path: str = "model_unet_final.pth"
    root_dir: str = "img_data"
    l1_loss_weight: float = 100.0
    gan_loss_weight: float = 1.0
    generator_lr: float = 2e-4
    discriminator_lr: float = 2e-4
    beta1: float = 0.5
    beta2: float = 0.999
    validation_image_count: int = 12
    batch_size: int = 32
    warmup_steps: int = 1000
    validation_steps: int = 1000
    model_name: str = "Unet-GAN"
    gan_mode: str = "vanilla"

    def set_total_image_count(self, count):
        self.total_image_count = count
        self.num_steps = (self.total_image_count - self.validation_image_count) // self.batch_size

    def __post_init__(self):
        self.total_image_count = None
        self.num_steps = None

    def create_run_name(self):
        if self.finetune:
            return f"{self.model_name}_pretrained_bs{self.batch_size}_steps{self.num_steps}"
        return f"{self.model_name}_bs{self.batch_size}_steps{self.num_steps}"

settings = ModelSettings()
dataset = ColorizationDataset(root_dir=settings.root_dir)
settings.set_total_image_count(len(dataset))

## Unet (Generator)

This is the unet model used in pretraining. We're going to plug it in to the GAN finetuning.

In [None]:
class CrissCrossAttention(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.channel_in = in_dim
        self.channel_out = in_dim // 8
        # Combined QKV projection
        self.qkv_conv = nn.Conv2d(in_dim, 3 * self.channel_out, 1)

        self.gamma = nn.Parameter(torch.zeros(1))
        self.out_conv = nn.Conv2d(in_dim // 8, in_dim, 1)

    def forward(self, x):
        B, C, H, W = x.size()

        # Combined QKV projection
        qkv = self.qkv_conv(x)
        query, key, value = qkv.chunk(3, dim=1)

        # Horizontal attention
        query_h = (
            query.permute(0, 3, 1, 2).contiguous().view(B * W, self.channel_out, H)
        )
        key_h = key.permute(0, 3, 1, 2).contiguous().view(B * W, self.channel_out, H)
        value_h = (
            value.permute(0, 3, 1, 2).contiguous().view(B * W, self.channel_out, H)
        )

        energy_h = torch.bmm(query_h, key_h.transpose(1, 2))
        attn_h = F.softmax(energy_h, dim=-1)
        out_h = torch.bmm(attn_h, value_h)

        # Vertical attention
        query_v = (
            query.permute(0, 2, 1, 3).contiguous().view(B * H, self.channel_out, W)
        )
        key_v = key.permute(0, 2, 1, 3).contiguous().view(B * H, self.channel_out, W)
        value_v = (
            value.permute(0, 2, 1, 3).contiguous().view(B * H, self.channel_out, W)
        )

        energy_v = torch.bmm(query_v, key_v.transpose(1, 2))
        attn_v = F.softmax(energy_v, dim=-1)
        out_v = torch.bmm(attn_v, value_v)

        # Reshape and combine
        out_h = out_h.view(B, W, self.channel_out, H).permute(0, 2, 3, 1)
        out_v = out_v.view(B, H, self.channel_out, W).permute(0, 2, 1, 3)

        out = out_h + out_v
        out = self.gamma * out

        # Project back to the original channel dimension
        out = self.out_conv(out)

        return out + x


class UnetBlock(nn.Module):
    def __init__(
        self,
        nf,
        ni,
        submodule=None,
        input_c=None,
        dropout=False,
        innermost=False,
        outermost=False,
        use_attention=True,
    ):
        super().__init__()
        self.outermost = outermost
        self.use_attention = use_attention
        if input_c is None:
            input_c = nf
        downconv = nn.Conv2d(
            input_c, ni, kernel_size=4, stride=2, padding=1, bias=False
        )
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)

        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(
                ni, nf, kernel_size=4, stride=2, padding=1, bias=False
            )
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(
                ni * 2, nf, kernel_size=4, stride=2, padding=1, bias=False
            )
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout:
                up += [nn.Dropout(0.5)]
            model = down + [submodule] + up

        self.model = nn.Sequential(*model)
        # Apply attention module if needed, except for the outermost layer and when the number of filters is too high
        self.use_attention_this_layer = not outermost and nf <= 512 and use_attention
        if self.use_attention_this_layer:
            self.attention = CrissCrossAttention(nf)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            output = self.model(x)
            if self.use_attention_this_layer:
                output = self.attention(output)
            return torch.cat([x, output], 1)


class Unet(nn.Module):
    def __init__(
        self,
        input_c=1,
        output_c=2,
        n_down=8,
        num_filters=64,
        use_attention=True,
        use_dropout=False,
    ):
        super().__init__()
        unet_block = UnetBlock(
            num_filters * 8,
            num_filters * 8,
            innermost=True,
            use_attention=use_attention,
        )

        for _ in range(n_down - 5):
            unet_block = UnetBlock(
                num_filters * 8,
                num_filters * 8,
                submodule=unet_block,
                dropout=use_dropout,
                use_attention=use_attention,
            )

        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(
                out_filters // 2,
                out_filters,
                submodule=unet_block,
                use_attention=use_attention,
            )
            out_filters //= 2

        self.model = UnetBlock(
            output_c,
            out_filters,
            input_c=input_c,
            submodule=unet_block,
            outermost=True,
            use_attention=use_attention,
        )

    def forward(self, x):
        return self.model(x)

# Discriminator

This discriminator architechure was found to be the most effective at producing favorable images, it analyzes the input images at multiple scales and predicts whether the given image is 'fake' or 'real'.

In [None]:
class SpectralNorm(nn.Module):
    def __init__(self, module, name="weight", power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))

        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _make_params(self):
        w = getattr(self.module, self.name)
        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = nn.Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)


def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class Discriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_layers=5):
        super().__init__()
        layers = [self.get_layers(input_c, num_filters, norm=False)]
        for i in range(1, n_layers):
            nf_prev = num_filters * min(2 ** (i - 1), 8)
            nf = num_filters * min(2**i, 8)
            stride = 1 if i == n_layers - 1 else 2
            layers.append(self.get_layers(nf_prev, nf, s=stride))
        layers.append(self.get_layers(nf, 1, s=1, norm=False, act=False))
        self.model = nn.Sequential(*layers)

    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
        layers = [SpectralNorm(nn.Conv2d(ni, nf, k, s, p, bias=not norm))]
        if norm:
            layers += [nn.InstanceNorm2d(nf)]
        if act:
            layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class MultiScaleDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_layers=5, num_D=3):
        super().__init__()
        self.num_D = num_D
        self.n_layers = n_layers

        for i in range(num_D):
            netD = Discriminator(input_c, num_filters, n_layers)
            setattr(self, f"layer_{i}", netD)

        self.downsample = nn.AvgPool2d(
            3, stride=2, padding=[1, 1], count_include_pad=False
        )

    def singleD_forward(self, model, input):
        result = [input]
        for i in range(len(model)):
            result.append(model[i](result[-1]))
        return result[1:]

    def forward(self, input):
        result = []
        input_downsampled = input
        for i in range(self.num_D):
            model = getattr(self, f"layer_{i}")
            result.append(self.singleD_forward(model.model, input_downsampled))
            if i != (self.num_D - 1):
                input_downsampled = self.downsample(input_downsampled)
        return result

## GAN Loss

The gan loss takes the logits of discrimator and produces a loss that the generator can use improve its image outputs

In [None]:
class GANLoss(nn.Module):
    def __init__(self, gan_mode="vanilla", real_label=0.9, fake_label=0.1):
        super().__init__()
        self.register_buffer("real_label", torch.tensor(real_label))
        self.register_buffer("fake_label", torch.tensor(fake_label))
        self.gan_mode = gan_mode
        if gan_mode == "vanilla":
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == "lsgan":
            self.loss = nn.MSELoss()
        else:
            raise NotImplementedError(f"gan_mode {gan_mode} not implemented")

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)

    def __call__(self, prediction, target_is_real):
        target_tensor = self.get_labels(prediction, target_is_real)
        loss = self.loss(prediction, target_tensor)
        return loss

# Main GAN Model

An amalgomation of the above networks that puts the generator and discriminator into a zero sum game. The generator attempts to generate images that fool the discriminator, and the discriminator attempts to classify images from the generator as 'fake'. The result of this game is a (hopefully) more effective generator model that produces higher quality colorizations, however the gan architechure will not ever converge to a final solution and can sometimes suffer mode collapse. Because of this each `n` steps are logged and images of its output are saved so that it can be later decided which iteration of the generator model to use.

In [None]:
class MainModel(nn.Module):
    def __init__(
        self,
        net_G=None,
        lr_G=1e-4,
        lr_D=4e-4,
        beta1=0.5,
        beta2=0.999,
        lambda_L1=100.0,
        lambda_GAN=1.0,
        gan_mode="vanilla"
    ):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        self.lambda_GAN = lambda_GAN

        if net_G is None:
            self.net_G = Unet(input_c=1, output_c=2, n_down=8, num_filters=64).to(self.device)
        else:
            self.net_G = net_G.to(self.device)

        self.net_D = MultiScaleDiscriminator(input_c=3, num_filters=64, n_layers=5, num_D=3).to(self.device)
        self.GANcriterion = GANLoss(gan_mode=gan_mode).to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

        self.loss_G_GAN = None
        self.loss_G_L1 = None
        self.loss_D = None
        self.loss_D_real = None
        self.loss_D_fake = None

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, l_chan, ab_chan):
        self.l_chan = l_chan.to(self.device)
        self.ab_chan = ab_chan.to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.l_chan)

    def backward_D(self):
        fake_image = torch.cat([self.l_chan, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.l_chan, self.ab_chan], dim=1)
        real_preds = self.net_D(real_image)
        loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (loss_D_fake + loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        fake_image = torch.cat([self.l_chan, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True) * self.lambda_GAN
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab_chan) * self.lambda_L1
        
        real_rgb = lab_to_rgb(self.l_chan, self.ab_chan)
        fake_rgb = lab_to_rgb(self.l_chan, self.fake_color)
        
        real_rgb = real_rgb.permute(0, 3, 1, 2)
        fake_rgb = fake_rgb.permute(0, 3, 1, 2)
        
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize(self):
        self.forward()

        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()

        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()
    
    def compute_D_loss(self):
        real_image = torch.cat([self.l_chan, self.ab_chan], dim=1)
        fake_image = torch.cat([self.l_chan, self.fake_color.detach()], dim=1)
        
        real_preds = self.net_D(real_image)
        fake_preds = self.net_D(fake_image)
        
        self.loss_D_real = 0
        self.loss_D_fake = 0
        for real_scale_preds, fake_scale_preds in zip(real_preds, fake_preds):
            self.loss_D_real += self.GANcriterion(real_scale_preds[-1], True)
            self.loss_D_fake += self.GANcriterion(fake_scale_preds[-1], False)
        
        self.loss_D_real /= len(real_preds)
        self.loss_D_fake /= len(fake_preds)
        self.loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5
        
        return self.loss_D

    def compute_G_loss(self):
        fake_image = torch.cat([self.l_chan, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        
        self.loss_G_GAN = 0
        for scale_preds in fake_preds:
            self.loss_G_GAN += self.GANcriterion(scale_preds[-1], True)
        self.loss_G_GAN /= len(fake_preds)
        
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab_chan) * self.lambda_L1
        loss_G = self.loss_G_GAN * self.lambda_GAN + self.loss_G_L1
        return loss_G

    def train(self):
        self.net_G.train()
        self.net_D.train()

    def eval(self):
        self.net_G.eval()
        self.net_D.eval()

## Validation Function

In [None]:
def validate_model(model, test_dataloader):
    torch.cuda.synchronize()
    val_start_time = time.time()
    model.eval()
    val_loss_G = 0.0
    val_loss_D = 0.0
    total_images = 0
    logged_images = 0

    with torch.no_grad():
        for l_chan, ab_chan in test_dataloader:
            l_chan, ab_chan = l_chan.to(device), ab_chan.to(device)

            model.setup_input(l_chan, ab_chan)
            model.forward()

            loss_D = model.compute_D_loss()
            loss_G = model.compute_G_loss()

            val_loss_G += loss_G.item()
            val_loss_D += loss_D.item()
            total_images += l_chan.size(0)

            if logged_images < settings.display_imgs:
                num_samples = min(
                    settings.display_imgs - logged_images, l_chan.shape[0]
                )
                l_chan_samples = l_chan[:num_samples]
                output_samples = model.fake_color[:num_samples]
                target_samples = ab_chan[:num_samples]

                greyscale_samples = lab_to_rgb(
                    l_chan_samples, torch.zeros_like(output_samples)
                )
                output_rgb_samples = lab_to_rgb(l_chan_samples, output_samples)
                target_rgb_samples = lab_to_rgb(l_chan_samples, target_samples)

                wandb.log(
                    {
                        "Examples": wandb.Image(
                            np.vstack(
                                [
                                    np.hstack(greyscale_samples.cpu().numpy()),
                                    np.hstack(output_rgb_samples.cpu().numpy()),
                                    np.hstack(target_rgb_samples.cpu().numpy()),
                                ]
                            ),
                            caption="Top: Grayscale, Middle: Predicted, Bottom: True",
                        )
                    },
                    commit=False,
                )

                logged_images += num_samples

    torch.cuda.synchronize()
    avg_val_loss_G = val_loss_G / (total_images / settings.batch_size)
    avg_val_loss_D = val_loss_D / (total_images / settings.batch_size)
    val_time = time.time() - val_start_time
    print(
        f"Average Validation Loss G: {avg_val_loss_G:.4f}, D: {avg_val_loss_D:.4f}, Validation Time: {val_time:.4f}s"
    )
    wandb.log(
        {
            "Validation Loss G": avg_val_loss_G,
            "Validation Loss D": avg_val_loss_D,
            "Validation Time": val_time,
        },
        commit=False,
    )
    return avg_val_loss_G + avg_val_loss_D

## Training Function

In [None]:
def train_model(
    model,
    train_dataloader,
    test_dataloader,
    num_steps,
    validation_steps,
):
    torch.cuda.synchronize()
    start_time = time.time()
    model.train()
    step_times = []
    best_val_loss = float("inf")

    train_iter = iter(train_dataloader)

    for step in range(num_steps):
        try:
            l_chan, ab_chan = next(train_iter)
        except StopIteration:
            train_iter = iter(train_dataloader)
            l_chan, ab_chan = next(train_iter)

        l_chan, ab_chan = l_chan.to(device), ab_chan.to(device)

        # Update Discriminator
        model.set_requires_grad(model.net_D, True)
        model.opt_D.zero_grad(set_to_none=True)
        model.setup_input(l_chan, ab_chan)
        model.forward()
        loss_D = model.compute_D_loss()
        loss_D.backward()
        torch.nn.utils.clip_grad_norm_(model.net_D.parameters(), max_norm=1.0)
        model.opt_D.step()

        # Update Generator
        model.set_requires_grad(model.net_D, False)
        model.opt_G.zero_grad(set_to_none=True)
        loss_G = model.compute_G_loss()
        loss_G.backward()
        torch.nn.utils.clip_grad_norm_(model.net_G.parameters(), max_norm=1.0)
        model.opt_G.step()

        torch.cuda.synchronize()
        step_end_time = time.time()
        step_times.append(step_end_time)

        if step > 0:
            step_time = step_times[-1] - step_times[-2]
        else:
            step_time = step_end_time - start_time

        total_time_spent = step_end_time - start_time
        avg_step_time = total_time_spent / (step + 1)
        etc_seconds = avg_step_time * (num_steps - (step + 1))
        etc_hours = int(etc_seconds // 3600)
        etc_minutes = int((etc_seconds % 3600) // 60)
        time_to_next_checkpoint = avg_step_time * (
            validation_steps - (step % validation_steps)
        )

        if model.loss_G_GAN is not None:
            loss_G_GAN = model.loss_G_GAN.item()
        else:
            loss_G_GAN = 0.0

        if model.loss_G_L1 is not None:
            loss_G_L1 = model.loss_G_L1.item()
        else:
            loss_G_L1 = 0.0

        print(
            f"Step [{step + 1}/{num_steps}], "
            f"Loss G: {loss_G.item():.4f}, "
            f"Loss D: {loss_D.item():.4f}, "
            f"Step Time: {step_time:.4f}s, "
            f"ETC: {etc_hours}h {etc_minutes}m, "
            f"Next Checkpoint: {time_to_next_checkpoint/60:.2f} minutes"
        )

        log_dict = {
            "Generator Loss": loss_G,
            "Generator Loss GAN": loss_G_GAN,
            "Generator Loss L1": loss_G_L1,
            "Discriminator Loss": loss_D,
            "Discriminator Loss Real": model.loss_D_real.item(),
            "Discriminator Loss Fake": model.loss_D_fake.item(),
            "Step": step + 1,
            "Step Time": step_time,
            "ETC (hours)": etc_seconds / 3600,
            "Time to Next Checkpoint (minutes)": time_to_next_checkpoint / 60,
        }

        wandb.log(log_dict)

        if (step + 1) % validation_steps == 0:
            model.eval()
            val_loss = validate_model(model, test_dataloader)
            model.train()

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(
                    {
                        "step": step + 1,
                        "model_state_dict": model.state_dict(),
                        "optimizer_G_state_dict": model.opt_G.state_dict(),
                        "optimizer_D_state_dict": model.opt_D.state_dict(),
                        "loss_G": loss_G.item(),
                        "loss_D": loss_D.item(),
                    },
                    f"best_checkpoint_unet_gan.pth",
                )
                
            torch.save(
                {
                    "step": step + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_G_state_dict": model.opt_G.state_dict(),
                    "optimizer_D_state_dict": model.opt_D.state_dict(),
                    "loss_G": loss_G.item(),
                    "loss_D": loss_D.item(),
                },
                f"checkpoint_unet_gan_{step + 1}.pth",
            )

    torch.cuda.synchronize()
    total_time = time.time() - start_time

    print(f"Total Time: {total_time:.4f}s")
    wandb.log({"Total Time": total_time})

    validate_model(model, test_dataloader)

    torch.save(model.state_dict(), "model_final_unet_gan.pth")

## Begin Training!

In [None]:
if __name__ == "__main__":
    #multiprocessing.freeze_support()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    wandb.init(project="unet-gan-colorizer")

    wandb.config.update(asdict(settings))
    wandb.run.name = settings.create_run_name()

    train_size = settings.total_image_count - settings.validation_image_count
    test_size = settings.validation_image_count
    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, test_size]
    )

    train_dataloader = FastColorizationDataLoader(
        train_dataset,
        batch_size=settings.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    test_dataloader = FastColorizationDataLoader(
        test_dataset,
        batch_size=settings.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    pretrained_unet = Unet(input_c=1, output_c=2, n_down=8, num_filters=64)
    if settings.finetune:
        state_dict = torch.load(settings.pretrained_path)
        if 'net_G' in state_dict:
            # The model was saved as part of MainModel
            pretrained_unet.load_state_dict({k.replace('net_G.', ''): v for k, v in state_dict['net_G'].items()})
        elif 'model_state_dict' in state_dict:
            # The model was saved with the entire state dict
            pretrained_unet.load_state_dict({k.replace('net_G.', ''): v for k, v in state_dict['model_state_dict'].items() if k.startswith('net_G.')})
        else:
            # Try loading directly
            pretrained_unet.load_state_dict({k.replace('net_G.', ''): v for k, v in state_dict.items() if k.startswith('net_G.')})
        print(f"Loaded pretrained model from {settings.pretrained_path}")

    model = MainModel(
        net_G=pretrained_unet,
        lr_G=settings.generator_lr,
        lr_D=settings.discriminator_lr,
        beta1=settings.beta1,
        beta2=settings.beta2,
        lambda_L1=settings.l1_loss_weight,
        lambda_GAN=settings.gan_loss_weight,
        gan_mode=settings.gan_mode,
    ).to(device)
    wandb.watch(model)

    print(model)

    print(f"Generator Parameters: {sum(p.numel() for p in model.net_G.parameters())}")
    print(f"Discriminator Parameters: {sum(p.numel() for p in model.net_D.parameters())}")

    train_model(
        model,
        train_dataloader,
        test_dataloader,
        settings.num_steps,
        settings.validation_steps,
    )
