# Import required libraries

In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

# Dataloader

In [4]:
# Define the ImageDataset class
class ImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.file_paths = [
            os.path.join(folder_path, file_name) for file_name in os.listdir(folder_path)
        ]
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = Image.open(img_path).convert("RGB")  # Open the image and convert to RGB
        if self.transform:
            image = self.transform(image)  # Apply transformations if provided
        return image

# Define transformation pipeline
transform_pipeline = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

def load_images(folder_path):
    dataset = ImageDataset(folder_path, transform=transform_pipeline)
    # Added num_workers and pin_memory for better performance
    return DataLoader(
        dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=4,  # Adjust this based on your CPU cores
        pin_memory=True  # Speeds up data transfer to GPU if using CUDA
    )

# Load datasets
source = load_images('/home/umang.shikarvar/instaformer/wb_small_airshed/images')  # Give source path
target = load_images('/home/umang.shikarvar/instaformer/delhi_ncr_small/images')   # Give target path

# Model

In [5]:
class EncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), 
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        x = self.conv(x)
        return self.pool(x), x  # (pooled output, pre-pool features)

class Bottleneck(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.block(x)

class DecoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, 2)
        self.conv = nn.Sequential(
            nn.Conv2d(out_ch * 2, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
        return self.conv(x)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.enc1 = EncoderBlock(3, 64)    # 640x640x3→320x320x64
        self.enc2 = EncoderBlock(64, 128)  # 320x320x64→160x160x128
        self.enc3 = EncoderBlock(128, 256) # 160x160x128→80x80x256
        self.enc4 = EncoderBlock(256, 512) # 80x80x256→40x40x512
        
        # Bottleneck
        self.bottleneck = Bottleneck(512, 1024) # 40x40x512→40x40x1024
        
        # Decoder
        self.dec4 = DecoderBlock(1024, 512) # 40x40x1024→80x80x512
        self.dec3 = DecoderBlock(512, 256) # 80x80x512→160x160x256
        self.dec2 = DecoderBlock(256, 128) # 160x160x256→320x320x128
        self.dec1 = DecoderBlock(128, 64) # 320x320x128→640x640x64
        
        self.out = nn.Sequential(
            nn.Conv2d(64, 3, 1), # 640x640x64→640x640x3
            nn.Tanh() # Normalize to [-1, 1]
        )

    def encoder(self, x):
        # Encoder forward pass only
        x, s1 = self.enc1(x)
        x, s2 = self.enc2(x)
        x, s3 = self.enc3(x)
        x, s4 = self.enc4(x)
        return [s1, s2, s3, s4]

    def forward(self, x):
        # Encoder with skip connections
        x, s1 = self.enc1(x)  # x: 320x320x64, s1: 640x640x64
        x, s2 = self.enc2(x)  # x: 160x160x128, s2: 320x320x128
        x, s3 = self.enc3(x)  # x: 80x80x256, s3: 160x160x256
        x, s4 = self.enc4(x)  # x: 40x40x512, s4: 80x80x512
        
        # Bottleneck
        x = self.bottleneck(x)  # 40x40x1024
        
        # Decoder with skip connections
        x = self.dec4(x, s4)  # 80x80x512 using x: 40x40x1024, s4: 80x80x512
        x = self.dec3(x, s3)  # 160x160x256 using x: 80x80x512, s3: 160x160x256
        x = self.dec2(x, s2)  # 320x320x128 using x: 160x160x256, s2: 320x320x128
        x = self.dec1(x, s1)  # 640x640x64 using x: 320x320x128, s1: 640x640x64
        
        return self.out(x), [s1, s2, s3, s4]

class HEncoder(nn.Module):  
    def __init__(self, input_channels, output_dim=300):
        super().__init__()
        # Layer-specific MLPs
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Linear(C, output_dim),  # Channel-wise transformation (C → 256)
                nn.ReLU()
            ) for C in input_channels
        ])
        
    def forward(self, features):
        embeddings = []
        for i, (proj, f) in enumerate(zip(self.proj, features)):
            
            # Reshape and apply MLP
            B, C, H, W = f.shape  # Update after downsampling
            f = f.permute(0, 2, 3, 1).reshape(B, H * W, C)  # [B, S, C]
            f_projected = proj(f)  # Apply MLP to each patch → [B, S, D]
            
            embeddings.append(f_projected)
        return embeddings

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # 640x640x3→320x320x64
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),  # 320x320x64→160x160x128
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),  # 160x160x128→80x80x256
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1),  # 80x80x256→40x40x512
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, 1, 1)  # 40x40x512→39x39x1
        )

    def forward(self, x):
        return self.model(x)

