In [None]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torchvision
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.transforms import v2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import pickle
from PIL import Image
from tqdm import tqdm

In [None]:
IMG_SIZE = 512
CROP_SIZE = 256

NUM_CHANELS = 3
BATCH_SIZE = 1

NUM_FEATURES = 64

NUM_EPOCHS = 200
LEARNING_RATE = 0.0002

CHECKPOINT = "cyclegan_epoch005.pt"

In [None]:
day_dir = "./data/day"
night_dir = "./data/night"

plt.figure()
img = plt.imread(night_dir + '/' + os.listdir(night_dir)[0])
plt.imshow(img)
plt.axis('off')
plt.title('sample image')
print(f'Image dimensions {img.shape}')
plt.show()

In [None]:
transform = transforms.Compose(
    [
        v2.ToTensor(),
        v2.Resize((IMG_SIZE, IMG_SIZE), transforms.InterpolationMode.BICUBIC),
        v2.RandomCrop((CROP_SIZE, CROP_SIZE)),
        v2.AutoAugment(),
        v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

In [None]:
import glob
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader

class UnpairedImageDataset(Dataset):
    def __init__(self, root_day, root_night, transform=None):
        # grab all image paths
        self.files_day   = sorted(glob.glob(f"{root_day}/*"))
        self.files_night = sorted(glob.glob(f"{root_night}/*"))
        self.transform   = transform
        # define length as max so each epoch sees all of both domains
        self.length      = max(len(self.files_day), len(self.files_night))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # cycle through day images sequentially
        path_day = self.files_day[idx % len(self.files_day)]
        # sample a random night image
        path_night = random.choice(self.files_night)

        img_day   = Image.open(path_day).convert("RGB")
        img_night = Image.open(path_night).convert("RGB")

        if self.transform:
            img_day   = self.transform(img_day)
            img_night = self.transform(img_night)

        return {"day": img_day, "night": img_night}

dataset = UnpairedImageDataset(day_dir, night_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, drop_last=True)

batch = next(iter(dataloader))
print(batch["day"].shape, batch["night"].shape)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, bias=False),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, bias=False),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)

In [None]:
# TransposeConv: out = (in + 2 x padding - kernel_size - output_padding) / stride + 1
# Conv: out = (in - 1) x stride - 2 x padding + kernel_size + output_padding

class Generator(nn.Module):
    def __init__(self, n_blocks=6) -> None:
        super().__init__()

        n_blocks = n_blocks

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(
                in_channels=NUM_CHANELS,
                out_channels=NUM_FEATURES,
                kernel_size=7,
                stride=1,
                bias=False,
            ),
            nn.InstanceNorm2d(NUM_FEATURES),
            nn.ReLU(inplace=True),
            # 2 downsampling layers
            nn.Conv2d(
                in_channels=NUM_FEATURES,
                out_channels=NUM_FEATURES * 2,
                kernel_size=3,
                stride=2,
                padding=1,
            ),
            nn.InstanceNorm2d(NUM_FEATURES * 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=NUM_FEATURES * 2,
                out_channels=NUM_FEATURES * 4,
                kernel_size=3,
                stride=2,
                padding=1,
            ),
            nn.InstanceNorm2d(NUM_FEATURES * 4),
            nn.ReLU(inplace=True),
        ]

        for _ in range(n_blocks):
            model += [ResidualBlock(NUM_FEATURES * 4)]

        model += [
            # 2 unsampling layers
            nn.ConvTranspose2d(
                in_channels=NUM_FEATURES * 4,
                out_channels=NUM_FEATURES * 2,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
                bias=False,
            ),
            nn.InstanceNorm2d(NUM_FEATURES * 2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(
                in_channels=NUM_FEATURES * 2,
                out_channels=NUM_FEATURES,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
                bias=False,
            ),
            nn.InstanceNorm2d(NUM_FEATURES),
            nn.ReLU(inplace=True),

            # This layer to get RGB image
            nn.ReflectionPad2d(3),
            nn.Conv2d(
                in_channels=NUM_FEATURES,
                out_channels=NUM_CHANELS,
                kernel_size=7,
                stride=1,
            ),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)

In [None]:
# TransposeConv: out = (in + 2 x padding - kernel_size - output_padding) / stride + 1
# Conv: out = (in - 1) x stride - 2 x padding + kernel_size + output_padding

class Discriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(
                in_channels=NUM_CHANELS,
                out_channels=NUM_FEATURES,
                kernel_size=4,
                stride=2,
                padding=1,
            ),
            nn.InstanceNorm2d(NUM_FEATURES),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 2nd layer
            nn.Conv2d(
                in_channels=NUM_FEATURES,
                out_channels=NUM_FEATURES * 2,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.InstanceNorm2d(NUM_FEATURES * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # 3rd layer
            nn.Conv2d(
                in_channels=NUM_FEATURES * 2,
                out_channels=NUM_FEATURES * 4,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.InstanceNorm2d(NUM_FEATURES * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # 4th layer
            nn.Conv2d(
                in_channels=NUM_FEATURES * 4,
                out_channels=NUM_FEATURES * 8,
                kernel_size=4,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.InstanceNorm2d(NUM_FEATURES * 8),
            nn.LeakyReLU(0.2, inplace=True),

            # 5th (final) layer
            nn.Conv2d(
                in_channels=NUM_FEATURES * 8,
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=1,
                bias=False,
            ),
            # nn.Sigmoid(),
        )
    
    def forward(self, x):
        return self.model(x)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

D = Discriminator().to(device)
x = batch["day"].to(device)        # [B, 3, 256, 256]
pred = D(x)                        # → [B, 1, 30, 30] for 256×256 input

print(x.shape, ": D(x) →", pred.shape)


In [None]:
import os
import itertools
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import make_grid, save_image
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

# ——— Hyper‑params ———
NUM_EPOCHS   = 10
LR           = 2e-4
BETA1, BETA2 = 0.5, 0.999
LAMBDA_CYCLE = 10

# ——— Losses ———
adv_loss   = nn.BCELoss()
cycle_loss = nn.L1Loss()

# ——— Models ———
G   = Generator(9).to(device)  # X→Y
F   = Generator(9).to(device)  # Y→X

if CHECKPOINT is not None:
    ckpt = torch.load(f"./checkpoints/{CHECKPOINT}", map_location=device)
    G.load_state_dict(ckpt["G"])
    F.load_state_dict(ckpt["F"])
    
D_Y = Discriminator().to(device)
D_X = Discriminator().to(device)

# ——— Optimizers ———
opt_G  = optim.Adam(itertools.chain(G.parameters(), F.parameters()), lr=LR, betas=(BETA1, BETA2))
opt_DY = optim.Adam(D_Y.parameters(), lr=LR, betas=(BETA1, BETA2))
opt_DX = optim.Adam(D_X.parameters(), lr=LR, betas=(BETA1, BETA2))

# ——— Logging setup ———
os.makedirs("samples", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
writer = SummaryWriter("runs/CycleGAN")  # optional TensorBoard

In [None]:
def sample_images(epoch, G, F, dataloader, device):
    # grab one batch to visualize
    batch = next(iter(dataloader))
    real_X, real_Y = batch["day"].to(device), batch["night"].to(device)

    fake_Y = G(real_X)
    fake_X = F(real_Y)
    rec_X  = F(fake_Y)
    rec_Y  = G(fake_X)

    # make a 2×3 grid: [ real_X | fake_Y | rec_X ]
    #                  [ real_Y | fake_X | rec_Y ]
    row1 = torch.cat([real_X, fake_Y, rec_X], dim=3)
    row2 = torch.cat([real_Y, fake_X, rec_Y], dim=3)
    grid = torch.cat([row1, row2], dim=2)

    # save to disk
    save_image(grid, f"samples/epoch_{epoch:03d}.png",
               normalize=True, value_range=(0,1))  # or (-1,1) if you scaled that way

    # log to TensorBoard
    writer.add_image("CycleGAN/Results", grid, epoch, dataformats="NCHW")


# ——— Training loop ———
for epoch in range(1, NUM_EPOCHS+1):
    loop = tqdm(dataloader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")
    for batch in loop:
        real_X = batch["day"].to(device)
        real_Y = batch["night"].to(device)

        # — Generator step —
        opt_G.zero_grad()

        fake_Y = G(real_X)
        fake_X = F(real_Y)

        # adversarial
        pred_fake_Y = torch.sigmoid(D_Y(fake_Y))
        pred_fake_X = torch.sigmoid(D_X(fake_X))
        valid_Y = torch.ones_like(pred_fake_Y)
        valid_X = torch.ones_like(pred_fake_X)

        loss_GAN_XY = adv_loss(pred_fake_Y, valid_Y)
        loss_GAN_YX = adv_loss(pred_fake_X, valid_X)

        # cycle
        rec_X = F(fake_Y)
        rec_Y = G(fake_X)
        loss_cycle = (cycle_loss(rec_X, real_X) + cycle_loss(rec_Y, real_Y)) * LAMBDA_CYCLE

        # identity
        idt_Y = G(real_Y)
        idt_X = F(real_X)
        loss_idt = (cycle_loss(idt_Y, real_Y) + cycle_loss(idt_X, real_X)) * (LAMBDA_CYCLE * 0.5)

        # total
        loss_G = loss_GAN_XY + loss_GAN_YX + loss_cycle + loss_idt
        loss_G.backward()
        opt_G.step()

        # — Discriminator Y step —
        opt_DY.zero_grad()
        # real
        pred_real_Y = torch.sigmoid(D_Y(real_Y))
        loss_D_realY = adv_loss(pred_real_Y, valid_Y)
        # fake
        pred_fake_Y2 = torch.sigmoid(D_Y(fake_Y.detach()))
        loss_D_fakeY = adv_loss(pred_fake_Y2, torch.zeros_like(pred_fake_Y2))
        loss_DY = 0.5 * (loss_D_realY + loss_D_fakeY)
        loss_DY.backward()
        opt_DY.step()

        # — Discriminator X step —
        opt_DX.zero_grad()
        pred_real_X = torch.sigmoid(D_X(real_X))
        loss_D_realX = adv_loss(pred_real_X, valid_X)
        pred_fake_X2 = torch.sigmoid(D_X(fake_X.detach()))
        loss_D_fakeX = adv_loss(pred_fake_X2, torch.zeros_like(pred_fake_X2))
        loss_DX = 0.5 * (loss_D_realX + loss_D_fakeX)
        loss_DX.backward()
        opt_DX.step()

        # — Progress bar update —
        loop.set_postfix({
            "L_G": loss_G.item(),
            "L_DY": loss_DY.item(),
            "L_DX": loss_DX.item()
        })

    # — After each epoch: sample & checkpoint —
    sample_images(epoch, G, F, dataloader, device)

    ckpt = {
        "epoch": epoch,
        "G": G.state_dict(),
        "F": F.state_dict(),
        "D_X": D_X.state_dict(),
        "D_Y": D_Y.state_dict(),
        "opt_G": opt_G.state_dict(),
        "opt_DX": opt_DX.state_dict(),
        "opt_DY": opt_DY.state_dict(),
    }
    torch.save(ckpt, f"checkpoints/cyclegan_epoch{epoch:03d}.pt")

    print(f"=> Saved samples/epoch_{epoch:03d}.png and checkpoint.")


In [None]:
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

# 1) Recreate your model and load weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2) Prepare your input image
transform_infer = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(CROP_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
])

def denormalize(tensor):
    # move from [-1,1] back to [0,1]
    return (tensor * 0.5) + 0.5

def infer(path, model, device):
    img = Image.open(path).convert("RGB")
    x   = transform_infer(img).unsqueeze(0).to(device)      # [1,3,H,W]
    with torch.no_grad():
        y = model(x)
    y = y.cpu().squeeze(0)                                  # [3,H,W]
    return denormalize(y).permute(1,2,0).numpy()            # H×W×3 numpy

# 3) Run and display
input_path  = "data/day/0.jpg"
out_day2night = infer(input_path, G, device)
out_cycleback = infer(input_path, F, device)  # if you want F(day→night→day)

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.title("Day→Night")
plt.axis("off")
plt.imshow(out_day2night)
plt.subplot(1,2,2)
plt.title("Night→Day (cycle)")
plt.axis("off")
plt.imshow(out_cycleback)
plt.show()
