# runner.ipynb

A notebook to use the methods in `jax_conv_lstm.py`

In [1]:
import torch
from torch import nn
import torch.optim as optim
from torch_conv_lstm import *

In [2]:
path = "/media/z/Data/datasets/language_network/groupA1_traindataset_256x256.h5"
config = read_config('rnn_config.yaml')

data_loader = TorchDataLoader(path, config['batch_size'], shuffle=True)

In [3]:
datum = next(data_loader._get_iterator())
vision, motor, language, mask, lang_mask = datum
print(f"Vision shape: {vision.shape}. \t Vision type: {type(vision)}")
print(f"Motor shape: {motor.shape}. \t Motor type: {type(motor)}")
print(f"Language shape: {language.shape}. \t Language type: {type(language)}.")
print(f"Mask shape: {mask.shape}. \t Mask type: {type(mask)}.")
print(f"Lang_mask shape: {lang_mask.shape}. \t Lang_mask type: {type(lang_mask)}.")

Vision shape: torch.Size([35, 50, 3, 256, 256]). 	 Vision type: <class 'torch.Tensor'>
Motor shape: torch.Size([35, 50, 60]). 	 Motor type: <class 'torch.Tensor'>
Language shape: torch.Size([35, 5, 20]). 	 Language type: <class 'torch.Tensor'>.
Mask shape: torch.Size([35, 50]). 	 Mask type: <class 'torch.Tensor'>.
Lang_mask shape: torch.Size([35, 5]). 	 Lang_mask type: <class 'torch.Tensor'>.


---
Test the `ConvLSTMCell` class

In [4]:
# test conv_lstm cell
inp_channels = 3
hid_size = 128  # desired hidden state size
pad, r  = padding_fun(vision.shape[3],
                      hid_size,
                      config['s'],
                      config['inp_kernel_size'],
                      config['kd'])

conv_lstm = ConvLSTMCell(inp_channels,
                         config['h_channels'],
                         config['inp_kernel_size'],
                         config['hid_kernel_size'],
                         inp_stride=config['s'],
                         inp_padding=pad,
                         ik_dilation=config['kd'])

image_batch = vision[:, 0, :, :, :].to(torch.float32)

img_height = image_batch.shape[2]
img_width = image_batch.shape[3]
h, c = conv_lstm.init_hidden_from_normal(vision.shape[0], (img_height, img_width))

print(f"h shape: {h.shape}")
print(f"c shape: {c.shape}")

h, c = conv_lstm(image_batch, (h, c))

print(f"h shape: {h.shape}")
print(f"c shape: {c.shape}")


h shape: torch.Size([35, 12, 128, 128])
c shape: torch.Size([35, 12, 128, 128])
inp_convs: torch.Size([35, 48, 128, 128])
hid_convs: torch.Size([35, 48, 128, 128])
ic: torch.Size([35, 12, 128, 128])
h shape: torch.Size([35, 12, 128, 128])
c shape: torch.Size([35, 12, 128, 128])


---
Test the `PredictorCell` class

In [5]:
input_channels = 3  # the RGB channels
hid_size = 128  # desired hidden state size
pad, r  = padding_fun(vision.shape[3],
                      hid_size,
                      config['s'],
                      config['inp_kernel_size'],
                      config['kd'])
conv_params = {'input_channels': input_channels,
               'hidden_channels': config['h_channels'],
               'inp_kernel_size': config['inp_kernel_size'],
               'hid_kernel_size': config['hid_kernel_size'],
               'inp_stride': config['s'],
               'inp_padding': pad,
               'ik_dilation': config['kd'],
               'bias': True,
              }

pad_t, r_t = padding_fun(vision.shape[3],
                         hid_size,
                         config['s'],
                         config['trans_kernel_size'],
                         1)
conv_params_t = {'kernel_size': config['trans_kernel_size'],
                 'ik_dilation': 1,
                 'inp_padding': pad_t,
                 'output_padding': r_t,
                 'bias': True,
                }

pred_cell = PredictorCell(conv_params, conv_params_t)

image_batch = vision[:, 0, :, :, :].to(torch.float32)

img_height = image_batch.shape[2]
img_width = image_batch.shape[3]
h, c = pred_cell.init_hidden(vision.shape[0], (img_height, img_width))

x_next, h, c = pred_cell(image_batch, (h, c))

