In [1]:
from balm.data import load_dataset, DataCollator
from balm.models.balm import BalmForMaskedLM
from balm.models.balm_moe import BalmMoEForMaskedLM
from balm.models.balm_moe_rope import BalmMoERoPEForMaskedLM
from balm.tokenizer import Tokenizer

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader



In [2]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [3]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [4]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    # remove_columns="text"
)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [5]:
collator = DataCollator(tokenizer=tokenizer)



In [6]:
train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=32,
    shuffle=True,
)



In [8]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")



In [11]:
# model = BalmMoERoPEForMaskedLM(
model = BalmMoEForMaskedLM(
    embed_dim=256,
    ffn_dim=1024,
    num_experts=4,
    num_layers=8,
    num_heads=8,
    expert_capacity=128,
    router_z_loss_coef=0.01,
    router_aux_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)

# model =  BalmForMaskedLM(
#     num_layers=6,
#     num_heads=8,
#     vocab_size=tokenizer.vocab_size,
#     max_length=320,
#     attention_dropout=0.1,
#     attention_batch_first=True,
#     layer_norm_eps=1e-5,
# )


# wrapped_model = nn.DataParallel(model)
model = model.to(device)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)


optimizer = optim.Adam(model.parameters(), lr=4e-4)



In [12]:
model.num_parameters

18982689

In [13]:
n_epochs = 10

model.train()
# pbar = tqdm(total=len(train_dataloader) * train_dataloader.batch_size * n_epochs)
pbar = tqdm(total=len(train_dataloader) * n_epochs)
n_steps = 0
pbar.reset()
for epoch in range(n_epochs):
    for examples in train_dataloader:
        optimizer.zero_grad()
        collated = collator(examples["input_ids"])

        input_ids = collated["input_ids"].to(device)
        labels = collated.get("labels", None)
        if labels is not None:
            labels = labels.to(device)
        attn_mask = collated.get("attention_mask", None)
        if attn_mask is not None:
            attn_mask = attn_mask.to(device)
        key_padding_mask = collated.get("key_padding_mask", None)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.to(device)


        outputs = model(
            input_ids=input_ids,
            labels=labels,
            key_padding_mask=key_padding_mask,
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        # pbar.update(train_dataloader.batch_size)
        pbar.update(1)
        pbar.refresh()
        n_steps += 1
        if n_steps % 10 == 0:
            print(
                f"step: {n_steps}, total Loss: {loss.item():.4f}, LM loss: {outputs.lm_loss.item():.4f}, router z loss: {outputs.router_z_loss.item():.4f}, router aux loss: {outputs.router_aux_loss.item():.4f}  "
            )
            # print(
            #     f"step: {n_steps}, total Loss: {loss.item():.4f}  "
            # )

  0%|          | 0/320 [00:00<?, ?it/s]

step: 10, total Loss: 2.6946, LM loss: 2.6728, router z loss: 0.0118, router aux loss: 0.0100  
step: 20, total Loss: 2.5289, LM loss: 2.5089, router z loss: 0.0098, router aux loss: 0.0101  
step: 30, total Loss: 2.2836, LM loss: 2.2651, router z loss: 0.0085, router aux loss: 0.0100  
step: 40, total Loss: 2.1535, LM loss: 2.1364, router z loss: 0.0070, router aux loss: 0.0100  
step: 50, total Loss: 2.1300, LM loss: 2.1136, router z loss: 0.0063, router aux loss: 0.0100  
step: 60, total Loss: 2.1334, LM loss: 2.1179, router z loss: 0.0055, router aux loss: 0.0100  
step: 70, total Loss: 1.9980, LM loss: 1.9835, router z loss: 0.0044, router aux loss: 0.0101  
step: 80, total Loss: 2.1106, LM loss: 2.0962, router z loss: 0.0044, router aux loss: 0.0101  
step: 90, total Loss: 1.9247, LM loss: 1.9100, router z loss: 0.0047, router aux loss: 0.0100  
step: 100, total Loss: 1.9263, LM loss: 1.9122, router z loss: 0.0040, router aux loss: 0.0100  
step: 110, total Loss: 1.9217, LM loss: