In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt
from tkinter import Tk, filedialog

# Function to select an image file
def select_image(title):
    root = Tk()
    root.withdraw()  # Hide the main window
    file_path = filedialog.askopenfilename(title=title, filetypes=[("Image Files", "*.jpg;*.jpeg;*.png")])
    return file_path

def load_image(img_path, max_size=400, shape=None):
    image = Image.open(img_path).convert("RGB")
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
    if shape is not None:
        size = shape
    in_transform = transforms.Compose([
        transforms.Resize(size if isinstance(size, tuple) else (size, size), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
    image = in_transform(image)[:3, :, :].unsqueeze(0)
    return image

def get_features(image, model, layers=None):
    if layers is None:
        layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1', '21': 'conv4_2', '28': 'conv5_1'}
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
    return features

def gram_matrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram 

# Select images from file dialog
print("Select the Content Image")
content_path = select_image("Select Content Image")

print("Select the Style Image")
style_path = select_image("Select Style Image")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.to(device).eval()

content = load_image(content_path).to(device)
style = load_image(style_path, shape=(content.shape[-2], content.shape[-1])).to(device)

content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

target = content.clone().requires_grad_(True).to(device)

style_weights = {'conv1_1': 1.0, 'conv2_1': 0.75, 'conv3_1': 0.5, 'conv4_1': 0.25, 'conv5_1': 0.1}
content_weight = 1e4
style_weight = 1e2

optimizer = optim.Adam([target], lr=0.003)

print("Starting Neural Style Transfer...")
for i in range(300):
    target_features = get_features(target, vgg)
    content_loss = torch.nn.functional.mse_loss(target_features['conv4_2'], content_features['conv4_2'])
    style_loss = 0
    for layer in style_weights:
        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        _, d, h, w = target_feature.shape
        style_gram = style_grams[layer]
        layer_style_loss = torch.nn.functional.mse_loss(target_gram, style_gram)
        style_loss += (style_weights[layer] * layer_style_loss) / (d * h * w)
    
    total_loss = content_weight * content_loss + style_weight * style_loss
    optimizer.zero_grad()
    total_loss.backward(retain_graph=True)
    optimizer.step()
    
    if i % 50 == 0:  # Print progress every 50 iterations
        print(f"Iteration {i}/300, Total Loss: {total_loss.item()}")

# Convert output to an image
target_image = target.clone().detach().cpu().squeeze(0)
target_image = torch.clamp(target_image, 0, 1)
target_image = transforms.ToPILImage()(target_image)

# Save & display the final image
target_image.save("output.jpg")
print("Image saved as output.jpg")

plt.imshow(target_image)
plt.axis("off")
plt.show()


Select the Content Image
