# Representational Geodesic

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import pyrtools as pt
import plenoptic as po
from plenoptic.tools import to_numpy
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype  = torch.float32
torch.__version__

## Translation

In [None]:
image_size = 64
einstein = po.load_images('../data/256/einstein.pgm')
einstein = po.tools.conv.blur_downsample(einstein, step=(4,4))
vid = po.tools.translation_sequence(einstein[0], n_steps=20)
from torchvision.transforms.functional import center_crop
vid = center_crop(vid, image_size // 2)
vid = po.tools.rescale(vid, 0, 1)

imgA = vid[0:1]
imgB = vid[-1:]

pt.image_stats(to_numpy(imgA))
pt.image_stats(to_numpy(imgB))
print(imgA.shape)
print(vid.shape)

# convention: full name for numpy arrays, short hands for torch tensors
video = to_numpy(vid).squeeze()
print(video.shape)
# pt.animshow(video, zoom=4)
pt.imshow(list(video.squeeze()), zoom=4);

### Spectral models
Computing a geodesic to reveal excess invariance of the global Fourier magnitude representation.

In [None]:
import torch.fft
class Fourier(nn.Module):
    def __init__(self, representation = 'amp'):
        super().__init__()
        self.representation = representation
        
    def spectrum(self, x):
        return torch.fft.rfftn(x, dim=(2, 3))

    def forward(self, x):
        if self.representation == 'amp':
            return torch.abs(self.spectrum(x))
        elif self.representation == 'phase':
            return torch.angle(self.spectrum(x))
        elif self.representation == 'rectangular':
            return self.spectrum(x)
        elif self.representation == 'polar':
            return torch.cat((torch.abs(self.spectrum(x)),
                              torch.angle(self.spectrum(x))),
                             dim=1)

model = Fourier('amp')
# model = Fourier('polar') # note: need pytorch>=1.8 to take gradients through torch.angle 

In [None]:
# class Normalize(nn.Module):
#     def __init__(self):
#         super().__init__()
#     def forward(self, x):
#         return torch.div(x, x.pow(2).sum(dim=(1,2,3), keepdim=True).pow(.5))
# model = Normalize()

In [None]:
n_steps = len(video)-1
moog = po.synth.Geodesic(imgA, imgB, model, n_steps, init='bridge')
moog.synthesize(max_iter=500, learning_rate=.01, lmbda=.1, mu=1, nu=0.01)

In [None]:
moog.plot_loss();
moog.plot_deviation_from_line(vid);

In [None]:
plt.plot([po.to_numpy(e) for e in moog.step_energy], alpha=.2);
plt.plot([e.mean() for e in moog.step_energy], 'r-', label='path energy')
plt.axhline(torch.norm(moog._analyze(moog.xA) - moog._analyze(moog.xB)) ** 2 / moog.n_steps ** 2)
plt.legend()
plt.title('evolution of representation step energy')
plt.ylabel('step energy')
plt.xlabel('iteration')
plt.yscale('log')
plt.show()

In [None]:
try:
    moog.step_jerkiness[0]
    plt.plot([po.to_numpy(j) for j in moog.step_jerkiness]);
    plt.plot([j.mean() for j in moog.step_jerkiness], 'r--', label='path energy');
    plt.legend()
    plt.title('evolution of representation step jerkiness')
    plt.ylabel('step jerkiness')
    plt.xlabel('iteration')
    plt.yscale('log')
    plt.show()
except:
    plt.plot(moog.calculate_path_jerkiness())
    plt.title('final representation step jerkiness')

In [None]:
# plt.plot(torch.stack([d[0] for d in moog.dev_from_line]));
plt.plot(torch.stack([d[1] for d in moog.dev_from_line]));

plt.title('evolution of distance from representation line')
plt.ylabel('distance from representation line')
plt.xlabel('iteration step')
# plt.yscale('log')
plt.show()

In [None]:
pixelfade = to_numpy(moog.pixelfade.squeeze())
geodesic = to_numpy(moog.geodesic.squeeze())
fig = pt.imshow([video[5], pixelfade[5], geodesic[5]],
          title=['video', 'pixelfade', 'geodesic'],
          col_wrap=3, zoom=4);

size = geodesic.shape[-1]
h, m , l = (size//2 + size//4, size//2, size//2 - size//4)

# for a in fig.get_axes()[0]:
a = fig.get_axes()[0]
for line in (h, m, l):
    a.axhline(line, lw=2)

pt.imshow([video[:,l], pixelfade[:,l], geodesic[:,l]],
          title=None, col_wrap=3, zoom=4);
pt.imshow([video[:,m], pixelfade[:,m], geodesic[:,m]],
          title=None, col_wrap=3, zoom=4);
pt.imshow([video[:,h], pixelfade[:,h], geodesic[:,h]],
          title=None, col_wrap=3, zoom=4);

### Physiologically inspired models

In [None]:
model = po.simul.OnOff(kernel_size=(31,31), pretrained=True)
po.imshow(model(imgA), zoom=8);
# po.imshow(model.conv.weight, zoom=28, vrange='auto0');

In [None]:
n_steps = 10

moog = po.synth.Geodesic(imgA, imgB, model, n_steps, init='bridge')

print('shape trainable param', '# trainable param')
sum(p.numel() for p in moog.parameters())
[p.shape for p in moog.parameters() if p.requires_grad], sum(p.numel() for p in moog.parameters() if p.requires_grad)

In [None]:
try:
    from adabelief_pytorch import AdaBelief
    import adabelief_pytorch
    print(adabelief_pytorch.__version__)
    optimizer = AdaBelief([moog.x], lr=0.001, eps=1e-16, betas=(0.9,0.999),
                          weight_decouple=True, rectify=False, print_change_log=False)
except:
    optimizer = 'Adam'

In [None]:
moog.synthesize(optimizer=optimizer, nu=0)

In [None]:
moog.plot_loss()
moog.plot_deviation_from_line();

In [None]:
# try:
#     moog.animate_distance_from_line(vid).save("../logs/distfromline_frontend_translation.mp4")
# except:
#     print('generating the animation takes time, therefore we dont do it by default')

In [None]:
# moog.dev_from_line[0][1]

In [None]:
# plt.plot(po.to_numpy(torch.stack(moog., 0)[:, 1:-1]))
plt.plot(torch.stack([d[0] for d in moog.dev_from_line]));
# plt.plot(torch.stack([d[1][1:-1] for d in moog.dev_from_line]));

plt.title('evolution of distance from representation line')
plt.ylabel('distance from representation line')
plt.xlabel('iteration step')
plt.yscale('log')
plt.show()

In [None]:
plt.plot([po.to_numpy(e) for e in moog.step_energy]);
plt.plot([e.mean() for e in moog.step_energy], 'r--', label='path energy')
plt.axhline(torch.norm(moog._analyze(moog.xA) - moog._analyze(moog.xB)) ** 2 / moog.n_steps ** 2)
plt.legend()
plt.title('evolution of representation step energy')
plt.ylabel('step energy')
plt.xlabel('iteration')
plt.yscale('log')
plt.show()

In [None]:
try:
    moog.step_jerkiness[0]
    plt.plot([po.to_numpy(j) for j in moog.step_jerkiness]);
    plt.plot([j.mean() for j in moog.step_jerkiness], 'r--', label='path energy');
    plt.legend()
    plt.title('evolution of representation step jerkiness')
    plt.ylabel('step jerkiness')
    plt.xlabel('iteration')
    plt.yscale('log')
    plt.show()
except:
    plt.plot(moog.calculate_path_jerkiness())
    plt.title('final representation step jerkiness')

In [None]:
geodesic  = po.to_numpy(moog.geodesic).squeeze()
pixelfade = po.to_numpy(moog.pixelfade).squeeze()
assert geodesic.shape == pixelfade.shape
geodesic.shape

In [None]:
print('geodesic')
pt.imshow(list(geodesic), vrange='auto1', title=None, zoom=4);
print('diff')
pt.imshow(list(geodesic - pixelfade), vrange='auto1', title=None, zoom=4);
print('pixelfade')
pt.imshow(list(pixelfade), vrange='auto1', title=None, zoom=4);

In [None]:
# checking that the range constraint is met
plt.hist(video.flatten(), histtype='step', density=True, label='video')
plt.hist(pixelfade.flatten(), histtype='step', density=True, label='pixelfade')
plt.hist(geodesic.flatten(), histtype='step', density=True, label='geodesic');
plt.yscale('log')
plt.title('signal value histogram')
plt.legend(loc=1)
plt.show()

## vgg16 translation / rotation / scaling  

In [None]:
imgA = po.load_images('../data/frontwindow_affine.jpeg', as_gray=False)
imgB = po.load_images('../data/frontwindow.jpeg', as_gray=False)
# imgA = torchvision.transforms.functional.center_crop(imgA, 224)
# imgB = torchvision.transforms.functional.center_crop(imgB, 224)
# torch.manual_seed()
# imgA = torchvision.transforms.RandomCrop(224)(imgA)
# imgB = torchvision.transforms.RandomCrop(224)(imgB)
u = 300
l = 90
imgA = imgA[..., u:u+224, l:l+224]
imgB = imgB[..., u:u+224, l:l+224]
po.imshow([imgA, imgB], as_rgb=True);
diff = imgA - imgB
po.imshow(diff);
pt.image_compare(po.to_numpy(imgA, True), po.to_numpy(imgB, True));
# pt.image_stats(po.to_numpy(diff, True));

In [None]:
# imgA = torch.tensor(imageA, dtype=dtype).unsqueeze(0).unsqueeze(0)
# imgB = torch.tensor(imageB, dtype=dtype).unsqueeze(0).unsqueeze(0)

# # print(imgA.shape)
# # from plenoptic.tools.straightness import make_straight_line
# # n_steps = 11
# # video = make_straight_line(imgA, imgB, n_steps)
# # print(video.shape)
# # pt.image_stats(po.to_numpy(video).squeeze())
# # pt.animshow(po.to_numpy(video).squeeze(), zoom=2)

# imgA = torch.stack([imgA, imgA, imgA], dim=1).squeeze(2)
# imgB = torch.stack([imgB, imgB, imgB], dim=1).squeeze(2)
# print(imgA.shape)
# po.imshow([imgA, imgB], as_rgb=True, zoom=2);

# # color_img = po.load_images('../data/color_wheel.jpg', as_gray=False)
# # color_img = po.blur_downsample(color_img)
# # color_img = po.blur_downsample(color_img)[..., 11:-11, 11:-11]
# # color_img = po.blur_downsample(color_img)
# # color_img = po.blur_downsample(color_img)
# # imgA = po.rescale(color_img)
# # imgB = torch.transpose(imgA.clone(), 2, 3)
# # po.imshow([imgA, imgB], as_rgb=True, zoom=2);

In [None]:
from torchvision import models
# Create a class that takes the nth layer output of a given model
class NthLayer(torch.nn.Module):
    """Wrap any model to get the response of an intermediate layer
    
    Works for Resnet18 or VGG16.
    
    """
    def __init__(self, model, layer=None):
        """
        Parameters
        ----------
        model: PyTorch model
        layer: int
            Which model response layer to output
        """
        super().__init__()

        # TODO
        # is centrering appropriate??? 
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])
        try:
            # then this is VGG16
            features = list(model.features)
        except AttributeError:
            # then it's resnet18
            features = ([model.conv1, model.bn1, model.relu, model.maxpool] + [l for l in model.layer1] + 
                        [l for l in model.layer2] + [l for l in model.layer3] + [l for l in model.layer4] + 
                        [model.avgpool, model.fc])
        self.features = nn.ModuleList(features).eval()

        if layer is None:
            layer = len(self.features)
        self.layer = layer

    def forward(self, x):
        
        x = self.normalize(x)
        for ii, mdl in enumerate(self.features):
            x = mdl(x)
            if ii == self.layer:
                return x

