# Download libraries

In [None]:
!pip install torch-fidelity
!pip install clean-fid

# Import libraries

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.utils import spectral_norm
from PIL import Image
import os
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.resnet import BasicBlock
from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np
from tqdm import tqdm
import time
from torch.utils.data import ConcatDataset
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt
import zipfile
from cleanfid import fid
from torchvision.transforms.functional import to_pil_image
import shutil

# Functions for GAN

In [None]:
# Gaussian Noise Layer
class GaussianNoise(nn.Module):
    def __init__(self, sigma=0.1):
        super().__init__()
        self.sigma = sigma
    
    def forward(self, x):
        if self.training:
            noise = torch.randn_like(x) * self.sigma
            return x + noise
        return x

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, stride=1, padding=1)
        )
    
    def forward(self, x):
        return self.model(x)


class TextGuidedGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, text_embedding_dim=512, dropout_rate=0.1, device='cuda'):
        super().__init__()
        
        # CLIP text encoder
        self.clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        
        # Freeze CLIP parameters
        for param in self.clip_model.parameters():
            param.requires_grad = False
            param.requires_grad_(False)
            
        # 1. Initial Downsampling + 2 ResNet50
        resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.downsample_resnet = nn.Sequential(
            resnet.conv1,      # First downsampling
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,    # downsampling
            resnet.layer1,     # First ResNet block
            resnet.layer2      # Second ResNet block
        )
        for param in self.downsample_resnet.parameters():
            param.requires_grad = False
        
        # 3. DS before transformer
        self.pre_transformer_ds = nn.Conv2d(512 + text_embedding_dim, 256, 3, stride=2, padding=1)
        
        # 4. Transformer block
        self.transformer = nn.Transformer(
            d_model=256,
            nhead=8,
            num_encoder_layers=3,
            num_decoder_layers=3,
            dim_feedforward=1024,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        
        # 5. US after transformer
        self.post_transformer_us = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        
        # 6. ResNet block post-transformer
        self.post_transformer_resnet = BasicBlock(128, 128)
        
       # 7. Final Upsampling with  three upsampling steps
        self.final_upsample = nn.Sequential(
            # First upsampling
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # Second upsampling
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            # Third upsampling
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            # Final Conv
            nn.Conv2d(32, output_channels, 7, padding=3),
            nn.Tanh()
        )
        
    def encode_text(self, text):
        # Get the device from CLIP model
        device = self.clip_model.device
        # Tokenize and encode text using CLIP
        tokens = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(device)
        text_features = self.clip_model(**tokens).last_hidden_state.mean(dim=1)  # Average pooling
        return text_features
    
    def forward(self, x, text_prompt):
        batch_size = x.size(0)
        
        # Text embedding
        if isinstance(text_prompt, (list, tuple)):
            text_embedding = self.encode_text(text_prompt)
        else:
            text_embedding = self.encode_text([text_prompt] * batch_size)
            
        x = self.downsample_resnet(x)
        
        # 3. Concatenate text embedding and DS
        text_embedding = text_embedding.unsqueeze(-1).unsqueeze(-1)
        text_embedding = text_embedding.expand(-1, -1, x.size(2), x.size(3))
        
        x = torch.cat([x, text_embedding], dim=1)
        
        x = self.pre_transformer_ds(x)  # 56 -> 28
        
        # 4. Transformer
        b, c, h, w = x.shape
        x = x.view(b, c, h*w).permute(0, 2, 1)
        
        x = self.transformer(x, x)
        
        x = x.permute(0, 2, 1).view(b, c, h, w)
        
        # 5. Post-transformer US
        x = self.post_transformer_us(x)
        
        # 6. Post-transformer ResNet block
        x = self.post_transformer_resnet(x)
        
        # 7. Final upsampling
        x = self.final_upsample(x)
        
        return x
        
class TextGuidedImageDataset(Dataset):
    def __init__(self, source_dir, target_dir, text_prompt, transform=None):
        self.source_dir = source_dir
        self.target_dir = target_dir
        self.text_prompt = text_prompt
        self.transform = transform
        source_images = set(os.listdir(source_dir))
        target_images = set(os.listdir(target_dir))
        self.images = list(source_images.intersection(target_images))
        if len(self.images) == 0:
            raise ValueError("No matching images found.")
        print(f"Found {len(self.images)} matching images for prompt: {text_prompt}")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        source_path = os.path.join(self.source_dir, img_name)
        target_path = os.path.join(self.target_dir, img_name)
        source_image = Image.open(source_path).convert('RGB')
        target_image = Image.open(target_path).convert('RGB')
        if self.transform:
            source_image = self.transform(source_image)
            target_image = self.transform(target_image)
        return source_image, target_image, str(self.text_prompt)


def train_text_guided_gan(generator, discriminator, train_loader, val_loader, num_epochs, device, save_dir="models"):
    os.makedirs(save_dir, exist_ok=True)
    criterion_gan = nn.MSELoss()
    criterion_pixel = nn.L1Loss()
    optimizer_g = torch.optim.AdamW(generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=1e-1)
    optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999), weight_decay=1e-1)
    scheduler_g = CosineAnnealingLR(optimizer_g, T_max=num_epochs, eta_min=1e-6)
    scheduler_d = CosineAnnealingLR(optimizer_d, T_max=num_epochs, eta_min=1e-6)
    best_val_loss = float('inf')
    patience = 10
    early_stop_counter = 0
    train_metrics = {'g_loss': [], 'd_loss': [], 'val_loss': []}
    
    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        total_train_loss = 0
        
        for source, target, text_prompt in tqdm(train_loader):
            batch_size = source.size(0)
            real = target.to(device)
            source = source.to(device)
            # Ensure text_prompt is a list of strings
            if isinstance(text_prompt, torch.Tensor):
                text_prompt = text_prompt.tolist()
            
            # Train discriminator
            optimizer_d.zero_grad()
            fake = generator(source, text_prompt)
            pred_real = discriminator(real)
            pred_fake = discriminator(fake.detach())
            real_labels = torch.ones_like(pred_real) * 0.9
            fake_labels = torch.zeros_like(pred_fake) * 0.1
            loss_d_real = criterion_gan(pred_real, real_labels)
            loss_d_fake = criterion_gan(pred_fake, fake_labels)
            loss_d = (loss_d_real + loss_d_fake) * 0.5
            loss_d.backward()
            optimizer_d.step()
            
            # Train generator
            optimizer_g.zero_grad()
            pred_fake = discriminator(fake)
            loss_g_gan = criterion_gan(pred_fake, torch.ones_like(pred_fake))
            loss_g_pixel = criterion_pixel(fake, real) * 20
            loss_g = loss_g_gan + loss_g_pixel
            loss_g.backward()
            optimizer_g.step()
            
            train_metrics['g_loss'].append(loss_g.item())
            train_metrics['d_loss'].append(loss_d.item())
        
        # Validation
        generator.eval()
        total_val_loss = 0
        with torch.no_grad():
            for source, target, text_prompt in val_loader:
                source = source.to(device)
                target = target.to(device)
                fake = generator(source, text_prompt)
                val_loss = criterion_pixel(fake, target).item()
                total_val_loss += val_loss
        
        avg_val_loss = total_val_loss / len(val_loader)
        train_metrics['val_loss'].append(avg_val_loss)
        
        print(f"Epoch {epoch+1}/{num_epochs}: G_loss={np.mean(train_metrics['g_loss'][-len(train_loader):]):.4f}, "
              f"D_loss={np.mean(train_metrics['d_loss'][-len(train_loader):]):.4f}, Val_loss={avg_val_loss:.4f}")
        
        scheduler_g.step()
        scheduler_d.step()
    
    # Save final model
    torch.save({
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict(),
        'epoch': num_epochs-1,
        'train_metrics': train_metrics
    }, os.path.join(save_dir, "final_model.pth"))
    
    return train_metrics

