In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
# Load the VGG19 model
vgg19 = models.vgg19(pretrained=True).features

# Freeze model weights to avoid training them
for param in vgg19.parameters():
    param.requires_grad = False
    
# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg19.to(device)

In [None]:
def load_image(img_path, max_size=400, shape=None):
    image = Image.open(img_path).convert('RGB')
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
    if shape is not None:
        size = shape
    
    in_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    # add the batch dimension
    image = in_transform(image)[:3,:,:].unsqueeze(0)
    return image.to(device)

In [None]:
# Function to extract features from images
def get_features(image, model, layers=None):
    if layers is None:
        layers = {
            '0': 'conv1_1',
            '5': 'conv2_1',
            '10': 'conv3_1',  # convX_Y (X: conv block of the VGG, Y: layer within the block)
            '19': 'conv4_1',
            '21': 'conv4_2',  # content representation
            '28': 'conv5_1'
        }
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
    return features

In [None]:
# Function to calculate the Gram matrix
def gram_matrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram

In [None]:
# Load content and style images
content = load_image("path_to_your_content_image.jpg")
style = load_image("path_to_your_style_image.jpg", shape=content.shape[2:])

In [None]:
# Initialize a target image
target = content.clone().requires_grad_(True).to(device)
optimizer = optim.Adam([target], lr=0.003)

In [None]:
# Style weights
style_weights = {'conv1_1': 1.,
                 'conv2_1': 0.75,
                 'conv3_1': 0.2,
                 'conv4_1': 0.2,
                 'conv5_1': 0.2
                 }
content_weight = 1  # Adjust as needed
style_weight = 1e6  # Adjust as needed

In [None]:
# Define training loop
epochs = 300
for i in range(epochs):
    target_features = get_features(target, vgg19)
    content_features = get_features(content, vgg19)
    style_features = get_features(style, vgg19)

    style_loss = 0
    content_loss = 0

    for layer in style_weights:
        target_feature = target_features[layer]
        style_feature = style_features[layer]
        target_gram = gram_matrix(target_feature)
        style_gram = gram_matrix(style_feature)
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2)
        style_loss += layer_style_loss / (target_feature.shape[1] * target_feature.shape[2] * target_feature.shape[3])
        
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
    
    total_loss = content_weight * content_loss + style_weight * style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if i % 50 == 0:
        print(f"Epoch {i}, Total Loss: {total_loss.item()}")