print(f"h shape: {h.shape}")
print(f"c shape: {c.shape}")
print(f"x_next shape: {x_next.shape}")

h shape: torch.Size([35, 12, 128, 128])
c shape: torch.Size([35, 12, 128, 128])
x_next shape: torch.Size([35, 3, 256, 256])


---
Test the `Predictor` class

In [6]:
input_channels = 3  # the RGB channels
hid_size = 128  # desired hidden state size
T = vision.shape[1]  # number of images to predict
pad, r  = padding_fun(vision.shape[3],
                      hid_size,
                      config['s'],
                      config['inp_kernel_size'],
                      config['kd'])
conv_params = {'input_channels': input_channels,
               'hidden_channels': config['h_channels'],
               'inp_kernel_size': config['inp_kernel_size'],
               'hid_kernel_size': config['hid_kernel_size'],
               'inp_stride': config['s'],
               'inp_padding': pad,
               'ik_dilation': config['kd'],
               'bias': True,
              }

pad_t, r_t = padding_fun(vision.shape[3],
                         hid_size,
                         config['s'],
                         config['trans_kernel_size'],
                         1)
conv_params_t = {'kernel_size': config['trans_kernel_size'],
                 'ik_dilation': 1,
                 'inp_padding': pad_t,
                 'output_padding': r_t,
                 'bias': True,
                }
image_batch = vision[:, 0, :, :, :].to(torch.float32)

predictor = Predictor(conv_params, conv_params_t)

pred_sequence = predictor(image_batch, T)

print(f"image_batch shape: {image_batch.shape}")
print(f"pred_sequence shape: {pred_sequence.shape}")

image_batch shape: torch.Size([35, 3, 256, 256])
pred_sequence shape: torch.Size([35, 50, 3, 256, 256])


---
### Perform backpropagation

In [4]:
n_epochs = 10
input_channels = 3  # the RGB channels
hid_size = 128  # desired hidden state size
T = vision.shape[1]  # number of images to predict

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

pad, r  = padding_fun(vision.shape[3],
                      hid_size,
                      config['s'],
                      config['inp_kernel_size'],
                      config['kd'])
conv_params = {'input_channels': input_channels,
               'hidden_channels': config['h_channels'],
               'inp_kernel_size': config['inp_kernel_size'],
               'hid_kernel_size': config['hid_kernel_size'],
               'inp_stride': config['s'],
               'inp_padding': pad,
               'ik_dilation': config['kd'],
               'bias': True,
              }

pad_t, r_t = padding_fun(vision.shape[3],
                         hid_size,
                         config['s'],
                         config['trans_kernel_size'],
                         1)
conv_params_t = {'kernel_size': config['trans_kernel_size'],
                 'ik_dilation': 1,
                 'inp_padding': pad_t,
                 'output_padding': r_t,
                 'bias': True,
                }

predictor = Predictor(conv_params, conv_params_t).to(device)

loss_fn = nn.MSELoss(reduction='mean')
optimizer = optim.SGD(predictor.parameters(), lr=config['learning_rate'])


cuda:0


In [5]:
accum_loss = 0.0
losses = []
for epoch in range(n_epochs):
    for datum in data_loader:
        vision, motor, language, mask, lang_mask = datum
        vision = vision.to(device, dtype=torch.float32)
        first_images_batch = vision[:, 0, :, :, :].detach().clone().to(device)
        optimizer.zero_grad()
        predictions = predictor(first_images_batch, vision.shape[1])
        loss = loss_fn(vision, predictions)
        loss.backward()
        optimizer.step()
        accum_loss += loss.item()
    losses.append(accum_loss / len(data_loader)
    if epoch % 5 == 0:
        print(f"Epoch {epoch}, loss={losses[-1]}")

Epoch 0, loss=0.3999221192465888
Epoch 5, loss=2.3724786374304028


In [8]:
# visualize some predictions
example_index = 0
datum = next(data_loader._get_iterator())
vision, motor, language, mask, lang_mask = datum
vision = vision.to(device, dtype=torch.float32)
first_images_batch = vision[:, 0, :, :, :].detach().clone().to(device)
predictions = predictor(first_images_batch, vision.shape[1])
vision_ex = predictions[example_index, :, :, :, :]

OutOfMemoryError: CUDA out of memory. Tried to allocate 28.00 MiB (GPU 0; 23.46 GiB total capacity; 20.99 GiB already allocated; 47.12 MiB free; 21.47 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF