# flax_conv_lstm.ipynb

A convolutional LSTM that predicts the next image for the block manipulation task.  
This version uses Flax's [ConvLSTMCell](https://flax.readthedocs.io/en/latest/_modules/flax/linen/recurrent.html#ConvLSTMCell) rather than `conv_general_dilated`.

In [1]:
%cd ..
from flax_conv_lstm import *
import jax
import jax.numpy as jnp
# from jax import lax
from jax import random
from jax.tree_util import tree_map
import numpy as np
from matplotlib import pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
# from torch.utils import data
import torch
# import math
from functools import partial
# from jax.tree_util import Partial
from flax import linen as nn
from flax.linen.recurrent import ConvLSTMCell
import optax
import time

/home/z/projects/language_network/conv_lstm_test/jax


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
# configuration and data loader
config = read_config('rnn_config.yaml')
dataset = JaxDataset(config['data_path'])
print("Configuration dictionary:")
print(config)
data_loader = NumpyLoader(dataset, batch_size=config['batch_size'], shuffle=True)

Configuration dictionary:
{'data_path': '/media/z/Data/datasets/language_network/groupA1_traindataset_256x256.h5', 'model_path': '/media/z/Data/datasets/language_network/saved_params.pkl', 'batch_size': 32, 'n_epochs': 500, 'learning_rate': 0.0005, 'h_channels': 12, 'inp_kernel_size': 3, 'hid_kernel_size': 3, 'trans_kernel_size': 5, 's': 2, 'p': 0, 'd': 1, 'kd': 1}


In [3]:
# explore the data_loader object
datum = next(data_loader._get_iterator())
vision, motor, language, mask, lang_mask = datum
print(f"Vision shape: {vision.shape}. \t Vision type: {type(vision)}")
print(f"Motor shape: {motor.shape}. \t Motor type: {type(motor)}")
print(f"Language shape: {language.shape}. \t Language type: {type(language)}.")
print(f"Mask shape: {mask.shape}. \t Mask type: {type(mask)}.")
print(f"Lang_mask shape: {lang_mask.shape}. \t Lang_mask type: {type(lang_mask)}.")
config['vision'] = vision[:, 0:1, :, :, :]  # for shape reference when creating parameters

Vision shape: (32, 50, 256, 256, 3). 	 Vision type: <class 'numpy.ndarray'>
Motor shape: (32, 50, 60). 	 Motor type: <class 'numpy.ndarray'>
Language shape: (32, 5, 20). 	 Language type: <class 'numpy.ndarray'>.
Mask shape: (32, 50). 	 Mask type: <class 'numpy.ndarray'>.
Lang_mask shape: (32, 5). 	 Lang_mask type: <class 'numpy.ndarray'>.


In [4]:
# Model initialization
h_channels = config['h_channels']
kernel_size = (config['hid_kernel_size'], config['hid_kernel_size'])
strides = config['s']

conv_lstm_cell = ConvLSTMCell(h_channels,
                              kernel_size,
                              strides=1)

key = jax.random.key(23)
key, *l1_keys = random.split(key, 4)
inp_shape = vision[:, 0, :, :, :].shape
carry = conv_lstm_cell.initialize_carry(l1_keys[0], inp_shape)
carry = (0.05 * random.normal(random.key(1), carry[0].shape),  # non-zero initialization
         0.05 * random.normal(random.key(2), carry[1].shape))

conv_out, conv_params = conv_lstm_cell.init_with_output(l1_keys[1],
                                                        carry, vision[:, 0, :, :, :])

In [5]:
# jax.tree_util.tree_map(jnp.shape, conv_out)
jax.tree_util.tree_map(jnp.mean, conv_params)

{'params': {'hh': {'bias': Array(0., dtype=float32),
   'kernel': Array(0.00212648, dtype=float32)},
  'ih': {'bias': Array(0., dtype=float32),
   'kernel': Array(0.01056828, dtype=float32)}}}

In [6]:
print(conv_lstm_cell.tabulate(key, carry,
                        vision[:, 0, :, :, :],
                       compute_flops=False, compute_vjp_flops=False))


