## Transformer Encoder Reference.

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

In [2]:
from datasets import load_dataset, Value
from tokenizers import Tokenizer

### Load in the IMDB dataset.

In [3]:
imdb_train = load_dataset("imdb", cache_dir="../data", split="train")
imdb_test = load_dataset("imdb", cache_dir="../data", split="test")

Found cached dataset imdb (/home/zongyf02/projects/mlax/examples/Encoder/../data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Found cached dataset imdb (/home/zongyf02/projects/mlax/examples/Encoder/../data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


In [4]:
seq_len = 512
tokenizer = Tokenizer.from_pretrained("roberta-base")
tokenizer.enable_truncation(seq_len)
tokenizer.enable_padding(length=seq_len)

In [5]:
def tokenization(batch):
    encodings = tokenizer.encode_batch(batch["text"])
    batch["ids"] = [encoding.ids for encoding in encodings]
    batch["mask"] = [
        [not bool(i) for i in encoding.attention_mask] for encoding in encodings
    ]
    del batch["text"]
    return batch

In [6]:
imdb_train_tokenized = imdb_train.map(
    tokenization, batched=True, batch_size=1024
)
imdb_test_tokenized = imdb_test.map(
    tokenization, batched=True, batch_size=1024
)

Loading cached processed dataset at /home/zongyf02/projects/mlax/examples/data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-83e2b24d1a51517b.arrow
Loading cached processed dataset at /home/zongyf02/projects/mlax/examples/data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-ae4f513eadcbf358.arrow


In [7]:
imdb_train_tokenized.set_format(type="torch")
imdb_test_tokenized.set_format(type="torch")

### Prepare dataloaders.

In [8]:
batch_size = 128
train_dataloader = DataLoader(imdb_train_tokenized, batch_size=batch_size, shuffle=True, num_workers=0)
test_dataloader = DataLoader(imdb_test_tokenized, batch_size=batch_size, num_workers=0)

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


### Build Encoder model.

In [10]:
class RotaryEmbed(nn.Module):
    def __init__(self, seq_len, embed_dim):
        super().__init__()
        inv_freq = 1.0 / (
            10000 ** (torch.arange(0, embed_dim, 2, dtype=torch.float) / embed_dim)
        )
        pos = torch.arange(seq_len, dtype=torch.float)
        pos_enc = torch.einsum("p, f -> p f", pos, inv_freq)
        self.seq_len = seq_len
        self.register_buffer("sin_enc", pos_enc.sin())
        self.register_buffer("cos_enc", pos_enc.cos())

    def forward(self, x):
        sin_enc = torch.stack([self.sin_enc, self.sin_enc], dim=-1).reshape(
            self.sin_enc.shape[:-1] + (-1,)
        )
        cos_enc = torch.stack([self.cos_enc, self.cos_enc], dim=-1).reshape(
            self.cos_enc.shape[:-1] + (-1,)
        )
        rotated_x = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1).reshape(
            x.shape
        )
        return x * cos_enc + rotated_x * sin_enc

class Model(nn.Module):
    def __init__(
        self,
        vocab_size,
        seq_len,
        embed_dim=256,
        num_heads=8,
        ff_depth=1024,
        act_fn=nn.functional.gelu,
        dropout=0.1
    ):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.rotary = RotaryEmbed(seq_len, embed_dim)
        self.encoder1 = nn.TransformerEncoderLayer(
            embed_dim, num_heads, ff_depth, dropout, act_fn,
            batch_first=True, norm_first=True
        )
        self.encoder2 = nn.TransformerEncoderLayer(
            embed_dim, num_heads, ff_depth, dropout, act_fn,
            batch_first=True, norm_first=True
        )
        self.fc = nn.Linear(embed_dim * seq_len, 1)

    def forward(self, ids, mask):
        embeds = self.embeddings(ids)
        embeds = self.rotary(embeds)
        acts = self.encoder1(embeds, src_key_padding_mask=mask)
        acts = self.rotary(acts)
        acts = self.encoder2(acts, src_key_padding_mask=mask)
        acts = torch.reshape(acts, (len(acts), -1))
        acts = self.fc(acts)
        return torch.squeeze(acts)

model = Model(tokenizer.get_vocab_size(), seq_len, dropout=0.2)
print(model)

Model(
  (embeddings): Embedding(50265, 256)
  (rotary): RotaryEmbed()
  (encoder1): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (linear1): Linear(in_features=256, out_features=1024, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
    (linear2): Linear(in_features=1024, out_features=256, bias=True)
    (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
  )
  (encoder2): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (linear1): Linear(in_features=256, out_features=1024, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
    (linear2): Linear(in_features=1024, out_fe

### Define loss function and optimizer.

In [11]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

### Define training and testing steps.

In [12]:
@torch.compile
def train_step(X, mask, y):
    with torch.enable_grad():
        loss = loss_fn(model(X, mask), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return loss.item()

In [13]:
@torch.compile
def test_step(X, mask, y):
    with torch.no_grad():
        preds = model(X, mask)
        loss = loss_fn(preds, y)
    accurate = (torch.sigmoid(preds).round() == y).type(torch.int).sum()
    return loss.item(), accurate.item()

### Define training and testing loops.

In [14]:
def train(dataloader):
    model.train()
    train_loss = 0.0
    for batch in dataloader:
        X = batch["ids"].to(device)
        mask = batch["mask"].to(device)
        y = batch["label"].type(torch.float).to(device)
        train_loss += train_step(X, mask, y)

    print(f"Train loss: {train_loss / len(dataloader)}")

In [15]:
def test(dataloader):
    model.eval()
    test_loss, accurate = 0.0, 0
    for batch in dataloader:
        X = batch["ids"].to(device)
        mask = batch["mask"].to(device)
        y = batch["label"].type(torch.float).to(device)
        loss, acc = test_step(X, mask, y)
        test_loss += loss
        accurate += acc
    
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [16]:
def train_loop(
    train_dataloader,
    test_dataloader,
    epochs,
    test_every
):
    model.to(device)
    for i in range(epochs):
        epoch = (i + 1)
        print(f"Epoch {epoch}\n----------------")
        train(train_dataloader)
        if (epoch % test_every == 0):
            test(test_dataloader)
        print(f"----------------")

## Train Encoder on the IMDB dataset.

In [17]:
train_loop(train_dataloader, test_dataloader, 30, 5)

Epoch 1
----------------




Train loss: 1.5207095638829835
----------------
Epoch 2
----------------
Train loss: 0.36088826895064235
----------------
Epoch 3
----------------
Train loss: 0.16279021430076385
----------------
Epoch 4
----------------
Train loss: 0.0859131528045602
----------------
Epoch 5
----------------
Train loss: 0.04383768909373226


  return torch._transformer_encoder_layer_fwd(
  return torch._transformer_encoder_layer_fwd(
   function: '<graph break in test_step>' (/tmp/ipykernel_13480/3487953640.py:7)
   reasons:  ___stack0 == 1.8548640012741089
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
  return torch._transformer_encoder_layer_fwd(
  return torch._transformer_encoder_layer_fwd(


Test loss: 1.4226764223101187, accuracy: 0.7324
----------------
Epoch 6
----------------
Train loss: 0.041457442265022924
----------------
Epoch 7
----------------
Train loss: 0.048255698595015446
----------------
Epoch 8
----------------
Train loss: 0.1875324243042922
----------------
Epoch 9
----------------
Train loss: 0.26736389376147063
----------------
Epoch 10
----------------
Train loss: 0.21252334950971702
Test loss: 2.67449328242516, accuracy: 0.8102
----------------
Epoch 11
----------------
Train loss: 0.17429787407932704
----------------
Epoch 12
----------------
Train loss: 0.10291768461592236
----------------
Epoch 13
----------------
Train loss: 0.08850410728447884
----------------
Epoch 14
----------------
Train loss: 0.0574245941491488
----------------
Epoch 15
----------------
Train loss: 0.05901914488610075
Test loss: 3.973797012044459, accuracy: 0.79796
----------------
Epoch 16
----------------
Train loss: 0.06335012340183929
----------------
Epoch 17
-----------