# different potential models of human visual perception of distortions
# resnet18 = NthLayer(models.resnet18(pretrained=True), layer=3)

# choosing what layer representation to study
# for l in range(len(models.vgg16().features)):
#     print(f'({l}) ', models.vgg16().features[l])   
#     y = NthLayer(models.vgg16(pretrained=True), layer=l)(imgA) 
    # print("dim", torch.numel(y), "shape ", y.shape,)

vgg_pool1 = NthLayer(models.vgg16(pretrained=True), layer=4)
vgg_pool2 = NthLayer(models.vgg16(pretrained=True), layer=9)
vgg_pool3 = NthLayer(models.vgg16(pretrained=True), layer=17)

In [None]:
# out of curiosity, if we are going to use a classifier
# I wonder how sable the predicted label is along the geodesic

In [None]:
predA = po.to_numpy(models.vgg16(pretrained=True)(imgA))[0]
predB = po.to_numpy(models.vgg16(pretrained=True)(imgB))[0]

plt.plot(predA);
plt.plot(predB);

In [None]:
with open("/Users/aldebaran/Downloads/imagenet1000_clsidx_to_labels.txt") as f:
    idx2label = eval(f.read())

for idx in np.argsort(predA)[-5:]:
    print(idx2label[idx])
for idx in np.argsort(predB)[-5:]:
    print(idx2label[idx])