# Loss functions

In [6]:
def NCELoss(f_q, f_k, patch_count=4096, tau=0.07):
    """
    Args:
        f_q, f_k: (B, S, D) - patch features per image
    Returns:
        InfoNCE loss using N patches per image
    """
    B, S, D = f_q.shape
    N = min(patch_count, S)  # N patches per image

    # Normalize feature vectors
    f_q = F.normalize(f_q, dim=2)  # (B, S, D)
    f_k = F.normalize(f_k, dim=2)  # (B, S, D)

    # Sample N patch indices per image
    idx = torch.randint(0, S, (B, N), device=f_q.device)  # (B, N)

    # Gather patches using batched indexing
    batch_indices = torch.arange(B, device=f_q.device).unsqueeze(1).expand(B, N)  # (B, N)
    f_q_sampled = f_q[batch_indices, idx]  # (B, N, D)
    f_k_sampled = f_k[batch_indices, idx]  # (B, N, D)

    # Compute positive logits: dot product of corresponding patches
    l_pos = torch.sum(f_q_sampled * f_k_sampled, dim=2, keepdim=True)  # (B, N, 1)

    # Compute negative logits: all other patches in the same image
    logits_neg = torch.bmm(f_q_sampled, f_k_sampled.transpose(1, 2))  # (B, N, N)
    mask = torch.eye(N, device=f_q.device).unsqueeze(0).bool()  # (1, N, N)
    logits_neg = logits_neg.masked_fill(mask, -1e9)

    # Combine and compute loss
    logits = torch.cat([l_pos, logits_neg], dim=2) / tau  # (B, N, 1+N)
    labels = torch.zeros(B, N, dtype=torch.long, device=f_q.device)

    return F.cross_entropy(logits.reshape(-1, 1 + N), labels.reshape(-1))

def contrastive_loss(embeddings_x, embeddings_gx, nce_loss_fn):
    """
    Compute multilayer contrastive loss

    Args:
        embeddings_x: List of input features from H(G_enc(x))
        embeddings_gx: List of output features from H(G_enc(G(x)))
        nce_loss_fn: A function like NCELoss
    """
    total_loss = 0.0
    num_layers = len(embeddings_x)
    
    for f_x, f_gx in zip(embeddings_x, embeddings_gx):
        total_loss += nce_loss_fn(f_gx, f_x)  # Query=G(x), Key=x
    
    return total_loss / num_layers

def adversarial_loss(D, real_images, fake_images):
    """ 
    LSGAN Adversarial Loss.
    Args:
        D: Discriminator model
        real_images: Real images from dataset
        fake_images: Generated images from Generator
        
    Returns:
        Discriminator loss, Generator loss
    """
    real_preds = D(real_images)  # Real image predictions
    fake_preds = D(fake_images.detach())  # Fake image predictions

    # LSGAN losses
    d_real_loss = 0.5 * torch.mean((real_preds - 1) ** 2)
    d_fake_loss = 0.5 * torch.mean(fake_preds ** 2)
    d_loss = d_real_loss + d_fake_loss

    g_loss = 0.5 * torch.mean((D(fake_images) - 1) ** 2)

    return d_loss, g_loss

