In [None]:
!pip install torch_fidelity

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import save_image
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.utils import spectral_norm
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
import time
import json
import torch_fidelity

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Generator No LoRA
class ImprovedGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, dropout_rate=0.1):
        super().__init__()
        resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.initial_layers = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2
        )
        for param in self.initial_layers.parameters():
            param.requires_grad = False  # Freeze pretrained layers

        self.noise_layer = GaussianNoise(0.01)
        self.dropout = nn.Dropout2d(dropout_rate)
        # Test without spectral normalization
        self.ds_transformer = nn.Conv2d(512, 256, 3, stride=2, padding=1)  # Without spectral norm

        # self.transformer = AntiOverfittingTransformerBlock(256, dropout_rate=dropout_rate)
        # Transformer
        self.transformer = nn.Transformer(
            d_model=256,
            nhead=8,  # Added number of attention heads
            num_encoder_layers=3,
            num_decoder_layers=3,
            dim_feedforward=1024,
            dropout=0.1,
            activation='gelu',
            batch_first=True,
        )

        self.us_transformer = spectral_norm(nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1))
        self.decoder_blocks = nn.ModuleList([
            nn.Sequential(
                spectral_norm(nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1)),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2),
                nn.Dropout2d(dropout_rate)
            ),
            nn.Sequential(
                spectral_norm(nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2),
                nn.Dropout2d(dropout_rate)
            ),
            nn.Sequential(
                spectral_norm(nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(0.2),
                nn.Dropout2d(dropout_rate)
            )
        ])
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, output_channels, 7, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.noise_layer(x)
        features = self.initial_layers(x)
        x = self.ds_transformer(features)

        # Reshape for transformer
        b, c, h, w = x.shape
        x = x.view(b, c, h*w).permute(0, 2, 1)

        # Transformer processing
        x = self.transformer(x, x)

        # Reshape back to conv format
        x = x.permute(0, 2, 1).view(b, c, h, w)

        x = self.dropout(x)
        x = self.us_transformer(x)
        for decoder_block in self.decoder_blocks:
            identity = x
            x = decoder_block(x)
            if x.size() == identity.size():
                x = x + identity
        return self.final_conv(x)

# Gaussian Noise Layer
class GaussianNoise(nn.Module):
    def __init__(self, sigma=0.1):
        super().__init__()
        self.sigma = sigma

    def forward(self, x):
        if self.training:
            noise = torch.randn_like(x) * self.sigma
            return x + noise
        return x


In [None]:
# Test Function
def test(generator, input_image, device='cuda'):
    """
    Function to test the generator with a single input image and return the generated image.

    Parameters:
    - generator (nn.Module): The generator model.
    - input_image (PIL.Image): The input image to be processed.
    - device (torch.device): The device to run the model on (e.g., 'cuda' or 'cpu').

    Returns:
    - output_image (PIL.Image): The generated output image.
    """
    # Transform the input image to a tensor
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    input_tensor = transform(input_image).unsqueeze(0).to(device)

    # Set the generator to evaluation mode
    generator.eval()
    with torch.no_grad():
        # Generate the image
        output_tensor = generator(input_tensor)

    # Convert the output tensor to an image
    output_tensor = output_tensor.squeeze(0).cpu()
    output_image = transforms.ToPILImage()(output_tensor.clamp(0, 1))  # Clamp to valid image range

    return output_image


# Main
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pre-trained generator (ensure the model weights are loaded)
model_str = "/content/drive/MyDrive/ATML/Models/best_model.pth"
generator = torch.nn.DataParallel(ImprovedGenerator().to(device))
# add map_location = 'cpu' is the machine has no gpu
generator.load_state_dict(torch.load(model_str, weights_only=True)['generator'])

# Load an input image
prefix_img = '/content/drive/MyDrive/ATML/'
input_image = Image.open(prefix_img + "obama.jpg")

# Run the test function
output_image = test(generator, input_image, device)

# Save or display the output image
result_dir = '/content/drive/MyDrive/ATML/test/'
os.makedirs(result_dir, exist_ok=True)
output_image.save(result_dir + "obama_generated.jpg")
plt.imshow(output_image)
plt.axis('off')
plt.show()

In [None]:
class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, r=4, alpha=1.0):
        super().__init__()
        self.r = r
        self.scaling = alpha / r
        self.linear = nn.Linear(in_features, out_features, bias=False)

        # LoRA matrices
        self.lora_A = nn.Parameter(torch.zeros(r, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, r))
        nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        lora_out = (x @ self.lora_A.T) @ self.lora_B.T
        return self.linear(x) + self.scaling * lora_out

In [None]:
class SuperMotherFatherGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, dropout_rate=0.1):
        super().__init__()
        resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.initial_layers = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2
        )
        for param in self.initial_layers.parameters():
            param.requires_grad = False  # Freeze pretrained layers

        # Add LoRA to the ResNet blocks
        self.lora_resnet1 = LoRALinear(512, 512, r=4, alpha=1.0)
        self.lora_resnet2 = LoRALinear(1024, 512, r=4, alpha=1.0)

        self.noise_layer = GaussianNoise(0.01)
        self.dropout = nn.Dropout2d(dropout_rate)
        self.ds_transformer = nn.Conv2d(512, 256, 3, stride=2, padding=1)

        # Transformer with LoRA
        self.transformer = nn.Transformer(
            d_model=256,
            nhead=8,
            num_encoder_layers=3,
            num_decoder_layers=3,
            dim_feedforward=1024,
            dropout=0.1,
            activation='gelu',
            batch_first=True,
        )
        self.lora_transformer = LoRALinear(256, 256, r=4, alpha=1.0)

        self.us_transformer = spectral_norm(nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1))

        self.decoder_blocks = nn.ModuleList([
            nn.Sequential(
                spectral_norm(nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1)),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2),
                nn.Dropout2d(dropout_rate)
            ),
            nn.Sequential(
                spectral_norm(nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2),
                nn.Dropout2d(dropout_rate)
            ),
            nn.Sequential(
                spectral_norm(nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(0.2),
                nn.Dropout2d(dropout_rate)
            )
        ])
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, output_channels, 7, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.noise_layer(x)
        features = self.initial_layers(x)

        # Add LoRA to ResNet features
        # Reshape features to (batch_size, num_features) before applying LoRALinear
        b, c, h, w = features.shape  # Get the shape of the features
        features = features.view(b, c, -1).permute(0, 2, 1) # Reshape for LoRALinear, preserving batch dimension
        features = self.lora_resnet1(features)
        features = features.permute(0, 2, 1).view(b, c, h, w) # Reshape back to original dimensions

        x = self.ds_transformer(features)

        # Reshape for transformer
        b, c, h, w = x.shape
        x = x.view(b, c, h * w).permute(0, 2, 1)

        # Transformer processing with LoRA
        x = self.transformer(x, x)
        x = self.lora_transformer(x)

        # Reshape back to conv format
        x = x.permute(0, 2, 1).view(b, c, h, w)

        x = self.dropout(x)
        x = self.us_transformer(x)
        for decoder_block in self.decoder_blocks:
            identity = x
            x = decoder_block(x)
            if x.size() == identity.size():
                x = x + identity
        return self.final_conv(x)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

def fine_tune_with_lora(
    improved_generator_file: str,
    train_dataset: torch.utils.data.Dataset,
    val_dataset: torch.utils.data.Dataset = None,
    batch_size: int = 32,
    num_epochs: int = 10,
    lr: float = 1e-5,
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
):
    """
    Fine-tune LoRA layers in the SuperGenerator.

    Parameters:
    - improved_generator_file (str): Path to the .pth file of ImprovedGenerator.
    - train_dataset (torch.utils.data.Dataset): Dataset for training.
    - val_dataset (torch.utils.data.Dataset, optional): Dataset for validation.
    - batch_size (int): Batch size for DataLoader.
    - num_epochs (int): Number of training epochs.
    - lr (float): Learning rate for optimizer.
    - device (str): Device to use for training ('cuda' or 'cpu').
    """
    # Load ImprovedGenerator weights
    print("Loading Generator weights into SuperGenerator...")
    improved_checkpoint = torch.load(improved_generator_file, weights_only=True)['generator']
    improved_state_dict = improved_checkpoint['state_dict'] if 'state_dict' in improved_checkpoint else improved_checkpoint

    # Load pretrained weights into SuperGenerator
    super_generator = torch.nn.DataParallel(SuperMotherFatherGenerator())
    super_state_dict = super_generator.state_dict()
    updated_state_dict = {}
    print(f"Found: {len(improved_state_dict.items())} items")

    for name, param in improved_state_dict.items():
        if name in super_state_dict:
            if 'lora' not in name and super_state_dict[name].shape == param.shape:
                updated_state_dict[name] = param
                #print(f"Loaded: {name}")
            else:
                print(f"Skipped (LoRA or shape mismatch): {name}")

    print(f"Loaded {len(updated_state_dict)} weights out of {len(super_state_dict)}")
    super_state_dict.update(updated_state_dict)
    super_generator.load_state_dict(super_state_dict)
    print("Weights loaded successfully!")

    # Freeze non-LoRA parameters
    print("Freezing non-LoRA parameters...")
    for name, param in super_generator.named_parameters():
        if 'lora' not in name:
            param.requires_grad = False

    # Move model to device
    super_generator = super_generator.to(device)

    # Optimizer and Loss
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, super_generator.parameters()), lr=lr)
    criterion = nn.MSELoss()  # Example loss; replace with your task-specific loss

    # Training loop
    print("Starting training...")
    for epoch in range(num_epochs):
        super_generator.train()
        total_loss = 0.0

        # Create progress bar for training loop using train_dataloader
        train_loop = tqdm(train_loader, desc=f"Epoch [{epoch + 1}/{num_epochs}]", leave=False)

        for i, (inputs, targets) in enumerate(train_loop):
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = super_generator(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update progress bar description with current loss
            train_loop.set_postfix(loss=loss.item())


        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}")

        # Validation (if val_loader is provided)
        if val_loader:
            super_generator.eval()
            total_val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = super_generator(inputs)
                    val_loss = criterion(outputs, targets)
                    total_val_loss += val_loss.item()

            avg_val_loss = total_val_loss / len(val_loader)
            print(f"Epoch [{epoch + 1}/{num_epochs}], Val Loss: {avg_val_loss:.4f}")

    print("Training complete! Saving fine-tuned model...")
    torch.save(super_generator.state_dict(), "/content/drive/MyDrive/ATML/Models/super_generator_finetuned.pth")
    print("Model saved as 'super_generator_finetuned.pth'")


