# convolver.ipynb

Convolve input images from the robot with a given filter.

In [None]:
import h5py
import jax.numpy as jnp
from jax import lax
import numpy as np
from matplotlib import pyplot as plt
#import matplotlib.animation as animation
#from IPython.display import HTML
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
from torch.utils import data
import torch

In [None]:
# load data
# 64x64
#path = "/mnt/bucket/TaniU/Members/prasanna/dataset_5x8/5x6comp1train_dataset0901.h5"
# 256x256
path = "/mnt/bucket/TaniU/Members/prasanna/prasanna_data/datasets/groupA1_traindataset_256x256.h5"
#path = "/mn/bucket/TaniU/Members/prasanna/prasanna_data/datasets/groupA1_testdataset_256x256.h5"

# quick information
f = h5py.File(path, 'r')
keys = list(f.keys())
for key in keys:
    print(f"'{key}' \t dataset shape: {f[key].shape}")

# dataset class
class JaxDataset(data.Dataset):
    """ A Torch Dataset class with the HDF5 datasets.

    :param data_path: path of the HDF5 dataset.
    """
    def __init__(self, data_path):
        self.data_path = data_path
        f = h5py.File(data_path, 'r')
        self.lang_mask = f['lang_mask']
        self.language = f['language']
        self.mask = f['mask']
        self.motor = f['motor']
        self.vision = f['vision']

    def __len__(self):
        return self.vision.shape[0]

    def __getitem__(self, index):
        """ Returns the training data corresponding to 'index'.

        :param index: index of the data to return
        :type index: int
        :return: vision, motor, language, mask, lang_mask
        :rtype: tuple(JAX array)
        """
        return (self.vision[index], self.motor[index], self.language[index],
                self.mask[index], self.lang_mask[index])

dataset = JaxDataset(path)

---
The next cell sets the images in the example in a grid. This code comes from
the Pytorch [visualization utilities](https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html#sphx-glr-auto-examples-plot-visualization-utils-py) documentation.

It is assumed that vision has index 0, and visual data has
shape `[N, 3, W, H]`, i.e. N RGB images of size WxH.  
Moreover, those images have data in the range [-1, 1]

In [None]:
example_index = 12
example = dataset[example_index]
vision = example[0]
print(f"For index {example_index}, the vision data has shape {vision.shape}")

imgs = vision / 2. + 0.5
grid = make_grid(torch.tensor(imgs))

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(10,10))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

show(grid)

Let's convolve all the images in the example with a set filter, using `lax.conv_general_dilated` as in [the tutorial](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html).

In [None]:
# 2D kernel - HWIO layout
kernel = jnp.zeros((3, 3, 3, 3), dtype=jnp.float32)
# kernel += jnp.array([[1, 1, 0],
#                      [1, 0,-1],
#                      [0,-1,-1]])[:, :, jnp.newaxis, jnp.newaxis]

# kernel += jnp.array([[1,  .5,-1],
#                      [1,   0,-1],
#                      [1, -.5,-1]])[:, :, jnp.newaxis, jnp.newaxis]

kernel += jnp.array([[ 1, -1,  1],
                     [-1,  0, -1],
                     [ 1, -1,  1]])[:, :, jnp.newaxis, jnp.newaxis]
print("Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);

In [None]:
dn = lax.conv_dimension_numbers(vision.shape,     # only ndim matters, not shape
                                kernel.shape,  # only ndim matters, not shape 
                                ('NCHW', 'HWIO', 'NCHW'))  # the important bit
print(dn)

In [None]:
out = lax.conv_general_dilated(vision,    # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (1,1),  # window strides
                               'SAME', # padding mode
                               (1,1),  # lhs/image dilation
                               (1,1),  # rhs/kernel dilation
                               dn)     # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(4,4))
plt.imshow(np.array(out)[0, 0, :, :]);

In [None]:
pt_out = torch.from_numpy(np.asarray(out))  # shouldn't do in real code
imgs = pt_out / 2. + 0.5
grid = make_grid(torch.tensor(imgs))

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(12,12))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

show(grid)

---
Quick example of implementing transpose convolution

In [None]:
import jax
from jax import lax
import jax.numpy as jnp

In [None]:

# First, the direct convolution

