In [None]:
import PIL
import torch
import torchvision

import numpy as np
import matplotlib.pyplot as plt

from data_utils import load_imagenet_val
from image_utils import preprocess_image
from utils import *
from style_modules import ContentLoss, StyleLoss, TotalVariationLoss
from style_utils import preprocess, style_transfer

#### Example 1

$ f_1: \mathbb{R} \times \mathbb{R} \to \mathbb{R} $

$ f_1(x, y) = 3x^2 + y^2 $

$ \frac{\partial f_1}{dx} = 6x $

$ \frac{\partial f_1}{dy} = 2y $

In [None]:
f_1 = lambda x, y: with_grad(3 * x ** 2 + y ** 2)

In [None]:
x = torch.FloatTensor([5]).requires_grad_(True)
y = torch.FloatTensor([-7]).requires_grad_(True)

In [None]:
f_1(x, y)

In [None]:
x.grad

In [None]:
y.grad

#### Example 2

$ f_2: \mathbb{R}^2 \to \mathbb{R} $

$ f_2(\langle v_1, v_2 \rangle) = 3v_1^2 + v_2^2 $

$ \nabla f_2 = \langle \frac{\partial f_2}{dv_1}, \frac{\partial f_2}{dv_2} \rangle = \langle 6v_1, 2v_2 \rangle $

In [None]:
f_2 = lambda v: with_grad(3 * v[0] ** 2 + v[1] ** 2)

In [None]:
v = torch.FloatTensor([5, -7]).requires_grad_(True)

In [None]:
f_2(v)

In [None]:
v.grad

#### Image example

In [None]:
X, y, class_names = load_imagenet_val(num=16)

In [None]:
plt.imshow(X[0])

In [None]:
img_tensor = torch.tensor(preprocess_image(X[0])).requires_grad_(True)

$ f: \mathbb{R}^{h \times w \times 3} \to \mathbb{R} $

$ f(\langle x_1, ..., x_{h \cdot w \cdot 3} \rangle) = \sum_{i=1}^{h \cdot w \cdot 3} x_i^2 $

In [None]:
# Define scalar-valued function on image tensor

f = lambda img: with_grad(img_tensor.square().sum())

In [None]:
f(img_tensor)

In [None]:
img_tensor.grad.shape

#### Pre-trained NN

In [None]:
cnn = torchvision.models.squeezenet1_1(pretrained=True)

preds = []
for i in range(X.shape[0]):
    processed_X = torch.tensor(preprocess_image(X[i])).permute((2, 0, 1))[None,]
    predicted_y = int(cnn(processed_X).argmax())
    preds.append(predicted_y)

In [None]:
print(f"accuracy of the cnn is {(np.array(preds) == y).sum() / y.shape[0]}")

In [None]:
for i in range(X.shape[0]):
    plt.figure()
    plt.imshow(X[i])
    plt.title(f"actual: {class_names[y[i]]}\npredicted: {class_names[preds[i]]}")