In [None]:
!pip install peft

In [None]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import time
import copy
import shutil

import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.resnet import BasicBlock

from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW

from torch.nn.utils import spectral_norm

from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor, CLIPModel
from sklearn.cluster import KMeans
from torchvision.transforms import ToPILImage
from PIL import Image

## K-means
Used to reduce the dataset to take only the most meaningful images with prompt

In [None]:
class ImageClusterSelector:
    def __init__(self, n_clusters, input_dir, output_dir, device=None):
        """
        Initialize the image cluster selector.
        
        Args:
            n_clusters (int): Number of clusters for K-means
            input_dir (str): Directory containing input images
            output_dir (str): Directory to save selected images
            device (str, optional): Device to use for computation
        """
        self.n_clusters = n_clusters
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize CLIP model
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        
        # Image preprocessing
        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.48145466, 0.4578275, 0.40821073),
                std=(0.26862954, 0.26130258, 0.27577711)
            )
        ])

    def extract_embedding(self, image_path):
        """Extract embedding from a single image."""
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess(image).unsqueeze(0).to(self.device)
            with torch.no_grad():
                embedding = self.clip_model.get_image_features(image)
            return embedding.cpu().numpy().flatten()
        except Exception as e:
            return None

    def get_image_paths(self):
        """Get valid image paths from input directory."""
        valid_extensions = ('.png', '.jpg', '.jpeg', '.webp')
        image_paths = [os.path.join(self.input_dir, img) 
            for img in os.listdir(self.input_dir) if img.lower().endswith(valid_extensions)]
        if not image_paths:
            raise ValueError(f"No valid images found in {self.input_dir}")
        return image_paths

    def process_images(self):
        """Main processing pipeline."""
        try:
            # Get image paths
            image_paths = self.get_image_paths()
            
            # Extract embeddings
            embeddings_list = []
            valid_paths = []
            
            for img_path in tqdm(image_paths, desc="Processing images"):
                embedding = self.extract_embedding(img_path)
                if embedding is not None:
                    embeddings_list.append(embedding)
                    valid_paths.append(img_path)
            
            embeddings = np.array(embeddings_list)
            
            # Adjust n_clusters if necessary
            self.n_clusters = min(self.n_clusters, len(valid_paths))

            # Perform clustering
            kmeans = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10)
            kmeans.fit(embeddings)

            selected_images = []
            for cluster_id in range(self.n_clusters):
                cluster_indices = np.where(kmeans.labels_ == cluster_id)[0]
                cluster_embeddings = embeddings[cluster_indices]
                distances = np.linalg.norm(
                    cluster_embeddings - kmeans.cluster_centers_[cluster_id],
                    axis=1)
                closest_index = cluster_indices[np.argmin(distances)]
                selected_images.append(valid_paths[closest_index])

            # Create output directory if it doesn't exist
            os.makedirs(self.output_dir, exist_ok=True)
            #Copy selected images
            for img_path in selected_images:
                image_name = os.path.basename(img_path)
                output_path = os.path.join(self.output_dir, image_name)
                shutil.copy2(img_path, output_path)
            print(f"Selected {len(selected_images)} images\n")
            return selected_images

        except Exception as e:
            raise

def process_all_folders_in_directory(parent_dir, n_clusters, output_dir):
    """
    Process all subfolders inside a parent directory, creating a zip file for each.

    Args:
        parent_dir (str): Parent directory containing folders to process
        n_clusters (int): Number of clusters for image selection
        output_dir (str): Directory to save output zip files
    """
    folders = [f for f in os.listdir(parent_dir) if os.path.isdir(os.path.join(parent_dir, f))]
    folders.remove('original_images_ordered')
    
    for folder in folders:
        input_dir = os.path.join(parent_dir, folder)
        folder_output_dir = os.path.join(output_dir, folder)
        os.makedirs(folder_output_dir, exist_ok=True)
        
        print(f"Processing folder: {folder}")
        selector = ImageClusterSelector(
            n_clusters=n_clusters,
            input_dir=input_dir,
            output_dir=folder_output_dir
        )
        selector.process_images()

# Main script
if __name__ == "__main__":
    parent_directory = "/kaggle/input/multiprompt"  # Parent directory containing subfolders
    output_directory = "/kaggle/working/images/"  # Directory to save the zip files
    n_clusters = 400  # Number of clusters for image selection

    os.makedirs(output_directory, exist_ok=True)
    process_all_folders_in_directory(parent_directory, n_clusters, output_directory)

