In [12]:
import os
from urllib.parse import urlparse
#!/usr/bin/env python3
# """Neural style transfer in PyTorch."""

import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import models, transforms
from PIL import Image
import requests
from io import BytesIO


# Define the VGGFeatures class
class VGGFeatures(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = sorted(set(layers))
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])
        self.model = models.vgg19(pretrained=True).features[:max(self.layers) + 1]
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, input):
        x = self.normalize(input)
        features = {}
        for i, layer in enumerate(self.model):
            x = layer(x)
            if i in self.layers:
                features[i] = x
        return features


# Define the ContentLoss class
class ContentLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        self.target = target.detach()

    def forward(self, input):
        return F.mse_loss(input, self.target)


# Define the StyleLoss class
class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super().__init__()
        self.target = self.gram_matrix(target_feature).detach()

    @staticmethod
    def gram_matrix(input):
        a, b, c, d = input.size()
        features = input.view(a * b, c * d)
        G = torch.mm(features, features.t())
        return G.div(a * b * c * d)

    def forward(self, input):
        G = self.gram_matrix(input)
        return F.mse_loss(G, self.target)


# Define the function to load and preprocess the image

def load_image(image_source, target_size):
    # Check if the source is a URL or a local file
    if urlparse(image_source).scheme in ('http', 'https'):
        # Source is a URL
        response = requests.get(image_source)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    elif os.path.exists(image_source):
        # Source is a local file path
        image = Image.open(image_source).convert("RGB")
    else:
        raise ValueError("Invalid image source. Must be a URL or a file path.")

    # Calculate aspect ratio resize
    aspect_ratio = image.width / image.height
    if aspect_ratio > 1:
        # Image is wider than tall
        new_height = target_size
        new_width = int(target_size * aspect_ratio)
    else:
        # Image is taller than wide
        new_width = target_size
        new_height = int(target_size / aspect_ratio)

    # Resize and center crop
    transform = transforms.Compose([
        transforms.Resize((new_height, new_width)),
        transforms.CenterCrop(target_size),
        transforms.ToTensor()
    ])
    image = transform(image).unsqueeze(0)
    return image


# Define content and style images URLs
content_url = "content.png"
style_url = "grid.png"

# Load content and style images
content_image = load_image(content_url, target_size=512)
style_image = load_image(style_url, target_size=512)

# Choose device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transfer images to device
content_image = content_image.to(device)
style_image = style_image.to(device)

# Define VGG layers
content_layers = [22]
style_layers = [1, 6, 11, 20, 29]
all_layers = style_layers + content_layers
vgg = VGGFeatures(all_layers).to(device)

# Extract features
content_features = vgg(content_image)
style_features = vgg(style_image)

# Initialize loss functions
content_loss_fn = ContentLoss(content_features[content_layers[0]])
style_loss_fns = [StyleLoss(style_features[layer]) for layer in style_layers]

# Prepare output image and optimizer
output_image = content_image.clone()
optimizer = optim.Adam([output_image.requires_grad_()], lr=0.02)

# Run the style transfer
num_steps = 500
style_weight = 1000000
content_weight = 1

for step in range(num_steps):
    optimizer.zero_grad()
    output_features = vgg(output_image)
    content_loss = content_loss_fn(output_features[content_layers[0]])
    style_loss = 0

    for fn, layer in zip(style_loss_fns, style_layers):
        style_loss += fn(output_features[layer])

    total_loss = content_weight * content_loss + style_weight * style_loss
    total_loss.backward()
    optimizer.step()
    


    if step % 50 == 0:
        print(f"Step {step}, Total loss: {total_loss.item()}")
        final_img = output_image.cpu().squeeze(0)
        final_img = transforms.ToPILImage()(final_img.clamp(0, 1))
        final_img.save(f"output{step}.jpg")

# Save or display the final image
final_img = output_image.cpu().squeeze(0)
final_img = transforms.ToPILImage()(final_img.clamp(0, 1))
final_img.save("output.jpg")
# final_img.show()

Step 0, Total loss: 2028.26220703125
Step 50, Total loss: 498.89141845703125
Step 100, Total loss: 308.968017578125
Step 150, Total loss: 236.876708984375
Step 200, Total loss: 207.48672485351562
Step 250, Total loss: 187.77772521972656
Step 300, Total loss: 169.3457794189453
Step 350, Total loss: 159.0277099609375
Step 400, Total loss: 154.73887634277344
Step 450, Total loss: 155.9910430908203
