# runner.ipynb

A notebook to use the methods in `torch_conv_lstm.py`

In [None]:
import torch
from torch import nn
import torch.optim as optim
from torch_conv_lstm import *
import time

In [None]:
path = "/media/z/Data/datasets/language_network/groupA1_traindataset_256x256.h5"
config = read_config('rnn_config.yaml')

data_loader = TorchDataLoader(path, config['batch_size'], shuffle=True)

In [None]:
datum = next(data_loader._get_iterator())
vision, motor, language, mask, lang_mask = datum
print(f"Vision shape: {vision.shape}. \t Vision type: {type(vision)}")
print(f"Motor shape: {motor.shape}. \t Motor type: {type(motor)}")
print(f"Language shape: {language.shape}. \t Language type: {type(language)}.")
print(f"Mask shape: {mask.shape}. \t Mask type: {type(mask)}.")
print(f"Lang_mask shape: {lang_mask.shape}. \t Lang_mask type: {type(lang_mask)}.")

---
Test the `ConvLSTMCell` class

In [None]:
# test conv_lstm cell
inp_channels = 3
hid_size = 128  # desired hidden state size
pad, r  = padding_fun(vision.shape[3],
                      hid_size,
                      config['s'],
                      config['inp_kernel_size'],
                      config['kd'])

conv_lstm = ConvLSTMCell(inp_channels,
                         config['h_channels'],
                         config['inp_kernel_size'],
                         config['hid_kernel_size'],
                         inp_stride=config['s'],
                         inp_padding=pad,
                         ik_dilation=config['kd'])

image_batch = vision[:, 0, :, :, :].to(torch.float32)

img_height = image_batch.shape[2]
img_width = image_batch.shape[3]
h, c = conv_lstm.init_hidden_from_normal(vision.shape[0], (img_height, img_width))

print(f"h shape: {h.shape}")
print(f"c shape: {c.shape}")

h, c = conv_lstm(image_batch, (h, c))

print(f"h shape: {h.shape}")
print(f"c shape: {c.shape}")


---
Test the `PredictorCell` class

In [None]:
input_channels = 3  # the RGB channels
hid_size = 128  # desired hidden state size
pad, r  = padding_fun(vision.shape[3],
                      hid_size,
                      config['s'],
                      config['inp_kernel_size'],
                      config['kd'])
conv_params = {'input_channels': input_channels,
               'hidden_channels': config['h_channels'],
               'inp_kernel_size': config['inp_kernel_size'],
               'hid_kernel_size': config['hid_kernel_size'],
               'inp_stride': config['s'],
               'inp_padding': pad,
               'ik_dilation': config['kd'],
               'bias': True,
              }

pad_t, r_t = padding_fun(vision.shape[3],
                         hid_size,
                         config['s'],
                         config['trans_kernel_size'],
                         1)
conv_params_t = {'kernel_size': config['trans_kernel_size'],
                 'ik_dilation': 1,
                 'inp_padding': pad_t,
                 'output_padding': r_t,
                 'bias': True,
                }

pred_cell = PredictorCell(conv_params, conv_params_t)

image_batch = vision[:, 0, :, :, :].to(torch.float32)

img_height = image_batch.shape[2]
img_width = image_batch.shape[3]
h, c = pred_cell.init_hidden(vision.shape[0], (img_height, img_width))

x_next, h, c = pred_cell(image_batch, (h, c))

print(f"h shape: {h.shape}")
print(f"c shape: {c.shape}")
print(f"x_next shape: {x_next.shape}")

---
Test the `Predictor` class

In [None]:
input_channels = 3  # the RGB channels
hid_size = 128  # desired hidden state size
T = vision.shape[1]  # number of images to predict
pad, r  = padding_fun(vision.shape[3],
                      hid_size,
                      config['s'],
                      config['inp_kernel_size'],
                      config['kd'])
conv_params = {'input_channels': input_channels,
               'hidden_channels': config['h_channels'],
               'inp_kernel_size': config['inp_kernel_size'],
               'hid_kernel_size': config['hid_kernel_size'],
               'inp_stride': config['s'],
               'inp_padding': pad,
               'ik_dilation': config['kd'],
               'bias': True,
              }

pad_t, r_t = padding_fun(vision.shape[3],
                         hid_size,
                         config['s'],
                         config['trans_kernel_size'],
                         1)