## LoRA
Implementing LoRA with peft module following the paper indications, adding it to the downsampling, transformer and upsampling blocks

In [None]:
# Gaussian Noise Layer
class GaussianNoise(nn.Module):
    def __init__(self, sigma=0.1):
        super().__init__()
        self.sigma = sigma

    def forward(self, x):
        if self.training:
            noise = torch.randn_like(x) * self.sigma
            return x + noise
        return x

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, 4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

class TextGuidedGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, text_embedding_dim=512, dropout_rate=0.1, device='cuda'):
        super().__init__()
        
        # CLIP text encoder
        self.clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        
        # Freeze CLIP parameters
        for param in self.clip_model.parameters():
            param.requires_grad = False
            param.requires_grad_(False)
            
        # 1. Initial Downsampling + 2 ResNet50
        resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.downsample_resnet = nn.Sequential(
            resnet.conv1,      # First downsampling
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,    # downsampling
            resnet.layer1,     # First ResNet block
            resnet.layer2      # Second ResNet block
        )
        for param in self.downsample_resnet.parameters():
            param.requires_grad = False
        
        # 3. DS before transformer
        self.pre_transformer_ds = nn.Conv2d(512 + text_embedding_dim, 256, 3, stride=2, padding=1)
        
        # 4. Transformer block
        self.transformer = nn.Transformer(
            d_model=256,
            nhead=8,
            num_encoder_layers=3,
            num_decoder_layers=3,
            dim_feedforward=1024,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        
        # 5. US after transformer
        self.post_transformer_us = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        
        # 6. ResNet block post-transformer
        self.post_transformer_resnet = BasicBlock(128, 128)
        
       # 7. Final Upsampling with  three upsampling steps
        self.final_upsample = nn.Sequential(
            # First upsampling
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # Second upsampling
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            # Third upsampling
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            # Final Conv
            nn.Conv2d(32, output_channels, 7, padding=3),
            nn.Tanh()
        )
        
    def encode_text(self, text):
        # Get the device from CLIP model
        device = self.clip_model.device
        # Tokenize and encode text using CLIP
        tokens = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(device)
        text_features = self.clip_model(**tokens).last_hidden_state.mean(dim=1)  # Average pooling
        return text_features
    
    def forward(self, x, text_prompt):
        batch_size = x.size(0)
        
        # Text embedding
        if isinstance(text_prompt, (list, tuple)):
            text_embedding = self.encode_text(text_prompt)
        else:
            text_embedding = self.encode_text([text_prompt] * batch_size)
            
        x = self.downsample_resnet(x)
        
        # 3. Concatenate text embedding and DS
        text_embedding = text_embedding.unsqueeze(-1).unsqueeze(-1)
        text_embedding = text_embedding.expand(-1, -1, x.size(2), x.size(3))
        
        x = torch.cat([x, text_embedding], dim=1)
        
        x = self.pre_transformer_ds(x)  # 56 -> 28
        
        # 4. Transformer
        b, c, h, w = x.shape
        x = x.view(b, c, h*w).permute(0, 2, 1)
        
        x = self.transformer(x, x)
        
        x = x.permute(0, 2, 1).view(b, c, h, w)
        
        # 5. Post-transformer US
        x = self.post_transformer_us(x)
        
        # 6. Post-transformer ResNet block
        x = self.post_transformer_resnet(x)
        
        # 7. Final upsampling
        x = self.final_upsample(x)
        
        return x

class TextGuidedImageDataset(Dataset):
    def __init__(self, source_dir, target_dir, text_prompt, transform=None):
        self.source_dir = source_dir
        self.target_dir = target_dir
        self.text_prompt = text_prompt
        self.transform = transform
        source_images = set(os.listdir(source_dir))
        target_images = set(os.listdir(target_dir))
        self.images = list(source_images.intersection(target_images))
        if len(self.images) == 0:
            raise ValueError("No matching images found.")
        print(f"Found {len(self.images)} matching images for prompt: {text_prompt}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        source_path = os.path.join(self.source_dir, img_name)
        target_path = os.path.join(self.target_dir, img_name)
        source_image = Image.open(source_path).convert('RGB')
        target_image = Image.open(target_path).convert('RGB')
        if self.transform:
            source_image = self.transform(source_image)
            target_image = self.transform(target_image)
        # Ensure text prompt is passed as string
        return source_image, target_image, str(self.text_prompt)

def train_text_guided_gan(generator, discriminator, train_loader, val_loader, num_epochs, device, save_dir="models"):
    os.makedirs(save_dir, exist_ok=True)
    criterion_gan = nn.MSELoss()
    criterion_pixel = nn.L1Loss()
    
    optimizer_g = torch.optim.AdamW(generator.parameters(), lr=1e-5, betas=(0.5, 0.999), weight_decay=1e-4)
    optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=1e-5, betas=(0.5, 0.999), weight_decay=1e-4)
    scheduler_g = CosineAnnealingLR(optimizer_g, T_max=num_epochs, eta_min=1e-6)
    scheduler_d = CosineAnnealingLR(optimizer_d, T_max=num_epochs, eta_min=1e-6)
    train_metrics = {'g_loss': [], 'd_loss': [], 'val_loss': []}

    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()
        for source, target, text_prompt in tqdm(train_loader):
            real = target.to(device)
            source = source.to(device)
            # Ensure text_prompt is a list of strings
            if isinstance(text_prompt, torch.Tensor):
                text_prompt = text_prompt.tolist()

            # Train discriminator
            optimizer_d.zero_grad()
            fake = generator(source, text_prompt)
            pred_real = discriminator(real)
            pred_fake = discriminator(fake.detach())
            real_labels = torch.ones_like(pred_real) * 0.9
            fake_labels = torch.zeros_like(pred_fake) * 0.1
            loss_d_real = criterion_gan(pred_real, real_labels)
            loss_d_fake = criterion_gan(pred_fake, fake_labels)
            loss_d = (loss_d_real + loss_d_fake) * 0.5
            loss_d.backward()
            optimizer_d.step()

            # Train generator
            optimizer_g.zero_grad()
            pred_fake = discriminator(fake)
            loss_g_gan = criterion_gan(pred_fake, torch.ones_like(pred_fake))
            loss_g_pixel = criterion_pixel(fake, real) * 20
            loss_g = loss_g_gan + loss_g_pixel
            loss_g.backward()
            optimizer_g.step()

            train_metrics['g_loss'].append(loss_g.item())
            train_metrics['d_loss'].append(loss_d.item())

        # Validation
        generator.eval()
        total_val_loss = 0
        with torch.no_grad():
            for source, target, text_prompt in val_loader:
                source = source.to(device)
                target = target.to(device)
                fake = generator(source, text_prompt)
                val_loss = criterion_pixel(fake, target).item()
                total_val_loss += val_loss

        avg_val_loss = total_val_loss / len(val_loader)
        train_metrics['val_loss'].append(avg_val_loss)

        print(f"Epoch {epoch+1}/{num_epochs}: G_loss={np.mean(train_metrics['g_loss'][-len(train_loader):]):.4f}, "
              f"D_loss={np.mean(train_metrics['d_loss'][-len(train_loader):]):.4f}, Val_loss={avg_val_loss:.4f}")

        scheduler_g.step()
        scheduler_d.step()

    # Save final model
    generator.save_pretrained(save_dir)
    torch.save({
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict(),
        'epoch': num_epochs-1,
        'train_metrics': train_metrics
    }, os.path.join(save_dir, "LoRA_V6.pth"))

    return generator

### Add LoRA
Find the layers to apply LoRA

In [None]:
# Load pre-trained generator model if specified
checkpoint_path = "/kaggle/input/v6/pytorch/default/1/V6.pth"
checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=True)
base_generator = TextGuidedGenerator()
base_generator.load_state_dict(checkpoint['generator'])