In [None]:
# Image Pair Dataset
class ImagePairDataset(Dataset):
    def __init__(self, source_dir, target_dir, transform=None):
        self.source_dir = source_dir
        self.target_dir = target_dir
        self.transform = transform
        source_images = set(os.listdir(source_dir))
        target_images = set(os.listdir(target_dir))
        self.images = list(source_images.intersection(target_images))
        if len(self.images) == 0:
            raise ValueError("No matching images found.")
        print(f"Found {len(self.images)} matching images")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        source_path = os.path.join(self.source_dir, img_name)
        target_path = os.path.join(self.target_dir, img_name)
        source_image = Image.open(source_path).convert('RGB')
        target_image = Image.open(target_path).convert('RGB')
        if self.transform:
            source_image = self.transform(source_image)
            target_image = self.transform(target_image)
        return source_image, target_image

In [None]:
# Dataset and transforms
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load full dataset
full_dataset = ImagePairDataset(
    source_dir='/content/drive/MyDrive/ATML/original_images',
    target_dir='/content/drive/MyDrive/ATML/blonde',
    transform=transform
)

# Split dataset into train and validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

print("Startig training function time...")
print()
start_time = time.time()

model_str = "/content/drive/MyDrive/ATML/Models/best_model.pth"
fine_tune_with_lora(
    improved_generator_file=model_str,
    train_dataset=train_loader,
    val_dataset=val_loader,
    batch_size=32,
    num_epochs=10,
    lr=1e-4,
    device="cuda"
)
end_time = (time.time() - start_time)/60
print(f"Training time: {end_time:.2f} minutes")

In [None]:
# Main
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pre-trained generator (ensure the model weights are loaded)
model_str = "/content/drive/MyDrive/ATML/Models/super_generator_finetuned.pth"
generator = torch.nn.DataParallel(SuperMotherFatherGenerator().to(device))
# add map_location = 'cpu' is the machine has no gpu
generator.load_state_dict(torch.load(model_str, weights_only=True))

# Load an input image
prefix_img = '/content/drive/MyDrive/ATML/'
input_image = Image.open(prefix_img + "obama.jpg")

# Run the test function
output_image = test(generator, input_image, device)

# Save or display the output image
result_dir = '/content/drive/MyDrive/ATML/test/'
os.makedirs(result_dir, exist_ok=True)
output_image.save(result_dir + "super_generated.jpg")
plt.imshow(output_image)
plt.axis('off')
plt.show()

In [None]:
def calculate_fid(real_images_dir, generated_images_dir, batch_size=50, device='cuda'):
    """Calculates the FID score between real and generated images.

    Args:
        real_images_dir (str): Path to the directory containing real images.
        generated_images_dir (str): Path to the directory containing generated images.
        batch_size (int, optional): Batch size for loading images. Defaults to 50.
        device (str, optional): Device to use for calculations. Defaults to 'cuda'.

    Returns:
        float: The FID score.
    """
    metrics_dict = torch_fidelity.calculate_metrics(
        input1=real_images_dir,
        input2=generated_images_dir,
        batch_size=batch_size,
        cuda=True if device == 'cuda' else False,  # Use cuda if available
        isc=False,  # Inception Score (ISC) is not needed
        fid=True,  # Calculate FID
    )
    return metrics_dict['frechet_inception_distance']


In [None]:
# Calculate FID
real_images_dir = '/content/drive/MyDrive/ATML/test/super_generated.jpg'  # Replace with your real images directory
generated_images_dir = '/content/drive/MyDrive/ATML/test/generated.jpg'  # Replace with your generated images directory
fid_score = calculate_fid(real_images_dir, generated_images_dir)

print(f"FID score: {fid_score:.4f}")