conv_params_t = {'kernel_size': config['trans_kernel_size'],
                 'ik_dilation': 1,
                 'inp_padding': pad_t,
                 'output_padding': r_t,
                 'bias': True,
                }
image_batch = vision[:, 0, :, :, :].to(torch.float32)

predictor = Predictor(conv_params, conv_params_t)

pred_sequence = predictor(image_batch, T)

print(f"image_batch shape: {image_batch.shape}")
print(f"pred_sequence shape: {pred_sequence.shape}")

---
### Perform backpropagation

In [None]:
# build the network

n_epochs = config['n_epochs']
input_channels = 3  # the RGB channels. This is not configurable.
hid_size = 128  # desired hidden state size
T = vision.shape[1]  # number of images to predict

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

pad, r  = padding_fun(vision.shape[3],
                      hid_size,
                      config['s'],
                      config['inp_kernel_size'],
                      config['kd'])
conv_params = {'input_channels': input_channels,
               'hidden_channels': config['h_channels'],
               'inp_kernel_size': config['inp_kernel_size'],
               'hid_kernel_size': config['hid_kernel_size'],
               'inp_stride': config['s'],
               'inp_padding': pad,
               'ik_dilation': config['kd'],
               'bias': True,
              }

pad_t, r_t = padding_fun(vision.shape[3],
                         hid_size,
                         config['s'],
                         config['trans_kernel_size'],
                         1)
conv_params_t = {'kernel_size': config['trans_kernel_size'],
                 'ik_dilation': 1,
                 'inp_padding': pad_t,
                 'output_padding': r_t,
                 'bias': True,
                }

predictor = Predictor(conv_params, conv_params_t).to(device)

loss_fn = nn.MSELoss(reduction='mean')
optimizer = optim.SGD(predictor.parameters(), lr=config['learning_rate'])


In [None]:
# Run the epochs

if 'losses' not in locals():
    losses = []

start_time = time.time()
warmup_epochs = 1

for epoch in range(n_epochs):
    accum_loss = 0.0
    if epoch == warmup_epochs:
        start_time = time.time()
    for datum in data_loader:
        vision, motor, language, mask, lang_mask = datum
        vision = vision.to(device, dtype=torch.float32)
        first_images_batch = vision[:, 0, :, :, :].detach().clone().to(device)
        optimizer.zero_grad()
        predictions = predictor(first_images_batch, vision.shape[1])
        loss = loss_fn(vision, predictions)
        loss.backward()
        optimizer.step()
        accum_loss += loss.item()
    losses.append(accum_loss / len(data_loader))
    if epoch % 1 == 0:
        print(f"Epoch {epoch}, loss={losses[-1]}")

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Completed {n_epochs - warmup_epochs} epochs in {elapsed_time:.2f} seconds")

2 epochs in 78 seconds with no compilation.  
2 epochs in 132 seconds with "inductor" compilation, no warmup, batch_size=32.  
2 epochs in 133 seconds with "inductor" compilation, 1 epoch warmup, batch_size=32.  
2 epochs in 53 seconds with "cudagraphs" compilation, 1 epoch warmup, batch_size=16.
3 epochs in 79 seconds with "cudagraphs" compilation, 1 epoch warmup, batch_size=16.


In [None]:
# visualize some predictions
from matplotlib import pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
import numpy as np

example_index = 1
datum = next(data_loader._get_iterator())
vision, motor, language, mask, lang_mask = datum
vision = vision.to(device, dtype=torch.float32)
first_images_batch = vision[:, 0, :, :, :].detach().clone().to(device)
predictions = predictor(first_images_batch, vision.shape[1])
vision_ex = predictions[example_index, :, :, :, :]


print(f"For index {example_index}, the vision data has shape {vision_ex.shape}")

imgs = vision_ex / 2. + 0.5
grid = make_grid(torch.tensor(imgs))

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(10,10))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

show(grid)

In [None]:
# plot the losses
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')

In [None]:
# plot the losses
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')

---
Save and load the model

In [None]:
save_name = config['model_path']

In [None]:
# Save model and optimizer state dictionaries along with other info
torch.save({
    'model_state_dict': predictor.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'losses': losses,
}, save_name)

In [None]:
# Load the checkpoint (run after first cell of backpropagation)
checkpoint = torch.load(save_name)
predictor.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
losses = checkpoint['losses']


In [None]:
26 / 61