In [1]:
import torch
import torch.nn as nn
import torchvision # for pretrained models
from torchvision import transforms, models # for pretrained models
from PIL import Image # Python Image Library for image processing
import matplotlib.pyplot as plt # for plotting
import numpy as np # for numerical calculations

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_image(path, img_transform, size = (300, 300)):
    img = Image.open(path)
    img = img.resize(size, Image.LANCZOS)
    img = img_transform(img).unsqueeze(0) # add batch dimension
    return img.to(device)

In [None]:
def get_gram(m):
    """ m is of shape (batch_size, channels, height, width) """
    batch_size, channels, height, width = m.size()
    m = m.view(batch_size * channels, height * width)
    gram = torch.mm(m, m.t())
    return gram.div(batch_size * channels * height * width)

In [None]:
# Denormalize the image
def denormalize_img(img):
    img = img.numpy().transpose(1, 2, 0) # (channels, height, width) -> (height, width, channels)
    mean = np.array([0.485, 0.456, 0.406]) # mean of the ImageNet dataset
    std = np.array([0.229, 0.224, 0.225]) # standard deviation of the ImageNet dataset
    img = std * img + mean # denormalize
    img = np.clip(img, 0, 1) # clip the values to [0, 1]
    return img * 0.5 + 0.5