## Transformer Encoder Reference.

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

In [2]:
from datasets import load_dataset
from tokenizers import Tokenizer

### Load in the IMDB dataset.

In [4]:
imdb_train = load_dataset("imdb", cache_dir="../data", split="train")
imdb_test = load_dataset("imdb", cache_dir="../data", split="test")

Found cached dataset imdb (/home/zongyf02/projects/mlax/examples/Encoder/../data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Found cached dataset imdb (/home/zongyf02/projects/mlax/examples/Encoder/../data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


In [5]:
tokenizer = Tokenizer.from_pretrained("roberta-base")
tokenizer.enable_truncation(512)
tokenizer.enable_padding(length=512)

In [6]:
def tokenization(batch):
    encodings = tokenizer.encode_batch(batch["text"])
    batch["ids"] = [encoding.ids for encoding in encodings]
    batch["mask"] = [encoding.attention_mask for encoding in encodings]
    del batch["text"]
    return batch

In [7]:
imdb_train_tokenized = imdb_train.map(
    tokenization, batched=True, batch_size=None
)
imdb_test_tokenized = imdb_test.map(
    tokenization, batched=True, batch_size=None
)

Loading cached processed dataset at /home/zongyf02/projects/mlax/examples/data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-b4fde54fc2b7fb90.arrow
Loading cached processed dataset at /home/zongyf02/projects/mlax/examples/data/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-6c2db6e39d7c9c83.arrow


In [8]:
imdb_train_tokenized.set_format(type="torch")
imdb_test_tokenized.set_format(type="torch")

### Prepare dataloaders.

In [9]:
batch_size = 128
train_dataloader = DataLoader(imdb_train_tokenized, batch_size=batch_size, shuffle=True, num_workers=0)
test_dataloader = DataLoader(imdb_test_tokenized, batch_size=batch_size, num_workers=0)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### Build Encoder model.

In [10]:
class Model(nn.Module):
    def __init__(
        self,
        vocab_size,
        seq_len,
        feature_embed_dim = 248,
        pos_embed_dim = 8,
        num_heads = 8,
        ff_depth = 1024,
        act_fn=nn.functional.gelu,
        dropout=0.1
    ):
        super().__init__()
        self.feature_embeddings = nn.Embedding(vocab_size, feature_embed_dim)
        self.pos_embeddings = nn.Embedding(seq_len, pos_embed_dim)
        self.pos_idx = nn.Parameter(torch.arange(seq_len), requires_grad=False)
        self.dropout = nn.Dropout(dropout)

        embed_dim = feature_embed_dim + pos_embed_dim
        self.encoder1 = nn.TransformerEncoderLayer(
            embed_dim,
            num_heads,
            ff_depth,
            dropout,
            act_fn,
            batch_first=True
        )
        self.encoder2 = nn.TransformerEncoderLayer(
            embed_dim,
            num_heads,
            ff_depth,
            dropout,
            act_fn,
            batch_first=True
        )

        self.fc = nn.Linear(embed_dim * seq_len, 1)

    def forward(self, ids, mask):
        embeddings = self.feature_embeddings(ids)
        pos_embeddings = self.pos_embeddings(
            torch.broadcast_to(self.pos_idx, ids.shape)
        )
        embeddings = torch.cat((embeddings, pos_embeddings), dim=-1)
        embeddings = self.dropout(embeddings)

        activations = self.encoder1(embeddings, src_key_padding_mask=mask)
        activations = self.encoder2(embeddings, src_key_padding_mask=mask)
        activations = torch.reshape(activations, (len(activations), -1))
        activations = self.fc(activations)
        return torch.squeeze(activations)

model = Model(tokenizer.get_vocab_size(), 512)
print(model)
model = torch.jit.script(model)

Model(
  (feature_embeddings): Embedding(50265, 248)
  (pos_embeddings): Embedding(512, 8)
  (dropout): Dropout(p=0.1, inplace=False)
  (encoder1): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (linear1): Linear(in_features=256, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=1024, out_features=256, bias=True)
    (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (encoder2): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (linear1): Linear(in_features=256, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1



### Define loss function and optimizer.

In [16]:
bce = torch.jit.script(nn.BCEWithLogitsLoss())
adamw = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2)

### Define training and testing loops.

In [17]:
def train(dataloader, model, loss_fn, optimizer, device):
    model.to(device)
    model.train()

    train_loss = 0.0
    for batch in dataloader:
        loss = loss_fn(
            model(
                batch["ids"].to(device),
                torch.logical_not(batch["mask"].type(torch.bool)).to(device)
            ),
            batch["label"].type(torch.float).to(device)
        )
        train_loss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Train loss: {train_loss / len(dataloader)}")

In [18]:
def test(dataloader, model, loss_fn, device):
    model.to(device)
    model.eval()

    test_loss, accurate = 0.0, 0
    with torch.no_grad():
        for batch in dataloader:
            y = batch["label"].type(torch.float).to(device)
            pred = model(
                batch["ids"].to(device),
                torch.logical_not(batch["mask"].type(torch.bool)).to(device)
            )
            test_loss += loss_fn(pred, y).item()
            accurate += (torch.sigmoid(pred).round() == y).sum().item()
    print(f"Test loss: {test_loss / len(dataloader)}, accuracy: {accurate / len(dataloader.dataset)}")

In [19]:
def train_loop(
    train_dataloader,
    test_dataloader,
    model, loss_fn, optimizer,
    device,
    epochs, test_every):
    for i in range(epochs):
        epoch = (i + 1)
        print(f"Epoch {epoch}\n----------------")
        train(train_dataloader, model, loss_fn, optimizer, device)
        if (epoch % test_every == 0):
            test(test_dataloader, model, loss_fn, device)
        print(f"----------------")

## Train Encoder on the IMDB dataset.

In [20]:
train_loop(train_dataloader, test_dataloader, model, bce, adamw, device, 10, 1)

Epoch 1
----------------
Train loss: 0.026605423539876938
Test loss: 1.117602876862701, accuracy: 0.80508
----------------
Epoch 2
----------------
Train loss: 0.024192167446017265
Test loss: 1.2074286858646237, accuracy: 0.79644
----------------
Epoch 3
----------------
Train loss: 0.03228404372930527
Test loss: 1.3317240836973092, accuracy: 0.80936
----------------
Epoch 4
----------------
Train loss: 0.030083132907748222
Test loss: 1.3468441015907697, accuracy: 0.81212
----------------
Epoch 5
----------------
Train loss: 0.022066554054617882
Test loss: 1.4045106285080617, accuracy: 0.81892
----------------
Epoch 6
----------------
Train loss: 0.014758866280317307
Test loss: 1.4636627993717486, accuracy: 0.80676
----------------
Epoch 7
----------------
Train loss: 0.01156957633793354
Test loss: 1.4085638086710657, accuracy: 0.8092
----------------
Epoch 8
----------------
Train loss: 0.009076050482690334
Test loss: 1.4104419175763518, accuracy: 0.81156
----------------
Epoch 9
----