## Bi-Directional LSTM on HuggingFace datasets and tokenizers.

In [1]:
import jax
from jax import (
    lax,
    nn,
    random,
    numpy as jnp
)
import optax
import numpy as np
from datasets import load_dataset
from tokenizers import Tokenizer
from torch.utils.data import DataLoader

In [2]:
from mlax import Module
from mlax.nn import Bias, Embed, Linear, Series, SeriesRng
# Local python file containing a bidirectional LSTM layer with output projection.
from lstm import BiLSTMBlock

### Load in the SNLI dataset.

In [3]:
snli_train = load_dataset("snli", cache_dir="../data", split="train").filter(
    lambda d: d["label"] != -1
)
snli_test = load_dataset("snli", cache_dir="../data", split="test").filter(
    lambda d: d["label"] != -1
)

### Tokenize datasets using a pretrained tokenizer.

In [4]:
seq_len = 128
tokenizer = Tokenizer.from_pretrained("roberta-base")
tokenizer.enable_truncation(seq_len)
tokenizer.enable_padding(length=seq_len)

In [5]:
def tokenization(batch):
    encodings = tokenizer.encode_batch(
        list(zip(batch["premise"], batch["hypothesis"]))
    )
    batch["ids"] = [encoding.ids for encoding in encodings]
    batch["type_ids"] = [encoding.type_ids for encoding in encodings]
    batch["mask"] = [
        [bool(i) for i in encoding.attention_mask] for encoding in encodings
    ]
    return batch

In [6]:
snli_train_tokenized = snli_train.map(
    tokenization, batched=True, batch_size=1024, remove_columns=["premise", "hypothesis"], 
)
snli_test_tokenized = snli_test.map(
    tokenization, batched=True, batch_size=1024, remove_columns=["premise", "hypothesis"]
)
snli_train_tokenized.set_format(type="numpy")
snli_test_tokenized.set_format(type="numpy")

### Prepare dataloaders.

In [7]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  elif isinstance(batch[0], dict):
    res = {}
    for key in batch[0]:
      res[key] = numpy_collate([d[key] for d in batch])
    return res
  else:
    return np.array(batch)

batch_size = 256
train_dataloader = DataLoader(
    snli_train_tokenized, batch_size, shuffle=True, collate_fn=numpy_collate, num_workers=0
)
test_dataloader = DataLoader(
    snli_test_tokenized, batch_size, collate_fn=numpy_collate, num_workers=0
)
print(len(train_dataloader), len(test_dataloader))

2146 39


### Build LSTM model.

In [8]:
class Model(Module):
    def __init__(self, rng, vocab_size, embed_size=192, dropout_rate=0.1):
        super().__init__()
        keys_iter = iter([random.fold_in(rng, i) for i in range(7)])
        
        self.embed = Embed(next(keys_iter), vocab_size, embed_size)
        self.type_embed = Embed(next(keys_iter), 2, embed_size)
        self.lstms = SeriesRng([
            BiLSTMBlock(next(keys_iter), embed_size, dropout_rate=dropout_rate),
            BiLSTMBlock(next(keys_iter), embed_size, dropout_rate=dropout_rate),
            BiLSTMBlock(next(keys_iter), embed_size, dropout_rate=dropout_rate)
        ])
        self.fc = Series([
            Linear(next(keys_iter), 3), Bias(next(keys_iter), -1)
        ])

    def setup(self, xm):
        pass

    def forward(self, xm, rng, inference_mode=False, batch_axis_name=()):
        ids, type_ids, mask = xm
        embeddings, self.embed = self.embed(
            ids, None, inference_mode, batch_axis_name
        )
        type_embeddings, self.type_embed = self.type_embed(
            type_ids, None, inference_mode, batch_axis_name
        )
        embeddings = embeddings + type_embeddings
        (activations, _), self.lstms = self.lstms(
            (embeddings, mask), rng, inference_mode, batch_axis_name
        )
        activations = jnp.reshape(activations, (-1,))
        activations, self.fc = self.fc(
            activations, None, inference_mode, batch_axis_name
        )
        return activations

rng1 = random.PRNGKey(0)
rng1, rng2 = random.fold_in(rng1, 0), random.fold_in(rng1, 1)
model = Model(rng1, tokenizer.get_vocab_size())

# Induce lazy initialization
for batch in train_dataloader:
    ids, type_ids, mask = batch["ids"], batch["type_ids"], batch["mask"]
    activations, _ = model(
        (ids[0], type_ids[0], mask[0]), rng2, inference_mode=True
    )
    print(activations)
    print(activations.dtype)
    break

[-0.00568352  0.020241    0.01358531]
float32


