# Import required libraries

In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.utils as vutils

# Dataloader

In [None]:
# Set device to GPU or CPU
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 1

# Define the ImageDataset class
class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = [os.path.join(image_dir, img) for img in os.listdir(image_dir)]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, img_path  # Return both image and path

# Define transformation pipeline
transform_pipeline = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

def load_images(folder_path):
    dataset = ImageDataset(folder_path, transform=transform_pipeline)
    # Added num_workers and pin_memory for better performance
    return DataLoader(
        dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=4,  # Adjust this based on your CPU cores
        pin_memory=True  # Speeds up data transfer to GPU if using CUDA
    )

# Load datasets
source = load_images('/home/umang.shikarvar/instaformer/wb_small_airshed/images')  # Give source path
target = load_images('/home/umang.shikarvar/instaformer/delhi_ncr_small/images')   # Give target path

# CycleGAN

## Model

In [None]:
# Define models
class ConvolutionalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_downsampling=True, add_activation=True, **kwargs):
        super().__init__()
        if is_downsampling:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True) if add_activation else nn.Identity(),
            )
        else:
            self.conv = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True) if add_activation else nn.Identity(),
            )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvolutionalBlock(channels, channels, add_activation=True, kernel_size=3, padding=1),
            ConvolutionalBlock(channels, channels, add_activation=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=6):
        super().__init__()
        self.initial_layer = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.downsampling_layers = nn.ModuleList([
            ConvolutionalBlock(num_features, num_features * 2, is_downsampling=True, kernel_size=3, stride=2, padding=1),
            ConvolutionalBlock(num_features * 2, num_features * 4, is_downsampling=True, kernel_size=3, stride=2, padding=1),
        ])
        self.residual_layers = nn.Sequential(*[ResidualBlock(num_features * 4) for _ in range(num_residuals)])
        self.upsampling_layers = nn.ModuleList([
            ConvolutionalBlock(num_features * 4, num_features * 2, is_downsampling=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ConvolutionalBlock(num_features * 2, num_features * 1, is_downsampling=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ])
        self.last_layer = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial_layer(x)
        for layer in self.downsampling_layers:
            x = layer(x)
        x = self.residual_layers(x)
        for layer in self.upsampling_layers:
            x = layer(x)
        return torch.tanh(self.last_layer(x))

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial_layer = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2, inplace=True),
        )
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                nn.Conv2d(in_channels, feature, kernel_size=4, stride=2 if feature != features[-1] else 1, padding=1, padding_mode="reflect"),
            )
            layers.append(nn.InstanceNorm2d(feature))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial_layer(x)
        return self.model(x)

## Image Generation

In [None]:
# Define the Generator model (assuming it's already defined in your environment)
generator_g = Generator(img_channels=3).to(device)

# Load pre-trained weights
generator_g.load_state_dict(torch.load('/home/umang.shikarvar/instaformer/wb_CG_gen/generator_CG_200.pth', map_location=device)) # path to model

# Set the model to evaluation mode
generator_g.eval()

# Function to denormalize images (assuming they were normalized to [-1, 1])
def denormalize(tensor):
    return (tensor * 0.5) + 0.5

# Function to save generated images
def save_image(tensor, path):
    """Saves a PyTorch tensor as an image file."""
    tensor = denormalize(tensor).clamp(0, 1)  # Denormalize and clamp values
    vutils.save_image(tensor, path)

# Function to generate and save images with original file names
def generate_and_save_images(generator, dataloader, output_dir):
    os.makedirs(output_dir, exist_ok=True)  # Create output directory if not exists

    for example_input, img_path in dataloader:  # Now returns both image and path
        example_input = example_input.to(device)  # Move to device

        with torch.no_grad():
            generated_image = generator(example_input)  # Generate image

        # Extract original filename without extension
        original_filename = os.path.basename(img_path[0])  # Get first item from batch
        filename_without_ext = os.path.splitext(original_filename)[0]

        # Save generated image with modified filename
        save_path = os.path.join(output_dir, f"{filename_without_ext}.png")
        save_image(generated_image, save_path)
        print(f"Saved: {save_path}")

# Define output directory
output_dir = "/home/umang.shikarvar/instaformer/wb_CG/images"

# Generate and save images
print("Generating and saving images...")
generate_and_save_images(generator_g, source , output_dir)
print("Image generation complete.")

# CUT

## Model

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), 
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        x = self.conv(x)
        return self.pool(x), x  # (pooled output, pre-pool features)