# Initialize dataset and train GAN

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Define source directory and text-guided target directories with prompts
source_dir = '/kaggle/input/multiprompt/original_images_ordered'
style_configs = [
    ('/kaggle/input/multiprompt/albino', "convert to albino style"),
    ('/kaggle/input/multiprompt/blonde', "make the hair blonde"),
    ('/kaggle/input/multiprompt/gogh', "apply van gogh style"),
    ('/kaggle/input/multiprompt/old', "make the person old"),
    ('/kaggle/input/multiprompt/pink', "make the hair pink")]
    
# Create datasets with text prompts
datasets = []
for target_dir, prompt in style_configs:
    dataset = TextGuidedImageDataset(source_dir, target_dir, prompt, transform)
    datasets.append(dataset)

full_dataset = ConcatDataset(datasets)
print(f"Combined dataset contains {len(full_dataset)} images")

train_size = int(0.8 * len(full_dataset))
val_size = int(0.1 * len(full_dataset))
test_size = int(0.1 * len(full_dataset))
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

print(len(train_dataset), len(val_dataset), len(test_dataset))

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Initialize models
generator = TextGuidedGenerator(device=device).to(device)
discriminator = Discriminator().to(device)

# Train the model
start_time = time.time()
# train_metrics = train_text_guided_gan(
#     generator, 
#     discriminator, 
#     train_loader, 
#     val_loader, 
#     num_epochs=100, 
#     device=device
# )
end_time = (time.time() - start_time)/60
print(f"Training completed in {end_time:.2f} minutes")

# Functions for inference on model