#[(n, type(m)) for n, m in base_generator.named_modules()]
linear_layers=[]
rank_numbers=[]

for n, m in base_generator.named_modules():
    name = str(type(m))
    #print(name)
    if "Conv2d" in name and "clip" not in name and "Tanh" not in name:
        linear_layers.append(n)
        
# Initialize an empty dictionary
layer_dict = {}
alpha_dict = {}

# Populate the dictionary based on the rules
n_down1=0
n_down2=0
n_trans=0
n_up=0
for layer in linear_layers:
    if ("downsample_resnet.4" in layer or "downsample_resnet" in layer):
        layer_dict[layer] = 8#1
        alpha_dict[layer] = 16
        n_down1+=1
    elif "downsample_resnet.5" in layer:
        layer_dict[layer] = 4
        alpha_dict[layer] = 8
        n_down2+=1
    elif "transformer" in layer:
        layer_dict[layer] = 4#1
        alpha_dict[layer] = 8
        n_trans+=1
    elif not ("downsample_resnet" in layer):
        layer_dict[layer] = 8
        alpha_dict[layer] = 16
        n_up+=1

print(f"Layers first downsampling {n_down1}")
print(f"Layers second downsampling {n_down2}")
print(f"Layers into transformer {n_trans}")
print(f"Layers upsampling {n_up}") # not distinct upsampling since they have the same LoRA rank

