# AdaAttN Style Transfer

This notebook demonstrates the use of the AdaAttN model for style transfer.

## Imports

In [None]:

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import utils.data as data
from utils.eval import compute_ssim, plot_results
from models import AdaAttN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Dataset and Dataloaders

In [None]:
val_tf = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.ToTensor(),
])

_, content_validloader, _, _, style_validloader, _ = data.get_dataloaders(bs=64, valid_tf=val_tf)


# Load Pretrained Models

In [None]:
decoder = AdaAttN.Decoder()
vgg = AdaAttN.VGG()

# Load pretrained weights if available
decoder.load_state_dict(torch.load("models/output/AdaAttN/decoder.pth"))
vgg.load_state_dict(torch.load("models/output/AdaAttN/vgg_normalised.pth"))
vgg = nn.Sequential(*list(vgg.children())[:31])

decoder.to(device)
vgg.to(device)

model = AdaAttN.AdaAttnModel(decoder, vgg, in_planes=512, key_planes=512)
model.to(device)


# Model Evaluation Functions

In [None]:
def evaluate_model(model, content_loader, style_loader):
    model.eval()
    total_samples = 0
    running_content_loss, running_style_loss = 0.0, 0.0
    with torch.no_grad():
        for content, style in zip(content_loader, style_loader):
            content_images = content.to(device)
            style_images, style_labels = style
            style_images, style_labels = style_images.to(device), style_labels.to(device)
            
            stylized_images = model(content_images, style_images)
            
            content_loss = model.calc_content_loss(stylized_images, content_images)
            style_loss = model.calc_style_loss(stylized_images, style_images)
            
            running_content_loss += content_loss.item()
            running_style_loss += style_loss.item()
            
            if total_samples == 0:
                plot_results(content_images, style_images, style_labels, stylized_images)
            total_samples += style_labels.size(0)
            
    avg_ssim = compute_ssim(content_images, stylized_images) / total_samples
    avg_content_loss = running_content_loss / total_samples
    avg_style_loss = running_style_loss / total_samples
    return avg_ssim, avg_content_loss, avg_style_loss


# Results

In [None]:
AdaAttN_ssim, AdaAttN_content_loss, AdaAttN_style_loss = evaluate_model(model, content_validloader, style_validloader)
print("--- AdaAttN results ---")
print(f'Average SSIM = {AdaAttN_ssim: .4f}')
print(f"Average content loss = {AdaAttN_content_loss:.4f}")
print(f"Average style loss = {AdaAttN_style_loss:.6f}")