### Define loss function.

In [9]:
def loss_fn(batched_preds, batched_targets):
    return optax.softmax_cross_entropy_with_integer_labels(
        batched_preds, batched_targets
    ).mean() # Optax returns per-example loss, this returns the mean batch loss

### Define optimizer using Optax.

In [10]:
optimizer = optax.adamw(learning_rate=5e-4, weight_decay=1e-2)
optim_state = optimizer.init(model.filter())

### Define training and testing steps.

In [11]:
@jax.jit
def train_step(X, y, rng, model, optim_state):
    def _model_loss(X, y, rng, trainables, non_trainables):
        model = trainables.combine(non_trainables)
        preds, model = jax.vmap(
            model.__call__,
            in_axes = (0, None, None, None),
            out_axes = (0, None),
            axis_name = "N"
        )(X, rng, False, "N")
        return loss_fn(preds, y), model

    # Find batch loss and gradients with resect to trainables
    trainables, non_trainables = model.partition()
    (loss, model), gradients = jax.value_and_grad(
        _model_loss,
        argnums=3, # gradients wrt trainables (argument 2 of model_training_loss)
        has_aux=True # model is auxiliary data, loss is the true ouput
    )(X, y, rng, trainables, non_trainables)

    # Get new gradients and optimizer state
    trainables, non_trainables = model.partition()
    gradients, optim_state = optimizer.update(
        gradients, optim_state, trainables
    )

    # Update parameters with new gradients
    trainables = optax.apply_updates(gradients, trainables)
    return loss, trainables.combine(non_trainables), optim_state

In [12]:
@jax.jit
def test_step(X, y, rng, model):
    preds, _ = jax.vmap(
        model.__call__,
        in_axes = (0, None, None, None),
        out_axes = (0, None),
        axis_name = "N"
    )(X, rng, True, "N")
    accurate = (jnp.argmax(preds, axis=1) == y).sum()
    return loss_fn(preds, y), accurate

### Define training and testing loops.

In [13]:
def train_epoch(dataloader, rng, model, optim_state):
    train_loss = 0.0
    for i, batch in enumerate(dataloader):
        _rng = random.fold_in(rng, i)
        ids, type_ids, mask = batch["ids"], batch["type_ids"], batch["mask"]
        y = batch["label"]
        loss, model, optim_state = train_step(
            (ids, type_ids, mask), y, _rng, model, optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / len(dataloader)}") 
    return model, optim_state

In [14]:
def test(dataloader, rng, model):
    test_loss, accurate = 0.0, 0
    for batch in dataloader:
        ids, type_ids, mask = batch["ids"], batch["type_ids"], batch["mask"]
        y = batch["label"]
        loss, acc = test_step((ids, type_ids, mask), y, rng, model)
        test_loss += loss
        accurate += acc

    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [15]:
def train_loop(
    train_dataloader,
    test_dataloader,
    rng,
    model,
    optim_state,
    epochs,
    test_every
):
    for i in range(epochs):
        _rng = random.fold_in(rng, i)
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model, optim_state = train_epoch(
            train_dataloader, _rng, model, optim_state
        )
        if (epoch % test_every == 0):
            test(test_dataloader, _rng, model)
        print(f"----------------")
    return model, optim_state

## Train LSTM on the SNLI dataset.

In [16]:
with jax.default_matmul_precision("float32"):
    new_model, new_optim_state = train_loop(
        train_dataloader,
        test_dataloader,
        rng2,
        model,
        optim_state,
        10, 1
    )

Epoch 1
----------------
Train loss: 0.8552581071853638
Test loss: 0.7622200846672058, accuracy: 0.6709080338478088
----------------
Epoch 2
----------------
Train loss: 0.7262314558029175
Test loss: 0.7073829770088196, accuracy: 0.6969666481018066
----------------
Epoch 3
----------------
Train loss: 0.6693261861801147
Test loss: 0.6739569902420044, accuracy: 0.7186482548713684
----------------
Epoch 4
----------------
Train loss: 0.6268875598907471
Test loss: 0.6520060896873474, accuracy: 0.728318452835083
----------------
Epoch 5
----------------
Train loss: 0.5936973094940186
Test loss: 0.625606119632721, accuracy: 0.7462337613105774
----------------
Epoch 6
----------------
Train loss: 0.5648378133773804
Test loss: 0.6147147417068481, accuracy: 0.7496946454048157
----------------
Epoch 7
----------------
Train loss: 0.5376637578010559
Test loss: 0.612553596496582, accuracy: 0.749287486076355
----------------
Epoch 8
----------------
Train loss: 0.5122790932655334
Test loss: 0.6054