class Bottleneck(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.block(x)

class DecoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, 2)
        self.conv = nn.Sequential(
            nn.Conv2d(out_ch * 2, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
        return self.conv(x)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.enc1 = EncoderBlock(3, 64)    # 640x640x3→320x320x64
        self.enc2 = EncoderBlock(64, 128)  # 320x320x64→160x160x128
        self.enc3 = EncoderBlock(128, 256) # 160x160x128→80x80x256
        self.enc4 = EncoderBlock(256, 512) # 80x80x256→40x40x512
        
        # Bottleneck
        self.bottleneck = Bottleneck(512, 1024) # 40x40x512→40x40x1024
        
        # Decoder
        self.dec4 = DecoderBlock(1024, 512) # 40x40x1024→80x80x512
        self.dec3 = DecoderBlock(512, 256) # 80x80x512→160x160x256
        self.dec2 = DecoderBlock(256, 128) # 160x160x256→320x320x128
        self.dec1 = DecoderBlock(128, 64) # 320x320x128→640x640x64
        
        self.out = nn.Sequential(
            nn.Conv2d(64, 3, 1), # 640x640x64→640x640x3
            nn.Tanh() # Normalize to [-1, 1]
        )

    def encoder(self, x):
        # Encoder forward pass only
        x, s1 = self.enc1(x)
        x, s2 = self.enc2(x)
        x, s3 = self.enc3(x)
        x, s4 = self.enc4(x)
        return [s1, s2, s3, s4]

    def forward(self, x):
        # Encoder with skip connections
        x, s1 = self.enc1(x)  # x: 320x320x64, s1: 640x640x64
        x, s2 = self.enc2(x)  # x: 160x160x128, s2: 320x320x128
        x, s3 = self.enc3(x)  # x: 80x80x256, s3: 160x160x256
        x, s4 = self.enc4(x)  # x: 40x40x512, s4: 80x80x512
        
        # Bottleneck
        x = self.bottleneck(x)  # 40x40x1024
        
        # Decoder with skip connections
        x = self.dec4(x, s4)  # 80x80x512 using x: 40x40x1024, s4: 80x80x512
        x = self.dec3(x, s3)  # 160x160x256 using x: 80x80x512, s3: 160x160x256
        x = self.dec2(x, s2)  # 320x320x128 using x: 160x160x256, s2: 320x320x128
        x = self.dec1(x, s1)  # 640x640x64 using x: 320x320x128, s1: 640x640x64
        
        return self.out(x), [s1, s2, s3, s4]

class HEncoder(nn.Module):  
    def __init__(self, input_channels, output_dim=256):
        super().__init__()
        # Layer-specific MLPs
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Linear(C, output_dim),  # Channel-wise transformation (C → 256)
                nn.ReLU()
            ) for C in input_channels
        ])
        
    def forward(self, features):
        embeddings = []
        for i, (proj, f) in enumerate(zip(self.proj, features)):
            
            # Reshape and apply MLP
            B, C, H, W = f.shape  # Update after downsampling
            f = f.permute(0, 2, 3, 1).reshape(B, H * W, C)  # [B, S, C]
            f_projected = proj(f)  # Apply MLP to each patch → [B, S, D]
            
            embeddings.append(f_projected)
        return embeddings

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # 640x640x3→320x320x64
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),  # 320x320x64→160x160x128
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),  # 160x160x128→80x80x256
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1),  # 80x80x256→40x40x512
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, 1, 1)  # 40x40x512→39x39x1
        )

    def forward(self, x):
        return self.model(x)

## Image Generation

In [None]:
# Define the Generator model (assuming it's already defined in your environment)
generator_g = Generator().to(device)  # Move model to device

# Load the trained model parameters
model_path = "/home/umang.shikarvar/instaformer/CUT_gen/generator_CUT_200.pth" # path to model
generator_g.load_state_dict(torch.load(model_path, map_location=device))

# Set the model to evaluation mode
generator_g.eval()

# Function to denormalize images (assuming they were normalized to [-1, 1])
def denormalize(tensor):
    return (tensor * 0.5) + 0.5

# Function to save generated images
def save_image(tensor, path):
    """Saves a PyTorch tensor as an image file."""
    tensor = denormalize(tensor).clamp(0, 1)  # Denormalize and clamp values
    vutils.save_image(tensor, path)

# Function to generate and save images with original file names
def generate_and_save_images(generator, dataloader, output_dir):
    os.makedirs(output_dir, exist_ok=True)  # Create output directory if not exists

    for example_input, img_path in dataloader:  # Now returns both image and path
        example_input = example_input.to(device)  # Move to device

        with torch.no_grad():
            generated_image,_ = generator(example_input)  # Generate image

        # Extract original filename without extension
        original_filename = os.path.basename(img_path[0])  # Get first item from batch
        filename_without_ext = os.path.splitext(original_filename)[0]

        # Save generated image with modified filename
        save_path = os.path.join(output_dir, f"{filename_without_ext}.png")
        save_image(generated_image, save_path)
        print(f"Saved: {save_path}")

# Define output directory
output_dir = "/home/umang.shikarvar/instaformer/delhi_CUT/images"

# Generate and save images
print("Generating and saving images...")
generate_and_save_images(generator_g, source , output_dir)
print("Image generation complete.")