# Perceptual loss

JAX port of lpips, as implemented in [Taming Transformers](https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py) and [richzhang/PerceptualSimilarity](https://github.com/richzhang/PerceptualSimilarity/blob/31bc1271ae6f13b7e281b9959ac24a5e8f2ed522/lpips/pretrained_networks.py)

In [1]:
from flaxmodels import VGG16

In [2]:
import random
import jax
import jax.numpy as jnp
import flax.linen as nn

We need PyTorch to download the weights for the additional layers.

In [3]:
import torch

In [4]:
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)
x = jax.random.normal(key, shape=(8, 256, 256, 3))
target = jax.random.normal(subkey, shape=(8, 256, 256, 3))

In [5]:
# Is there a module for this?
mse = lambda x, y: jnp.mean((x - y) ** 2)

In [6]:
class NetLinLayer(nn.Module):
    kernel_size = (1,1)
    
    def setup(self):
        self.layer = nn.Conv(1, self.kernel_size, strides=None, padding=0, use_bias=False)
        
    def __call__(self, x):
        x = self.layer(x)
        return x

In [7]:
class LPIPS(nn.Module):    
    def setup(self):
        # We don't add a scaling layer because I think `VGG16` already includes it
        # To be verified
        self.feature_names = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
        
        self.vgg = VGG16(output='activations', pretrained='imagenet', include_head=False)
        self.lins = [NetLinLayer() for _ in range(len(self.feature_names))]
        
    def __call__(self, x, t):
        x = self.vgg((x + 1) / 2)
        t = self.vgg((t + 1) / 2)
            
        feats_x, feats_t, diffs = {}, {}, {}
        for i, f in enumerate(self.feature_names):
            feats_x[i], feats_t[i] = normalize_tensor(x[f]), normalize_tensor(t[f])
            diffs[i] = (feats_x[i] - feats_t[i]) ** 2

        # We should maybe vectorize this better
        res = [spatial_average(self.lins[i](diffs[i]), keepdims=True) for i in range(len(self.feature_names))]
        
        val = res[0]
        for i in range(1, len(res)):
            val += res[i]
        return val

In [8]:
def normalize_tensor(x, eps=1e-10):
    # Use `-1` because we are channel-last
    norm_factor = jnp.sqrt(jnp.sum(x**2, axis=-1, keepdims=True))
    return x / (norm_factor + eps)

In [9]:
def spatial_average(x, keepdims=True):
    # Mean over W, H
    return jnp.mean(x, axis=[1, 2], keepdims=keepdims)

In [10]:
lpips = LPIPS()

In [11]:
params = lpips.init(key, x, x)

In [12]:
res = lpips.apply(params, x, target)
res.shape

(8, 1, 1, 1)

**TODO**:
- [ ] Load pretrained weights of linear layers (beware of convolution transpose).
- [ ] Create a test to compare results against Taming Transformers.

----

In [13]:
additional_weights = torch.load("taming/modules/autoencoder/lpips/vgg.pth", map_location=torch.device("cpu"))
additional_weights.keys()

odict_keys(['lin0.model.1.weight', 'lin1.model.1.weight', 'lin2.model.1.weight', 'lin3.model.1.weight', 'lin4.model.1.weight'])