In [1]:
from balm.data import load_dataset, DataCollator
from balm.models.balm import BalmForMaskedLM
from balm.models.balm_moe import BalmMoEForMaskedLM
from balm.models.balm_moe_rope import BalmMoERoPEForMaskedLM
from balm.tokenizer import Tokenizer

from tqdm.auto import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

In [2]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [3]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [4]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    # remove_columns="text"
)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [5]:
collator = DataCollator(tokenizer=tokenizer)



In [10]:
train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=32,
    shuffle=True,
)



In [21]:
model = BalmMoERoPEForMaskedLM(
# model = BalmMoEForMaskedLM(
    embed_dim=256,
    ffn_dim=1024,
    num_experts=4,
    num_layers=8,
    num_heads=8,
    expert_capacity=128,
    router_z_loss_coef=0.01,
    router_aux_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)

optimizer = optim.Adam(model.parameters(), lr=4e-4)



In [22]:
model.num_parameters

18982689

In [23]:
n_epochs = 10

model.train()
# pbar = tqdm(total=len(train_dataloader) * train_dataloader.batch_size * n_epochs)
pbar = tqdm(total=len(train_dataloader) * n_epochs)
n_steps = 0
pbar.reset()
for epoch in range(n_epochs):
    for examples in train_dataloader:
        optimizer.zero_grad()
        collated = collator(examples["input_ids"])
        outputs = model(
            input_ids=collated["input_ids"],
            labels=collated.get("labels", None),
            key_padding_mask=collated.get("attention_mask", None),
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        # pbar.update(train_dataloader.batch_size)
        pbar.update(1)
        pbar.refresh()
        n_steps += 1
        if n_steps % 5 == 0:
            print(
                f"step: {n_steps}, total Loss: {loss.item():.4f}, LM loss: {outputs.lm_loss.item():.4f}, router z loss: {outputs.router_z_loss.item():.4f}, router aux loss: {outputs.router_aux_loss.item():.4f}  "
            )

  0%|          | 0/320 [00:00<?, ?it/s]

step: 5, total Loss: 2.9079, LM loss: 2.8777, router z loss: 0.0200, router aux loss: 0.0102  
step: 10, total Loss: 2.7819, LM loss: 2.7577, router z loss: 0.0139, router aux loss: 0.0102  
step: 15, total Loss: 2.6750, LM loss: 2.6553, router z loss: 0.0093, router aux loss: 0.0103  
step: 20, total Loss: 2.6537, LM loss: 2.6370, router z loss: 0.0064, router aux loss: 0.0103  
step: 25, total Loss: 2.6363, LM loss: 2.6218, router z loss: 0.0042, router aux loss: 0.0103  
step: 30, total Loss: 2.6445, LM loss: 2.6313, router z loss: 0.0030, router aux loss: 0.0102  
step: 35, total Loss: 2.6384, LM loss: 2.6260, router z loss: 0.0021, router aux loss: 0.0102  
step: 40, total Loss: 2.6215, LM loss: 2.6099, router z loss: 0.0015, router aux loss: 0.0101  
step: 45, total Loss: 2.6362, LM loss: 2.6248, router z loss: 0.0012, router aux loss: 0.0101  
step: 50, total Loss: 2.6394, LM loss: 2.6281, router z loss: 0.0011, router aux loss: 0.0102  
step: 55, total Loss: 2.6453, LM loss: 2.