In [None]:
moog = po.synth.Geodesic(imgA, imgB, vgg_pool3)
torch.numel(imgA), torch.numel(moog.model(imgA)), moog.model(imgA).shape, [p.shape for p in moog.parameters() if p.requires_grad]

In [None]:
# this should be run for longer on a GPU
moog.synthesize(max_iter=50, learning_rate=.001, mu=1, nu=0)

In [None]:
# moog.plot_loss()
moog.plot_deviation_from_line();

In [None]:
plt.plot(moog.loss)

In [None]:
try:
    moog.step_jerkiness[0]
    plt.plot([po.to_numpy(j) for j in moog.step_jerkiness]);
    plt.plot([j.mean() for j in moog.step_jerkiness], 'r--', label='path energy');
    plt.legend()
    plt.title('evolution of representation step jerkiness')
    plt.ylabel('step jerkiness')
    plt.xlabel('iteration')
    plt.yscale('log')
    plt.show()
except:
    plt.plot(moog.calculate_path_jerkiness())
    plt.title('final representation step jerkiness')

In [None]:
po.imshow(moog.geodesic, as_rgb=True, zoom=2, title=None, vrange='auto0');
po.imshow(moog.pixelfade, as_rgb=True, zoom=2, title=None, vrange='auto0');
# per channel difference
po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 0:1]], zoom=2, title=None, vrange='auto1');
po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 1:2]], zoom=2, title=None, vrange='auto1');
po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 2:]], zoom=2, title=None, vrange='auto1');
# exaggerated color difference
po.imshow([po.rescale((moog.geodesic - moog.pixelfade)[1:-1])], as_rgb=True, zoom=2, title=None);

In [None]:
# TODO pick better anchor frames, here too small motion?
# TODO investigate misbehaviour jerkiness while loss smoothly decreases