goal
- spatial transformation network
- references
    - https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
    - torch grid sampler https://github.com/pytorch/pytorch/blob/f064c5aa33483061a48994608d890b968ae53fb5/aten/src/THNN/generic/SpatialGridSamplerBilinear.c


In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:/usr/local/cuda/lib64'

import jax
import jax.numpy as np
from jax import grad, jit, vmap, device_put, random
from flax import linen as nn
from jax.scipy.stats import dirichlet

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
print(jax.local_device_count())
print(jax.devices())

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.tri as tri
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman' 
cmap = plt.cm.get_cmap('bwr')


from gpax import *

In [None]:
import torch
import torchvision

# https://stackoverflow.com/questions/66577151/http-error-when-trying-to-download-mnist-data
new_mirror = 'https://ossci-datasets.s3.amazonaws.com/mnist'
torchvision.datasets.MNIST.resources = [
   ('/'.join([new_mirror, url.split('/')[-1]]), md5)
   for url, md5 in torchvision.datasets.MNIST.resources
]

transforms = torchvision.transforms.Compose([
    lambda x: np.asarray(x)[...,np.newaxis] / 255.
])
dataset = torchvision.datasets.MNIST('./data', train=True, transform=transforms, download=True)

In [None]:
patch_shape = (20,14)
h, w = patch_shape
S = np.array(dataset.data[:1].numpy()).reshape(28,28,1)/255.
image_shape = S.shape
H, W, _ = S.shape

patches = extract_patches_2d(S, patch_shape)
Py,Px = (H-h+1),(W-w+1)

P = (H-h+1)*(W-w+1)
hi = np.arange(H-h+1)
wi = np.arange(W-w+1)
hwi = np.array(list(itertools.product(hi, wi)))

fig, ax = plt.subplots(1,1,figsize=(5,5))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(S, cmap='Greys')
ax.scatter(hwi[:,0], hwi[:,1])

print(patches.shape, Py, Px)

In [None]:
s, t = extract_patches_2d_scal_transl(image_shape, patch_shape)
A = vmap(trans2x3_from_scal_transl, (None, 0), 0)(s, t)
A = A[:Px+1]

fn = vmap(spatial_transform_details, (0, None, None), 0)
T, Gs = fn(A, S, patch_shape)

fig, axs = plt.subplots(len(A), 2, figsize=(10,5*len(A)))
for i in range(len(A)):
    plt_spatial_transform(axs[i,:], Gs[i], S, T[i])
    axs[i,0].set_title(f'{A[i]}')
fig.tight_layout()

In [None]:
A = vmap(trans2x3_from_scal_transl, (None, 0), 0)(s, t)
T = vmap(spatial_transform, (0, None, None), 0)(A, S, patch_shape)
ims_from_T = make_im_grid(T, im_per_row=Px, pad_value=0.1)
ims_from_patches = make_im_grid(patches, im_per_row=Px, pad_value=0.1)
ims_diff = ims_from_T - ims_from_patches
print(ims_from_T.min(), ims_from_T.max())
print(ims_from_patches.min(), ims_from_patches.max())
print(ims_diff.min(), ims_diff.max())

fig, ax = plt.subplots(1,1,figsize=(20,20))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(ims_diff, cmap='Greys')



In [None]:
ims_from_patches = make_im_grid(patches, im_per_row=Px, pad_value=.2)
fig, ax = plt.subplots(1,1,figsize=(20,20))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(ims_from_patches, cmap='Greys')



In [None]:

def plt_spatial_transform(axs, Gs, S, T):
    """ Given `axs` of size 2, source grid `Gs` 
            draw source image `S` with source grid `Gs` and
            target spatially transformed image `T`
    """
    h, w = T.shape[0], T.shape[1]
    Gt = homogeneous_grid(h, w)
    Xt, Yt = np.meshgrid(np.linspace(-1, 1, h),
                         np.linspace(-1, 1, w))
    Xs_flat = Gs[0, :]
    Ys_flat = Gs[1, :]
    Xs = Xs_flat.reshape((h,w))
    Ys = Ys_flat.reshape((h,w))
    
    ax = axs[0]
    ax.set_xticks([]); ax.set_yticks([])
    ax.scatter(Xs, Ys, marker='+', c='r', s=50)
    ax.imshow(S, cmap='Greys', extent=(-1,1,1,-1), origin='upper')
    
    ax = axs[1]
    ax.set_xticks([]); ax.set_yticks([])
    ax.scatter(Xt, Yt, marker='+', c='r', s=30)
    ax.imshow(T, cmap='Greys', extent=(-1,1,1,-1), origin='upper')



Tsize = (14, 14)

A = np.stack([
    np.array([[.5,0,0],[0,.5,0]]),
    np.array([[.5,0,.5],[0,.5,0]]),
    np.array([[.5,0,1],[0,.5,0]]),
])

S = np.array(dataset.data[:1].numpy()).reshape(28,28,1)
fn = vmap(spatial_transform_details, (0, None, None), 0)
T, Gs = fn(A, S, Tsize)


fig, axs = plt.subplots(len(A), 2, figsize=(10,5*len(A)))
for i in range(len(A)):
    plt_spatial_transform(axs[i,:], Gs[i], S, T[i])
    axs[i,0].set_title(f'{A[i]}')
    
    
fig.tight_layout()