## Transformer Encoder using HuggingFace datasets and tokenizers.

In [1]:
import jax
from jax import (
    lax,
    random,
    nn,
    numpy as jnp
)
import optax
import numpy as np

In [2]:
from mlax import Module, Parameter
from mlax.nn import (
    Embed,
    Linear,
    Bias,
    Series
)

In [3]:
from encoder import EncoderBlock

In [4]:
from datasets import load_dataset
from tokenizers import Tokenizer
from torch.utils.data import DataLoader

### Load in the IMDB dataset.

In [5]:
imdb_train = load_dataset("imdb", cache_dir="../data", split="train")
imdb_test = load_dataset("imdb", cache_dir="../data", split="test")

Found cached dataset imdb (/home/zongyf02/projects/mlax/examples/Encoder/../data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Found cached dataset imdb (/home/zongyf02/projects/mlax/examples/Encoder/../data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


### Tokenize datasets using a pretrained tokenizer.

In [6]:
seq_len = 512
tokenizer = Tokenizer.from_pretrained("roberta-base")
tokenizer.enable_truncation(seq_len)
tokenizer.enable_padding(length=seq_len)

In [7]:
def tokenization(batch):
    encodings = tokenizer.encode_batch(batch["text"])
    batch["ids"] = [encoding.ids for encoding in encodings]
    batch["mask"] = [
        [bool(i) for i in encoding.attention_mask] for encoding in encodings
    ]
    del batch["text"]
    return batch

In [8]:
imdb_train_tokenized = imdb_train.map(
    tokenization, batched=True, batch_size=None
)
imdb_test_tokenized = imdb_test.map(
    tokenization, batched=True, batch_size=None
)

Loading cached processed dataset at /home/zongyf02/projects/mlax/examples/data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-4d102de10fe8f91d.arrow
Loading cached processed dataset at /home/zongyf02/projects/mlax/examples/data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-6357fc30ae0156bd.arrow


### Prepare dataloaders.

In [9]:
imdb_train_tokenized.set_format(type="numpy")
imdb_test_tokenized.set_format(type="numpy")

In [10]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  elif isinstance(batch[0], dict):
    res = {}
    for key in batch[0]:
      res[key] = numpy_collate([d[key] for d in batch])
    return res
  else:
    return np.array(batch)

batch_size = 128
train_dataloader = DataLoader(
    imdb_train_tokenized, batch_size, shuffle=True, collate_fn=numpy_collate, num_workers=0
)
test_dataloader = DataLoader(
    imdb_test_tokenized, batch_size, collate_fn=numpy_collate, num_workers=0
)
print(len(train_dataloader), len(test_dataloader))

196 196


### Build Encoder model.

In [11]:
class RotaryEmbed(Module):
    def __init__(self, seq_len, embed_dim):
        super().__init__()
        inv_freq = 1.0 / (
            10000.0 ** (jnp.arange(0, embed_dim, 2, dtype=jnp.float32) / embed_dim)
        )
        pos = jnp.arange(seq_len, dtype=jnp.float32)
        pos_enc = lax.dot_general(pos, inv_freq, (((), ()), ((), ())))

        self.seq_len = seq_len
        self.sin_enc = Parameter(trainable=False, data=lax.sin(pos_enc))
        self.cos_enc = Parameter(trainable=False, data=lax.cos(pos_enc))
    
    def init(self, x):
        pass

    def apply(self, x, rng=None, inference_mode=False, batch_axis_name=()):
        shape = self.sin_enc.data.shape[:-1] + (-1,)
        sin = jnp.stack(
            [self.sin_enc.data, self.sin_enc.data], axis=-1
        ).reshape(shape)
        cos = jnp.stack(
            [self.cos_enc.data, self.cos_enc.data], axis=-1
        ).reshape(shape)
        rotated_x = jnp.stack(
            [-x[..., 1::2], x[..., ::2]], axis=-1
        ).reshape(x.shape)
        return x * cos + rotated_x * sin

class Model(Module):
    def __init__(
        self,
        rng,
        vocab_size,
        seq_len,
        embed_dim=256,
        num_heads=8,
        ff_depth=1024,
        act_fn=nn.gelu,
        dropout=0.1
    ):
        super().__init__()
        keys_iter = iter([random.fold_in(rng, i) for i in range(5)])

        self.embed = Embed(
            next(keys_iter), vocab_size, embed_dim
        )
        self.rotary = RotaryEmbed(seq_len, embed_dim)

        self.encoder1 = EncoderBlock(
            next(keys_iter), num_heads, ff_depth, act_fn, dropout
        )
        self.encoder2 = EncoderBlock(
            next(keys_iter), num_heads, ff_depth, act_fn, dropout
        )

        self.fc = Series([
            Linear(next(keys_iter), 1),
            Bias(next(keys_iter), -1)
        ])
    
    def init(self, x):
        pass

    def apply(self, xm, rng, inference_mode=False, batch_axis_name=()):
        ids, mask = xm
        embeddings, self.embed = self.embed(
            ids, None, inference_mode, batch_axis_name
        )
        embeddings, self.rotary = self.rotary(
            embeddings, None, inference_mode, batch_axis_name
        )
        activations, self.encoder1 = self.encoder1(
            (embeddings, mask), random.fold_in(rng, 0),
            inference_mode, batch_axis_name
        )
        activations, self.rotary = self.rotary(
            activations, None, inference_mode, batch_axis_name
        )
        activations, self.encoder2 = self.encoder2(
            (activations, mask), random.fold_in(rng, 1),
            inference_mode, batch_axis_name
        )
        activations = jnp.reshape(activations, (-1,))
        activations, self.fc = self.fc(
            activations, None, inference_mode, batch_axis_name
        )
        return jnp.squeeze(activations)

rng1 = random.PRNGKey(0)
rng1, rng2 = random.fold_in(rng1, 0), random.fold_in(rng1, 1)
model = Model(rng1, tokenizer.get_vocab_size(), seq_len, dropout=0.2)

# Induce lazy initialization
for batch in train_dataloader:
    ids, mask = batch["ids"], batch["mask"]
    activations, _ = model((ids[0], mask[0]), rng2, inference_mode=True)
    print(activations)
    print(activations.dtype)
    break

-3.5083568
float32


### Define loss function.

In [12]:
def loss_fn(batched_preds, batched_targets):
    return optax.sigmoid_binary_cross_entropy(
        batched_preds, batched_targets
    ).mean() # Optax returns per-example loss, this returns the mean batch loss

### Define optimizer using Optax.

In [13]:
optimizer = optax.adamw(5e-4, weight_decay=1e-2)
optim_state = optimizer.init(model.filter())

### Define training and testing steps.

In [14]:
@jax.jit
def train_step(X, y, rng, model, optim_state):
    def _model_loss(X, y, rng, trainables, non_trainables):
        model = trainables.combine(non_trainables)
        preds, model = jax.vmap(
            model.__call__,
            in_axes = (0, None, None, None),
            out_axes = (0, None),
            axis_name = "N"
        )(X, rng, False, "N")
        return loss_fn(preds, y), model

    # Find batch loss and gradients with resect to trainables
    trainables, non_trainables = model.partition()
    (loss, model), gradients = jax.value_and_grad(
        _model_loss,
        argnums=3, # gradients wrt trainables (argument 2 of model_training_loss)
        has_aux=True # model is auxiliary data, loss is the true ouput
    )(X, y, rng, trainables, non_trainables)

    # Get new gradients and optimizer state
    trainables, non_trainables = model.partition()
    gradients, optim_state = optimizer.update(
        gradients, optim_state, trainables
    )

    # Update parameters with new gradients
    trainables = optax.apply_updates(gradients, trainables)
    return loss, trainables.combine(non_trainables), optim_state

In [15]:
@jax.jit
def test_step(X, y, rng, model):
    preds, _ = jax.vmap(
        model.__call__,
        in_axes = (0, None, None, None),
        out_axes = (0, None),
        axis_name = "N"
    )(X, rng, True, "N")
    accurate = (jnp.round(nn.sigmoid(preds)) == y).astype(jnp.int32).sum()
    return loss_fn(preds, y), accurate

### Define training and testing loops.

In [16]:
def train_epoch(dataloader, rng, model, optim_state):
    train_loss = 0.0
    for i, batch in enumerate(dataloader):
        _rng = random.fold_in(rng, i)
        ids, mask, y = batch["ids"], batch["mask"], batch["label"]
        loss, model, optim_state = train_step(
            (ids, mask), y, _rng, model, optim_state
        )
        train_loss += loss

    print(f"Train loss: {train_loss / len(dataloader)}") 
    return model, optim_state

In [17]:
def test(dataloader, rng, model):
    test_loss, accurate = 0.0, 0
    for batch in dataloader:
        ids, mask, y = batch["ids"], batch["mask"], batch["label"]
        loss, acc = test_step((ids, mask), y, rng, model)
        test_loss += loss
        accurate += acc

    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [18]:
def train_loop(
    train_dataloader,
    test_dataloader,
    rng,
    model,
    optim_state,
    epochs,
    test_every
):
    for i in range(epochs):
        _rng = random.fold_in(rng, i)
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model, optim_state = train_epoch(
            train_dataloader, _rng, model, optim_state
        )
        if (epoch % test_every == 0):
            test(test_dataloader, _rng, model)
        print(f"----------------")
    return model, optim_state

## Train Encoder on the IMDB dataset.

In [19]:
with jax.default_matmul_precision("float32"):
    new_params, new_optim_state = train_loop(
        train_dataloader,
        test_dataloader,
        rng2,
        model,
        optim_state,
        30, 5
    )

Epoch 1
----------------
Train loss: 1.1788954734802246
----------------
Epoch 2
----------------
Train loss: 0.10283064842224121
----------------
Epoch 3
----------------
Train loss: 0.01246965117752552
----------------
Epoch 4
----------------
Train loss: 0.0016183697152882814
----------------
Epoch 5
----------------
Train loss: 0.0004987806896679103
Test loss: 0.7959541082382202, accuracy: 0.8016799688339233
----------------
Epoch 6
----------------
Train loss: 0.00029921604436822236
----------------
Epoch 7
----------------
Train loss: 0.00021634899894706905
----------------
Epoch 8
----------------
Train loss: 0.0001659040863160044
----------------
Epoch 9
----------------
Train loss: 0.0001315609406447038
----------------
Epoch 10
----------------
Train loss: 0.00010687522444641218
Test loss: 0.8817386627197266, accuracy: 0.8104400038719177
----------------
Epoch 11
----------------
Train loss: 8.837754285195842e-05
----------------
Epoch 12
----------------
Train loss: 7.371755