## Transformer Encoder using HuggingFace datasets and tokenizers.

In [1]:
import jax
from jax import (
    random,
    nn,
    numpy as jnp
)
from functools import partial
import optax
import numpy as np

In [2]:
from mlax import Module, is_trainable
from mlax.nn import (
    Embed,
    Linear,
    Bias,
    Series
)
from mlax.nn.functional import dropout

In [3]:
from encoder import EncoderBlock

In [4]:
from datasets import load_dataset
from tokenizers import Tokenizer
from torch.utils.data import DataLoader

### Load in the IMDB dataset.

In [5]:
imdb_train = load_dataset("imdb", cache_dir="../data", split="train")
imdb_test = load_dataset("imdb", cache_dir="../data", split="test")

Found cached dataset imdb (/home/zongyf02/projects/mlax/examples/Encoder/../data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Found cached dataset imdb (/home/zongyf02/projects/mlax/examples/Encoder/../data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


### Tokenize datasets using a pretrained tokenizer.

In [51]:
tokenizer = Tokenizer.from_pretrained("roberta-base")
tokenizer.enable_truncation(512)
tokenizer.enable_padding(length=512)

In [52]:
def tokenization(batch):
    encodings = tokenizer.encode_batch(batch["text"])
    batch["ids"] = [encoding.ids for encoding in encodings]
    batch["mask"] = [encoding.attention_mask for encoding in encodings]
    del batch["text"]
    return batch

In [53]:
imdb_train_tokenized = imdb_train.map(
    tokenization, batched=True, batch_size=None
)
imdb_test_tokenized = imdb_test.map(
    tokenization, batched=True, batch_size=None
)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

### Prepare dataloaders.

In [54]:
imdb_train_tokenized.set_format(type="numpy")
imdb_test_tokenized.set_format(type="numpy")

In [55]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  elif isinstance(batch[0], dict):
    res = {}
    for key in batch[0]:
      res[key] = numpy_collate([d[key] for d in batch])
    return res
  else:
    return np.array(batch)

batch_size = 128
train_dataloader = DataLoader(
    imdb_train_tokenized, batch_size, shuffle=True, collate_fn=numpy_collate, num_workers=0
)
test_dataloader = DataLoader(
    imdb_test_tokenized, batch_size, collate_fn=numpy_collate, num_workers=0
)
print(len(train_dataloader), len(test_dataloader))

196 196


### Build Encoder model.

In [56]:
# Binary classification model with learnable positional embedding and 2 encoders
class Model(Module):
    def __init__(
        self,
        rng,
        vocab_size,
        seq_len,
        feature_embed_dim = 248,
        pos_embed_dim = 8,
        num_heads = 8,
        ff_depth = 1024,
        act_fn=nn.gelu,
        dropout=0.2
    ):
        super().__init__()
        rngs_iter = iter(random.split(rng, 6))

        self.feature_embeddings = Embed(
            next(rngs_iter), vocab_size, feature_embed_dim
        )
        self.pos_embeddings = Embed(
            next(rngs_iter), seq_len, pos_embed_dim
        )

        model_depth = (feature_embed_dim + pos_embed_dim)
        self.encoder1 = EncoderBlock(
            next(rngs_iter), model_depth, num_heads, ff_depth, act_fn, dropout
        )
        self.encoder2 = EncoderBlock(
            next(rngs_iter), model_depth, num_heads, ff_depth, act_fn, dropout
        )

        self.fc = Series([
            Linear(next(rngs_iter), 1),
            Bias(next(rngs_iter), -1)
        ])

        self.dropout = dropout
    
    @partial(
        jax.vmap,
        in_axes = (None, 0, None, None),
        out_axes = (0, None),
        axis_name = "batch"
    ) # Add leading batch dimension
    def __call__(self, x, rng, inference_mode=False):
        ids, mask = x
        rng1, rng2, rng3 = random.split(rng, 3)

        # Feature and positional embeddings are appended
        embeddings, self.feature_embeddings = self.feature_embeddings(
            ids, None, inference_mode
        )
        pos_embeddings, self.pos_embeddings = self.pos_embeddings(
            jnp.arange((len(ids))), None, inference_mode
        )
        embeddings = jnp.append(
            embeddings, pos_embeddings, axis=1
        )
        if not inference_mode:
            embeddings = dropout(embeddings, rng1, self.dropout)

        # Encoders
        activations, self.encoder1 = self.encoder1(
            (embeddings, mask),
            rng2,
            inference_mode
        )
        activations, self.encoder2 = self.encoder2(
            (activations, mask),
            rng3,
            inference_mode
        )

        # Dense layer
        activations = jnp.reshape(activations, (-1,))
        activations, self.fc = self.fc(activations, None, inference_mode)
        return activations, self

### Define loss function.

In [57]:
rng1, rng2 = random.split(random.PRNGKey(0))
model = Model(rng1, tokenizer.get_vocab_size(), 512)

# Induce lazy weight initialization
for batch in train_dataloader:
    acts, model = model((batch["ids"], batch["mask"]), rng2, False)
    print(acts.shape)
    break

(128, 1)


In [58]:
def loss_fn(
    preds: jnp.array,
    targets: np.array
):
    return optax.sigmoid_binary_cross_entropy(
        preds,
        targets
    ).mean() # Optax returns per-example loss, this returns the mean batch loss

In [59]:
def model_training_loss(
    x_batch: np.array,
    y_batch: np.array,
    rng: jax.Array,
    trainables,
    non_trainables
):
    model = trainables.combine(non_trainables)
    preds, model = model(x_batch, rng, False)
    return loss_fn(jnp.squeeze(preds), y_batch), model

@jax.jit
def model_inference_loss(
    x_batch: np.array,
    y_batch: np.array,
    rng: jax.Array,
    model: Module
):
    preds, _ = model(x_batch, rng, True)
    preds = jnp.squeeze(preds)
    return loss_fn(preds, y_batch), preds

### Define optimizer using Optax.

In [66]:
optimizer = optax.adamw(1e-4, weight_decay=1e-2)
optim_state = optimizer.init(model.filter(is_trainable))

### Define training step.

In [67]:
@jax.jit
def train_step(
    x_batch: np.array, 
    y_batch: np.array,
    rng: jax.Array,
    model: Module,
    optim_state
):
    # Find batch loss and gradients with repect to trainables
    (loss, model), gradients = jax.value_and_grad(
        model_training_loss,
        argnums=3, # gradients wrt trainables (argument 3 of model_training_loss)
        has_aux=True # model is auxiliary data, loss is the true ouput
    )(x_batch, y_batch, rng, *model.partition())

    # Get new gradients and optimizer state
    trainables, non_trainables = model.partition()
    gradients, optim_state = optimizer.update(gradients, optim_state, trainables)

    # Update model_weights with new gradients
    trainables = optax.apply_updates(gradients, trainables)
    return loss, trainables.combine(non_trainables), optim_state

### Define training and testing loops.

In [68]:
def train_epoch(
    dataloader,
    rng,
    model,
    optim_state
):
    train_loss = 0.0
    for batch in dataloader:
        x_batch = (batch["ids"], batch["mask"])
        y_batch = batch["label"]
        sub_rng, rng = random.split(rng)
        loss, model, optim_state = train_step(
            x_batch, y_batch, sub_rng,
            model,
            optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / len(dataloader)}") 
    return model, optim_state

In [69]:
def test(
    dataloader,
    rng,
    model
):
    test_loss, accuracy = 0.0, 0
    for batch in dataloader:
        x_batch = (batch["ids"], batch["mask"])
        y_batch = batch["label"]
        sub_rng, rng = random.split(rng)
        loss, preds = model_inference_loss(
            x_batch, y_batch, sub_rng, model
        )
        test_loss += loss
        accuracy += (nn.sigmoid(preds).round() == y_batch).sum()
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accuracy / len(dataloader.dataset)}")

In [70]:
def train_loop(
    train_dataloader,
    test_dataloader,
    rng,
    model,
    optim_state,
    epochs, test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model, optim_state = train_epoch(
            train_dataloader,
            rng,
            model,
            optim_state
        )
        if (epoch % test_every == 0):
            test(test_dataloader, rng, model)
        print(f"----------------")

    return model, optim_state

## Train Encoder on the IMDB dataset.

In [71]:
new_model, new_optim_state = train_loop(
    train_dataloader,
    test_dataloader,
    rng2,
    model,
    optim_state,
    10, 1
)

Epoch 1
----------------
Train loss: 0.9046181440353394
Test loss: 0.599520206451416, accuracy: 0.6882399916648865
----------------
Epoch 2
----------------
Train loss: 1.960471272468567
Test loss: 0.7225207686424255, accuracy: 0.7914800047874451
----------------
Epoch 3
----------------
Train loss: 0.6404978036880493
Test loss: 0.5163156390190125, accuracy: 0.8023999929428101
----------------
Epoch 4
----------------
Train loss: 0.7001040577888489
Test loss: 0.6478968262672424, accuracy: 0.80103999376297
----------------
Epoch 5
----------------
Train loss: 0.7255018949508667
Test loss: 0.5033421516418457, accuracy: 0.834119975566864
----------------
Epoch 6
----------------
Train loss: 0.32198628783226013
Test loss: 0.43705442547798157, accuracy: 0.8501200079917908
----------------
Epoch 7
----------------
Train loss: 0.20895497500896454
Test loss: 0.4127040207386017, accuracy: 0.8517199754714966
----------------
Epoch 8
----------------
Train loss: 0.14484575390815735
Test loss: 0.4