In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import requests
from pathlib import Path

def download_file(url, filename):
    response = requests.get(url)
    Path(filename).write_bytes(response.content)

# Download images
base_image_path = "sf.jpg"
style_reference_image_path = "starry_night.jpg"
download_file("https://img-datasets.s3.amazonaws.com/sf.jpg", base_image_path)
download_file("https://img-datasets.s3.amazonaws.com/starry_night.jpg", style_reference_image_path)

# Get dimensions
original_img = Image.open(base_image_path)
original_width, original_height = original_img.size
img_height = 400
img_width = round(original_width * img_height / original_height)

def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = Image.open(image_path)
    img = transform(img).unsqueeze(0)
    return img

def deprocess_image(tensor):
    tensor = tensor.squeeze().cpu()
    tensor = tensor * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    tensor = tensor + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    tensor = tensor.clamp(0, 1)
    tensor = tensor.permute(1, 2, 0)
    img = (tensor.numpy() * 255).astype(np.uint8)
    return Image.fromarray(img)

# Load VGG19 model
vgg19 = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features.eval()
if torch.cuda.is_available():
    vgg19 = vgg19.cuda()

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = {
            '0': 'block1_conv1',
            '5': 'block2_conv1',
            '10': 'block3_conv1',
            '19': 'block4_conv1',
            '28': 'block5_conv1',
            '31': 'block5_conv2'
        }
        self.model = vgg19

    def forward(self, x):
        features = {}
        for name, layer in self.model._modules.items():
            x = layer(x)
            if name in self.layers:
                features[self.layers[name]] = x
        return features

feature_extractor = FeatureExtractor()

def content_loss(base, combination):
    return torch.sum((combination - base).pow(2))

def gram_matrix(x):
    b, c, h, w = x.size()
    features = x.view(c, h * w)
    gram = torch.mm(features, features.t())
    return gram

def style_loss(style, combination):
    S = gram_matrix(style)
    C = gram_matrix(combination)
    channels = 3
    size = img_height * img_width
    return torch.sum((S - C).pow(2)) / (4.0 * (channels ** 2) * (size ** 2))

def total_variation_loss(x):
    h_diff = x[:, :, :-1, :] - x[:, :, 1:, :]
    w_diff = x[:, :, :, :-1] - x[:, :, :, 1:]
    return torch.sum((h_diff.pow(2) + w_diff.pow(2)).pow(1.25))

style_layer_names = [
    "block1_conv1",
    "block2_conv1",
    "block3_conv1",
    "block4_conv1",
    "block5_conv1",
]
content_layer_name = "block5_conv2"
total_variation_weight = 1e-6
style_weight = 1e-6
content_weight = 2.5e-8

def compute_loss(combination_image, base_image, style_reference_image):
    input_tensor = torch.cat([base_image, style_reference_image, combination_image])
    features = feature_extractor(input_tensor)
    
    loss = torch.tensor(0.0).cuda() if torch.cuda.is_available() else torch.tensor(0.0)
    
    # Content loss
    content_features = features[content_layer_name]
    base_content_features = content_features[0]
    combination_content_features = content_features[2]
    loss = loss + content_weight * content_loss(base_content_features, combination_content_features)
    
    # Style loss
    for layer_name in style_layer_names:
        layer_features = features[layer_name]
        style_features = layer_features[1]
        combination_features = layer_features[2]
        sl = style_loss(style_features, combination_features)
        loss += (style_weight / len(style_layer_names)) * sl
    
    # Total variation loss
    loss += total_variation_weight * total_variation_loss(combination_image)
    
    return loss

# Prepare images
base_image = preprocess_image(base_image_path)
style_reference_image = preprocess_image(style_reference_image_path)
combination_image = preprocess_image(base_image_path).requires_grad_(True)

if torch.cuda.is_available():
    base_image = base_image.cuda()
    style_reference_image = style_reference_image.cuda()
    combination_image = combination_image.cuda()

optimizer = optim.Adam([combination_image], lr=1.0)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)

iterations = 4000
for i in range(1, iterations + 1):
    optimizer.zero_grad()
    loss = compute_loss(combination_image, base_image, style_reference_image)
    loss.backward()
    optimizer.step()
    
    if i % 100 == 0:
        print(f"Iteration {i}: loss={loss.item():.2f}")
        img = deprocess_image(combination_image.detach())
        img.save(f"combination_image_at_iteration{i}.png")
    
    if i % 100 == 0:
        scheduler.step()



RuntimeError: CUDA error: the launch timed out and was terminated
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