### Training LoRA Model
Apply LoRA to the layers and train the new model

In [None]:
# Function to count changed weights
def count_changed_weights(model_before, model_after):
    changed_count = 0
    total_count = 0

    # Iterate through the parameters of both models
    for (name1, param1), (name2, param2) in zip(model_before.named_parameters(), model_after.named_parameters()):
        assert name1 == name2, f"Parameter mismatch: {name1} vs {name2}"
        # Check if weights differ
        changed_count += torch.sum(param1.data != param2.data).item()
        total_count += param1.numel()
    
    return changed_count, total_count


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Define source directory and text-guided target directories with prompts
source_dir = '/kaggle/input/multiprompt/original_images_ordered'
style_configs = [
    ('/kaggle/working/images/sculpture', "transform it into a sculpture"),
]

# Create datasets with text prompts
datasets = []
for target_dir, prompt in style_configs:
    dataset = TextGuidedImageDataset(source_dir, target_dir, prompt, transform)
    datasets.append(dataset)

full_dataset = ConcatDataset(datasets)
print(f"Combined dataset contains {len(full_dataset)} images\n")

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Initialize models
checkpoint_path = "/kaggle/input/v6/pytorch/default/1/V6.pth"
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
discriminator = Discriminator()
discriminator.load_state_dict(checkpoint['discriminator'])
discriminator = discriminator.to(device)
print(f"Loaded pre-trained discriminator model from {checkpoint_path}")

# LoRA Layers
config = LoraConfig(
    target_modules=linear_layers,
    init_lora_weights="gaussian",
    lora_dropout=0.3,
    rank_pattern=layer_dict,
    alpha_pattern=alpha_dict,
)

# Making the LoRA generator
base_generator = TextGuidedGenerator().to(device)
lora_generator = get_peft_model(base_generator, config)
print(f"Loaded pre-trained generator model from {checkpoint_path}")
lora_generator = lora_generator.to(device)
lora_generator.print_trainable_parameters()

# Copy the old model to see the number of weights changed
old_model = copy.deepcopy(lora_generator)

# Train the model
print("\nStarting training...")
start_time = time.time()
lora_generator = train_text_guided_gan(
    lora_generator,
    discriminator,
    train_loader,
    val_loader,
    num_epochs=100,
    device=device,
    save_dir="/kaggle/working/model"
)
end_time = (time.time() - start_time)/60
print(f"Training completed in {end_time:.2f} minutes")

# Compare the model weights after training
changed_weights, total_weights = count_changed_weights(old_model, lora_generator)
print(f"Number of weights changed: {changed_weights} out of {total_weights}")
print(f"Percentage of weights changed: {(changed_weights / total_weights) * 100:.2f}%")

### Test Model
Test the fine tuned model

In [None]:
def load_model(generator, discriminator, checkpoint_path, device):
    """Load model weights from checkpoint file"""
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
    discriminator.load_state_dict(checkpoint['discriminator'])

    # Load PEFT fine-tuned weights into the generator
    peft_model = PeftModel.from_pretrained(generator, "/kaggle/working/model/")
    peft_model.eval()  # Set to evaluation mode

    print(f"Model loaded from: {checkpoint_path}")
    return lora_generator


def process_single_image(image_path, lora_generator,text_prompt, transform, device):
    """process a single image from G and prompt"""
    
    # Load the image
    image = Image.open(image_path).convert('RGB')
    original_image = image.copy()

    # Apply transforms
    image_tensor = transform(image).unsqueeze(0).to(device)

    # Generate image
    lora_generator.eval()
    with torch.no_grad():
        lora_image_tensor = lora_generator(image_tensor, text_prompt)

    lora_image_tensor = lora_image_tensor.squeeze(0).cpu()
    lora_image = ToPILImage()(torch.clamp((lora_image_tensor + 1) / 2, 0, 1))  # Denormalization

    return lora_image