[3m                              ConvLSTMCell Summary                              [0m
┏━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath[0m[1m [0m┃[1m [0m[1mmodule      [0m[1m [0m┃[1m [0m[1minputs          [0m[1m [0m┃[1m [0m[1moutputs         [0m[1m [0m┃[1m [0m[1mparams          [0m[1m [0m┃
┡━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│      │ ConvLSTMCell │ - -              │ - -              │                  │
│      │              │ [2mfloat32[0m[32,256,… │ [2mfloat32[0m[32,256,… │                  │
│      │              │   -              │   -              │                  │
│      │              │ [2mfloat32[0m[32,256,… │ [2mfloat32[0m[32,256,… │                  │
│      │              │ -                │ -                │                  │
│      │              │ [2mfloat32[0m[32,256,… │ [2mfloat32[0m[32,256,… │                  │
├──────┼─────

In [7]:
conv_trans = nn.ConvTranspose(3, config['trans_kernel_size'])
trans_out, trans_params = conv_trans.init_with_output(l1_keys[2],
                                                      jnp.ones(conv_out[1].shape))

In [8]:
def prediction_step(params, carry, x):
    new_carry, h = conv_lstm_cell.apply(params['conv_params'], carry, x)
    new_x = conv_trans.apply(params['trans_params'], h)
    return new_carry, x

def prediction_n_steps(params, carry, vision):
    n = vision.shape[1]
    x_pred = jnp.zeros_like(vision)
    x = vision[:, 0, :, :, :]
    x_pred = x_pred.at[:, 0, :, :, :].set(x)
    for i in range(1, n):
        carry, x = prediction_step(params, carry, x)
        x_pred = x_pred.at[:, i, :, :, :].set(x)
    return x_pred

In [9]:
params = {'conv_params': conv_params,
          'trans_params': trans_params}

@jax.jit
def mse(params, carry, vision):
    x_pred = prediction_n_steps(params, carry, vision)
    # print(f"Mean x_pred: {jnp.mean(x_pred)}")
    return jnp.mean(optax.l2_loss(x_pred, vision))
    # return jnp.mean((x_pred - vision)**2)

x_pred = prediction_n_steps(params, carry, vision)
print(x_pred.shape)
print(mse(params, carry, vision))

(32, 50, 256, 256, 3)
0.028330727


In [None]:
tx = optax.adam(learning_rate=config['learning_rate'])
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

start_time = time.time()
warmup_epochs = 1

for epoch in range(config['n_epochs']): #range(config['n_epochs']):
    if epoch == warmup_epochs:
        start_time = time.time()
    for datum in data_loader:
        vision, motor, language, mask, lang_mask = datum
        loss_val, grads = loss_grad_fn(params, carry, vision)
        # print(tree_map(jnp.mean, grads))
        # print(f"Loss: {loss_val}")
        updates, opt_state = tx.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
    if epoch % 2 == 0:
        print(f"Loss at epoch {epoch}: {loss_val}")

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Completed {config['n_epochs'] - warmup_epochs} epochs in {elapsed_time:.2f} seconds")

Loss at epoch 0: 0.03069593571126461
Loss at epoch 2: 0.026686973869800568
Loss at epoch 4: 0.03228026628494263
Loss at epoch 6: 0.031136872246861458
Loss at epoch 8: 0.028465382754802704
Loss at epoch 10: 0.028439421206712723
Loss at epoch 12: 0.03022477589547634


In [None]:
def sgd_update(params, carry, vision, lr):
    loss_val, grads = jax.value_and_grad(mse)(params, carry, vision)
    new_params = tree_map(
        lambda p, g: p - lr * g, params, grads
    )
    return loss_val, new_params

lr = config['learning_rate']
for epoch in range(config['n_epochs']):
    for datum in data_loader:
        vision, motor, language, mask, lang_mask = datum
        loss_val, params = sgd_update(params, carry, vision, lr)
    if epoch % 4 == 0:
        print(f"Epoch {epoch}, loss: {loss_val}")

In [None]:
# visualize some predictions

prediction = prediction_n_steps(params, carry, vision)
prediction = prediction.transpose(0, 1, 4, 2, 3)

pt_prediction = torch.from_numpy(np.asarray(prediction)) # don't use in real code
example_index = 8
vision_ex = pt_prediction[example_index, :, :, :, :]
print(f"For index {example_index}, the vision data has shape {vision_ex.shape}")

imgs = vision_ex / 2. + 0.5
grid = make_grid(torch.tensor(imgs))

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(10,10))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

show(grid)

5m 14s for 100 epochs without jit

2m 52s for 100 epochs with jit (no warmup)  
102m for 4000 epochs with jit

In [None]:
# plot the losses
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')