goal
- spatial transformation network
- references
    - https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
    - torch grid sampler https://github.com/pytorch/pytorch/blob/f064c5aa33483061a48994608d890b968ae53fb5/aten/src/THNN/generic/SpatialGridSamplerBilinear.c


In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:/usr/local/cuda/lib64'

import jax
import jax.numpy as np
from jax import grad, jit, vmap, device_put, random
from flax import linen as nn
from jax.scipy.stats import dirichlet

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
print(jax.local_device_count())
print(jax.devices())

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.tri as tri
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman' 
cmap = plt.cm.get_cmap('bwr')


from gpax import *

In [None]:
import torch
import torchvision

# https://stackoverflow.com/questions/66577151/http-error-when-trying-to-download-mnist-data
new_mirror = 'https://ossci-datasets.s3.amazonaws.com/mnist'
torchvision.datasets.MNIST.resources = [
   ('/'.join([new_mirror, url.split('/')[-1]]), md5)
   for url, md5 in torchvision.datasets.MNIST.resources
]

transforms = torchvision.transforms.Compose([
    lambda x: np.asarray(x)[...,np.newaxis] / 255.
])
dataset = torchvision.datasets.MNIST('./data', train=True, transform=transforms, download=True)

In [None]:

im = dataset.data[0]
S = np.array(im.numpy()).reshape(28,28,1)
ims = np.stack((S,S,S))
Tsize = (28, 28)

## 

height, width = Tsize
Gt = homogeneous_grid(height, width)
Xt, Yt = np.meshgrid(np.linspace(-1, 1, width),
                         np.linspace(-1, 1, height))
A = np.array([[.5,0,0],[0,.5,0]])
Gs = A@Gt
Xs_flat = Gs[0, :]
Ys_flat = Gs[1, :]
Xs = Xs_flat.reshape(*Tsize)
Ys = Ys_flat.reshape(*Tsize)
T = grid_sample(S, (Xs, Ys))

fig, axs = plt.subplots(1, 2, figsize=(20,10))

ax = axs[0]
ax.set_xticks([]); ax.set_yticks([])
ax.scatter(Xs, Ys, c='b')
ax.imshow(S, cmap='Greys', extent=(-1,1,-1,1))

ax = axs[1]
ax.set_xticks([]); ax.set_yticks([])
ax.scatter(Xt, Yt, c='b')
ax.imshow(T, cmap='Greys', extent=(-1,1,-1,1))


In [None]:

Xu = ims
As = np.stack((A*.5, A, A*1.5))

spatial_transform_vmap = vmap(spatial_transform, (0, 0, None), 0)
Ts = spatial_transform_vmap(As, Xu, (28, 28))

fig, axs = plt.subplots(1, 3, figsize=(15,5))

for i in range(3):
    ax = axs[i]
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(Ts[i], cmap='Greys')