def visualize_images(original_image, lora_image, prompt):
    """shows original and created image side by side"""
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(original_image)
    axes[0].axis("off")
    axes[0].set_title("Input")

    axes[1].imshow(lora_image)
    axes[1].axis("off")
    axes[1].set_title("Lora Model")
    
    plt.title(prompt)
    plt.show()


# Specify the image path
image_path = '/kaggle/input/obamaset/obama.jpg'

# Load the image
image = Image.open(image_path).convert('RGB')
original_image = image.copy()

# transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
discriminator = Discriminator().to(device)

# Load generator
checkpoint_path = "/kaggle/working/model/LoRA_V6.pth"
test_gen = TextGuidedGenerator().to(device)
lora_generator = load_model(test_gen, discriminator, checkpoint_path, device)

# Generare immagine
lora_generator.eval()

prompt_list = ['Transform it into a sculpture', 'Blonde hair', 'Van gogh style']

for prompt in prompt_list:
    lora_image = process_single_image(image_path, lora_generator, prompt, transform, device)
    visualize_images(original_image, lora_image, prompt)

### Our LoRA Implementation
Show that the paper idea works but modified for our model, implement a different multi head attention module

In [None]:
class LoRACompatibleMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)
        self.scaling = self.head_dim ** -0.5


    def forward(self, query, key, value, attn_mask=None):
        batch_size, seq_length, embed_dim = query.size()
    
        # Project to query, key, value
        q = self.q_proj(query).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(key).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(value).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
    
        # Check if LoRA is active
        if hasattr(self.q_proj, "lora_A"):
            print("LoRA weights are active")
    
        # Scaled dot-product attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scaling
        if attn_mask is not None:
            attn_weights += attn_mask
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
    
        # Attention output
        attn_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_length, embed_dim)
    
        # Final projection
        output = self.out_proj(attn_output)
        return output


class CustomTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dropout=dropout,
                activation='gelu',
                dim_feedforward=1024
            )
            for _ in range(num_layers)
        ])
        
        # Replace MultiheadAttention in TransformerEncoderLayer with custom attention
        for layer in self.layers:
            layer.self_attn = LoRACompatibleMultiheadAttention(embed_dim, num_heads, dropout)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class TextGuidedGeneratorWithLoRA(TextGuidedGenerator):
    def __init__(self,*args, **kwargs):
        super().__init__(*args, **kwargs)

        # Define LoRA configuration for the transformer
        self.lora_config = LoraConfig(
            r=4,
            lora_alpha=8,
            target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
            lora_dropout=0.2,
            bias="lora_only"
        )

        # Apply LoRA
        self.transformer = get_peft_model(self.transformer, self.lora_config)
        

    def disable_adapter(self):
        """Disable LoRA layers by setting their parameters to not require gradients."""
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) and "lora" in name:
                module.requires_grad_(False)


    def enable_adapter(self):
        """Re-enable LoRA layers by allowing their parameters to require gradients."""
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) and "lora" in name:
                module.requires_grad_(True)

### Training
Train the new model on the dataset with distillation loss

In [None]:
# Instantiate the pre-trained model and fine-tuned model
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Define source directory and text-guided target directories with prompts
source_dir = '/kaggle/input/multiprompt/original_images_ordered'
style_configs = [
    ('/kaggle/working/images/sculpture', "transform it into a sculpture"),
]

# Create datasets with text prompts
datasets = []
for target_dir, prompt in style_configs:
    dataset = TextGuidedImageDataset(source_dir, target_dir, prompt, transform)
    datasets.append(dataset)

full_dataset = ConcatDataset(datasets)
print(f"Combined dataset contains {len(full_dataset)} images\n")

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Path to the saved pre-trained model weights
pretrained_weights_path = "/kaggle/input/v6/pytorch/default/1/V6.pth"

# Load the checkpoint
checkpoint = torch.load(pretrained_weights_path, map_location=device, weights_only=True)

# Load the pre-trained model (before fine-tuning)
pretrained_model = TextGuidedGeneratorWithLoRA(device=device).to(device)
pretrained_model.eval()  # Set to evaluation mode to freeze weights

# Load the pre-trained model weights
try:
    # Assuming the generator's weights are saved under the key "generator_state_dict"
    pretrained_model.load_state_dict(checkpoint["generator"], strict=False)
    print(f"Pre-trained model weights loaded successfully from {pretrained_weights_path}")
except KeyError:
    print(f"Key 'generator_state_dict' not found in the checkpoint. Make sure the checkpoint contains the correct key.")

