# runner.ipynb

A notebook to use the methods in `jax_conv_lstm.py`

In [None]:
%cd ..
from jax_conv_lstm import *
import time
import torch
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
from matplotlib import pyplot as plt

In [None]:
# configuration and data loader
config = read_config('rnn_config.yaml')
dataset = JaxDataset(config['data_path'])
print("Configuration dictionary:")
print(config)
data_loader = NumpyLoader(dataset, batch_size=config['batch_size'], shuffle=True)

In [None]:
# explore the data_loader object
datum = next(data_loader._get_iterator())
vision, motor, language, mask, lang_mask = datum
print(f"Vision shape: {vision.shape}. \t Vision type: {type(vision)}")
print(f"Motor shape: {motor.shape}. \t Motor type: {type(motor)}")
print(f"Language shape: {language.shape}. \t Language type: {type(language)}.")
print(f"Mask shape: {mask.shape}. \t Mask type: {type(mask)}.")
print(f"Lang_mask shape: {lang_mask.shape}. \t Lang_mask type: {type(lang_mask)}.")
config['vision'] = vision[:, 0:1, :, :, :]  # for shape reference when creating parameters

In [None]:
# Initialize the parameters
key = random.key(34)
params, conv_params, conv_params_t, h_shape  = create_random_params(key, config)
# convert parameters to named tuples    
params_nt, conv_params_nt, conv_params_t_nt = params_to_nt(params, conv_params, conv_params_t)

In [None]:
## Gradient descent
# initialize the hidden state and cell state
key = random.key(23456)
key_h, key_c = random.split(key)
h = 0.1 * jax.random.normal(key_h, h_shape)
c = 0.1 * jax.random.normal(key_c, h_shape)
lr = config['learning_rate']

if 'losses' not in locals():
    losses = []

start_time = time.time()
warmup_epochs = 1

# Run training epochs
for epoch in range(config['n_epochs']):
    running_loss = 0.0
    if epoch == warmup_epochs:
        start_time = time.time()
    for datum in data_loader:
        vision, motor, language, mask, lang_mask = datum
        loss_val, params_nt = sgd_update(params_nt,
                                         vision,
                                         h,
                                         c,
                                         conv_params_nt,
                                         conv_params_t_nt,
                                         lr)
        running_loss += loss_val
    losses.append(running_loss / len(data_loader))
    if epoch % 1 == 0:
        print(f"Epoch {epoch}, loss: {losses[-1]}")

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Completed {config['n_epochs'] - warmup_epochs} epochs in {elapsed_time:.2f} seconds")

2 epochs in 100 seconds with jit, 1 epoch warmup, batch_size=32  
2 epochs in 98 seconds with jit, 1 epoch warmup, batch_size=16  
2 epochs in 205 seconds without jit, 1 epoch warmup, batch_size=16  


In [None]:
# visualize some predictions

prediction = prediction_n_steps(params_nt, vision, h, c, conv_params_nt, conv_params_t_nt)
pt_prediction = torch.from_numpy(np.asarray(prediction)) # don't use in real code
example_index = 3 
vision_ex = pt_prediction[example_index, :, :, :, :]
print(f"For index {example_index}, the vision data has shape {vision_ex.shape}")

imgs = vision_ex / 2. + 0.5
grid = make_grid(torch.tensor(imgs))

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(10,10))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

show(grid)

In [None]:
# plot the losses
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')

---
Save and load the model

In [None]:
# Save the parameters of the model
    with open(config['model_path'], 'wb') as file:
        pickle.dump(params_nt, file)