# Perceptual loss

The paper does not state whether VGG-16 or VGG-19 is used. I'll assume VGG-19.

In [1]:
from flaxmodels import VGG19

In [2]:
vgg19 = VGG19(output='activations', pretrained='imagenet', include_head=False)

In [3]:
import jax
import jax.numpy as jnp
import random

In [4]:
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)
qc = jax.random.normal(key, shape=(16, 256, 256, 3))

In [5]:
vgg_params = vgg19.init(key, qc)

In [6]:
key, subkey = jax.random.split(key)
x = jax.random.normal(key, shape=(16, 256, 256, 3))
output = jax.random.normal(subkey, shape=(16, 256, 256, 3))

**TODO: verify normalization**

In [7]:
x_vgg = vgg19.apply(vgg_params, (x + 1)/2)
output_vgg = vgg19.apply(vgg_params, (output + 1)/2)

**TODO: verify in the paper.**

In [8]:
vgg_layers = ['relu3_3', 'relu4_3', 'relu5_3']
vgg_weights = [1e-3, 5e-3, 20e-3]

In [9]:
mse = lambda x, y: jnp.mean((x - y) ** 2)

In [10]:
p_loss = 0
for layer, w in zip(vgg_layers, vgg_weights):
    layer_loss = w * mse(x_vgg[layer], output_vgg[layer])
    p_loss += layer_loss
p_loss

DeviceArray(0.06175933, dtype=float32)