# Fine-tuning model with LoRA
model = TextGuidedGeneratorWithLoRA(device=device).to(device)
model.enable_adapter()

# Define loss functions
task_loss_fn = nn.MSELoss()  # Loss for the target task
distillation_loss_fn = nn.MSELoss()  # Knowledge regularization loss

# Initialize optimizer for LoRA parameters only
optimizer = AdamW(
    [param for name, param in model.named_parameters() if param.requires_grad],
    lr=1e-4,)

# Training loop
num_epochs = 100
lambda_distill = 0.5  # Weight for the distillation loss
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    for batch in train_loader:
        source_images, target_images, text_prompts = batch
        source_images = source_images.to(device)
        target_images = target_images.to(device)

        # Forward pass through both models
        with torch.no_grad():
            pretrained_outputs = pretrained_model(source_images, text_prompts)  # Original outputs

        # LoRA fine-tuned outputs
        fine_tuned_outputs = model(source_images, text_prompts)  

        # Compute task-specific loss
        task_loss = task_loss_fn(fine_tuned_outputs, target_images)

        # Compute distillation loss
        distillation_loss = distillation_loss_fn(fine_tuned_outputs, pretrained_outputs)

        # Combine losses
        loss = task_loss + lambda_distill * distillation_loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            source_images, target_images, text_prompts = batch
            source_images = source_images.to(device)
            target_images = target_images.to(device)

            # Forward pass through both models
            pretrained_outputs = pretrained_model(source_images, text_prompts)
            fine_tuned_outputs = model(source_images, text_prompts)

            # Compute task and distillation losses
            task_loss = task_loss_fn(fine_tuned_outputs, target_images)
            distillation_loss = distillation_loss_fn(fine_tuned_outputs, pretrained_outputs)

            # Combine losses
            loss = task_loss + lambda_distill * distillation_loss
            val_loss += loss.item()

    # Logging
    print(f"Epoch {epoch+1}/{num_epochs}: Distillation_loss={distillation_loss:.4f}, Val_loss={val_loss:.4f}")

# Save the fine-tuned model
save_path = "text_guided_generator_lora_with_distillation.pth"
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

### Test
Test the new fine tuned module to see if it is better

In [None]:
# Load the model state
device = 'cuda'
model = TextGuidedGeneratorWithLoRA(device=device).to(device)
model.load_state_dict(torch.load("/kaggle/working/text_guided_generator_lora_with_distillation.pth", map_location=device, weights_only=True))
model.eval()  # Set model to evaluation mode
print("Model loaded successfully.")

# Load a test image
test_image_path = "/kaggle/input/obamaset/obama.jpg"
test_image = Image.open(test_image_path).convert("RGB")

# Define transformations (same as used in training)
transform = transforms.Compose([
    transforms.ToTensor(),          # Convert to tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])
test_image_tensor = transform(test_image).unsqueeze(0).to(device)  # Add batch dimension

# Define a test text prompt
test_prompt = "Transform it into a sculpture"

# Pass the image and prompt through the model
with torch.no_grad():
    generated_image = model(test_image_tensor, test_prompt)

# Convert the output tensor to an image
generated_image = generated_image.squeeze(0).cpu()  # Remove batch dimension
generated_image = (generated_image * 0.5 + 0.5).clamp(0, 1)  # Rescale to [0, 1]
generated_image = transforms.ToPILImage()(generated_image)

# Save the result
generated_image.save("generated_image.jpg")

# Define a second test text prompt
test_prompt2 = "van gogh style"
with torch.no_grad():
    generated_image2 = pretrained_model(test_image_tensor, test_prompt2)

# Convert the output tensor to an image
generated_image2 = generated_image2.squeeze(0).cpu()  # Remove batch dimension
generated_image2 = (generated_image2 * 0.5 + 0.5).clamp(0, 1)  # Rescale to [0, 1]
generated_image2 = transforms.ToPILImage()(generated_image2)

# Save the result
generated_image2.save("generated_image2.jpg")

# Create a figure to display the images side by side
fig, ax = plt.subplots(1, 2, figsize=(12, 6))

# Display the original image
ax[0].imshow(generated_image)
ax[0].axis("off")  # Remove axes for better visualization
ax[0].set_title("Sculpture", fontsize=14)

# Display the generated image
ax[1].imshow(generated_image2)
ax[1].axis("off")  # Remove axes for better visualization
ax[1].set_title(test_prompt2, fontsize=14)

# Show the plot
plt.tight_layout()
plt.show()