in_chan = 256  # number of input channels
out_chan = 64  # number of output channels
batch_size = 32
kernel = jnp.ones((in_chan, out_chan, 3, 3))
images = jnp.ones((batch_size, in_chan, 6, 6))

dn = lax.conv_dimension_numbers(images.shape,  # Dimensions for the input
                                kernel.shape,  # Dimensions of the kernel
                                ('NCHW', 'IOHW', 'NCHW'))  # what each dimension is
                                                 # in the input, kernel, and output

out = lax.conv_general_dilated(images, # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (1,1),  # window strides
                               'VALID', # padding mode
                               (1,1),  # lhs/image dilation
                               (1,1),  # rhs/kernel dilation
                               dn)     # dimension_numbers

print(f"Output shape: {out.shape}")


#help(jax.lax.conv_transpose)

In [None]:
# Next, the corresponding transpose convolution
images = jnp.ones((batch_size, in_chan, 4, 4))

dn = lax.conv_dimension_numbers(images.shape,  # Dimensions for the input
                                kernel.shape,  # Dimensions of the kernel
                                ('NCHW', 'IOHW', 'NCHW'))  # what each dimension is
                                                 # in the input, kernel, and output

out = lax.conv_general_dilated(images, # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (1,1),  # window strides
                               [(2,2), (2,2)],  # padding mode
                               (1,1),  # lhs/image dilation
                               (1,1),  # rhs/kernel dilation
                               dn)     # dimension_numbers

print(f"Output shape: {out.shape}")



In [None]:
(20 + 10 - 5 - 4) * 3 + 1

In [None]:

# First, the direct convolution

in_chan = 256  # number of input channels
out_chan = 64  # number of output channels
batch_size = 32
kernel = jnp.ones((in_chan, out_chan, 5, 5))
images = jnp.ones((batch_size, in_chan, 64, 64))

dn = lax.conv_dimension_numbers(images.shape,  # Dimensions for the input
                                kernel.shape,  # Dimensions of the kernel
                                ('NCHW', 'IOHW', 'NCHW'))  # what each dimension is
                                                 # in the input, kernel, and output

out = lax.conv_general_dilated(images, # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (3,3),  # window strides
                               [(2,2), (2,2)], # padding mode
                               (1,1),  # lhs/image dilation
                               (2,2),  # rhs/kernel dilation
                               dn)     # dimension_numbers

print(f"Output shape: {out.shape}")


#help(jax.lax.conv_transpose)

In [None]:
# Next, the corresponding transpose convolution
images = jnp.ones((batch_size, in_chan, 20, 20))

dn = lax.conv_dimension_numbers(images.shape,  # Dimensions for the input
                                kernel.shape,  # Dimensions of the kernel
                                ('NCHW', 'IOHW', 'NCHW'))  # what each dimension is
                                                 # in the input, kernel, and output

out = lax.conv_general_dilated(images, # lhs = image tensor
                               kernel, # rhs = conv kernel tensor
                               (1,1),  # window strides
                               [(5,5), (5,5)],  # padding mode
                               (3,3),  # lhs/image dilation
                               (1,1),  # rhs/kernel dilation
                               dn)     # dimension_numbers

print(f"Output shape: {out.shape}")

In [None]:
import math

def padding(input_dim, output_dim, stride, kernel_size, dilation):
    """Calculate the padding for a convolution.

    The padding is calculated to attain the desired input and output
    dimensions.

    :param input_dim: dimension of the (square) input
    :type input_dim: int
    :param output_dim: height or width of the square convolution output
    :type output_dim: int
    :param stride: stride
    :type stride: int
    :param kernel_size: kernel size
    :type kernel_size: int
    :param dilation: dilation
    :type dilation: int
    :returns: padding for the convolution, residual for the transpose convolution
    :rtype: int, int
    """
    pad = math.ceil(0.5 * (
        stride * (output_dim - 1) - input_dim + dilation * (kernel_size - 1) + 1))
    err_msg = "kernel, stride, dilation and input/output sizes do not match"
    if pad >= 0:
        r = (input_dim + 2 * pad - dilation * (kernel_size - 1) - 1) % stride
        # verify that the padding is correct
        assert ( output_dim ==
            math.floor((input_dim + 2 * pad - dilation * (kernel_size - 1) - 1) / stride) + 1
            ), err_msg
        return pad, r
    else:
        raise ValueError(err_msg)
    
padding(20, 64, 1/3, 5, 2)