In [None]:
def load_model(generator, discriminator, checkpoint_path, device):
    """Load model weight from file checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator'])
    discriminator.load_state_dict(checkpoint['discriminator'])
    print(f"Model loaded from checkpoint: {checkpoint_path}")

def process_single_image(image_path, generator, text_prompt, transform, device):
    """Process a single image with G and promtp"""
    # Load image
    image = Image.open(image_path).convert('RGB')
    original_image = image.copy()
    
    # transform
    image_tensor = transform(image).unsqueeze(0).to(device)  # Aggiungi dimensione batch
    
    # Generate image
    generator.eval()
    with torch.no_grad():
        generated_image_tensor = generator(image_tensor, text_prompt)
    
    # Convert to PIL
    generated_image_tensor = generated_image_tensor.squeeze(0).cpu()
    generated_image = ToPILImage()(torch.clamp((generated_image_tensor + 1) / 2, 0, 1))  # Denormalizzazione
    
    return original_image, generated_image

def visualize_images(original_image, generated_image):
    """Shows original and created images side by side"""
    #fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    #axes[0].imshow(original_image)
    #axes[0].axis("off")
    #axes[0].set_title("Input")
    
    #axes[1].imshow(generated_image)
    #axes[1].axis("off")
    #axes[1].set_title("Output")
    plt.imshow(generated_image)
    plt.axis('off')
    
    plt.show()

# Inference

In [None]:
# path of input image
image_path = '/kaggle/input/obamaa/obama.jpg'

#checkpoint_path = '/kaggle/working/models/final_model.pth'
checkpoint_path = '/kaggle/input/e2gan-v6/pytorch/default/2/e2gan-v6-2'
text_prompt = "make the hair pink"

#["convert to albino style", "make the hair blonde","apply van gogh style", "make the person old", "make the hair pink"]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = TextGuidedGenerator(device=device).to(device)
discriminator = Discriminator().to(device)
    
load_model(generator, discriminator, checkpoint_path, device)
    
# generate image
start_time = time.time()
original_image, generated_image = process_single_image(image_path, generator, text_prompt, transform, device)
end_time = time.time() - start_time

# milliseconds time
print(end_time*1000)
    
visualize_images(original_image, generated_image)

# Compute FID on Testset

In [None]:
def generate_image(original_tensor, generator, text_prompt, device):
    """Generate an image by using the tensor of the original one and the promtp"""
    original_tensor = original_tensor.unsqueeze(0).to(device)  # add batch dimension
    generator.eval()
    with torch.no_grad():
        generated_tensor = generator(original_tensor, text_prompt)
    created = torch.clamp((generated_tensor.squeeze(0).cpu() + 1) / 2, 0, 1)  # denormalize
    created = created.permute(1, 2, 0).numpy()  # from (C, H, W) to (H, W, C)
    created = to_pil_image(created)
    
    return created

# promtps
prompts = ["convert to albino style", "make the hair blonde",
           "apply van gogh style", "make the person old",
           "make the hair pink"]

styles = ['albino', 'blonde', 'gogh', 'old', 'pink']

# list where to insert images for test
test0, test1,test2, test3, test4 = [], [], [], [], []

for source_image, target_image, prompt in test_dataset:
    if prompt == prompts[0]:
        if len(test0) < 30:
            test0.append((source_image, target_image, prompt))
    elif prompt == prompts[1]:
        if len(test1) < 30:
            test1.append((source_image, target_image, prompt))
    elif prompt == prompts[2]:
        if len(test2) < 30:
            test2.append((source_image, target_image, prompt))
    elif prompt == prompts[3]:
        if len(test3) < 30:
            test3.append((source_image, target_image, prompt))
    elif prompt == prompts[4]:
        if len(test4) < 30:
            test4.append((source_image, target_image, prompt))

print(len(test0),len(test1),len(test2),len(test3),len(test4))

test_lists = [test0, test1, test2, test3, test4]

# Generate image from triple (original, modified, prompt)
start_index = 1001
for style_idx, prompt in enumerate(prompts):
    style_folder = styles[style_idx]
    os.makedirs(f'/kaggle/working/{style_folder}', exist_ok=True)

    for i, (original_tensor, _, _) in enumerate(test_lists[style_idx]):

        # Generate image
        generated_image = generate_image(original_tensor, generator, prompt, device)

        # Save
        save_path = f'/kaggle/working/{style_folder}/{start_index + i}.png'
        generated_image.save(save_path)

    # Compress (to download)
    zip_file_path = f'/kaggle/working/{style_folder}.zip'
    with zipfile.ZipFile(zip_file_path, 'w') as zipf:
        for root, _, files in os.walk(f'/kaggle/working/{style_folder}'):
            for file in files:
                zipf.write(os.path.join(root, file), arcname=os.path.join(style_folder, file))

fid_scores = []

# Comoute FID between image generated and image from diffusion
for style_idx, style_folder in enumerate(styles):
    diff_dir = f'/kaggle/working/diff_{style_folder}'

    os.makedirs(diff_dir, exist_ok=True)

    for i, (_, modified_tensor, _) in enumerate(test_lists[style_idx]):

        # Save generated image
        diff_image_path = os.path.join(diff_dir, f"diff_{style_folder}_{i}.png")
        diff_image = to_pil_image(torch.clamp((modified_tensor + 1) / 2, 0, 1))  # Denormalizzazione
        diff_image.save(diff_image_path)

    # compute FID for every style
    score = fid.compute_fid(diff_dir, f'/kaggle/working/{style_folder}')
    fid_scores.append(score)
    print(f"FID for {style_folder}: {score}")

    # clean dir
    shutil.rmtree(diff_dir)

# final FID
median_fid = np.median(fid_scores)
print(f"The overall FID score is {median_fid}")
print(fid_scores)