def identity_loss(G, real_target_images):
    """
    Identity loss: Ensures G preserves content when given target images.
    
    Args:
        G: Generator model
        real_target_images: Real images from target domain
        
    Returns:
        Identity loss
    """
    identity_output = G(real_target_images)[0]  # Generator output
    return torch.mean(torch.abs(identity_output - real_target_images))  # L1 Loss

# Hyperparameters

In [7]:
BATCH_SIZE = 5
EPOCHS = 200

# Set device to GPU or CPU
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# Initialize models
G = Generator().to(device)
H = HEncoder([64, 128, 256, 512]).to(device)
D = Discriminator().to(device)

# Define optimizers
g_optimizer = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
h_optimizer = optim.Adam(H.parameters(), lr=2e-4, betas=(0.5, 0.999))
d_optimizer = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Learning rate schedulers
g_scheduler = optim.lr_scheduler.StepLR(g_optimizer, step_size=20, gamma=0.5)
h_scheduler = optim.lr_scheduler.StepLR(h_optimizer, step_size=20, gamma=0.5)
d_scheduler = optim.lr_scheduler.StepLR(d_optimizer, step_size=20, gamma=0.5)

# Hyperparameters
lambda_Y = 1  # Scaling factor for identity loss

# Training loop

In [None]:
# Initialize TensorBoard writer
log_dir = "/home/umang.shikarvar/instaformer/CUT_logs"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

# Training loop
for epoch in range(EPOCHS):
    total_d_loss = 0.0
    total_g_loss = 0.0
    total_contrastive_loss = 0.0
    total_identity_loss = 0.0

    for real_x, real_y in zip(source, target):
        real_x, real_y = real_x.to(device), real_y.to(device)

        # ==============================
        # 1. Train Discriminator
        # ==============================
        G_x, _ = G(real_x)  # Forward through generator

        d_loss, _ = adversarial_loss(D, real_y, G_x.detach())  # Detach to avoid backward through G

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ==============================
        # 2. Train Generator & Projection Head
        # ==============================
        G_x, features_x = G(real_x)  # Forward through generator again
        features_G_x = G.encoder(G_x)  # Encode G(x)

        # Project features using H (projection head)
        embeddings_x = H(features_x)
        embeddings_gx = H(features_G_x)

        contrastive_loss_val = contrastive_loss(embeddings_x, embeddings_gx, NCELoss)

        _, g_loss = adversarial_loss(D, real_y, G_x)  # Generator adversarial loss
        identity_loss_val = identity_loss(G, real_y)

        total_loss = g_loss + contrastive_loss_val + lambda_Y * identity_loss_val

        g_optimizer.zero_grad()
        h_optimizer.zero_grad()
        total_loss.backward()
        g_optimizer.step()
        h_optimizer.step()

        # Accumulate losses
        total_d_loss += d_loss.item()
        total_g_loss += g_loss.item()
        total_contrastive_loss += contrastive_loss_val.item()
        total_identity_loss += identity_loss_val.item()

    # Scheduler steps
    d_scheduler.step()
    g_scheduler.step()
    h_scheduler.step()

    # TensorBoard logging
    writer.add_scalar("Loss/Discriminator", total_d_loss, epoch)
    writer.add_scalar("Loss/Generator", total_g_loss, epoch)
    writer.add_scalar("Loss/Contrastive", total_contrastive_loss, epoch)
    writer.add_scalar("Loss/Identity", total_identity_loss, epoch)

    print(f"Epoch [{epoch+1}/{EPOCHS}] - D Loss: {total_d_loss:.4f}, "
          f"G Loss: {total_g_loss:.4f}, "
          f"Contrastive Loss: {total_contrastive_loss:.4f}, "
          f"Identity Loss: {total_identity_loss:.4f}")      

    # Checkpoint
    if (epoch + 1) % 10 == 0:
        torch.save(G.state_dict(), f"/home/umang.shikarvar/instaformer/CUT_gen/generator_CUT_{epoch+1}.pth")

writer.close()
print("Training complete!")