In [1]:
from balm.data import load_dataset, DataCollator
from balm.models import (
    BalmForMaskedLM,
    BalmMoEForMaskedLM,
    BalmExpertChoiceMoEForMaskedLM,
    BalmHybridMoEForMaskedLM,
)
from balm.tokenizer import Tokenizer
from balm.train import Trainer

In [2]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [3]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [4]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text"
)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [5]:
collator = DataCollator(tokenizer=tokenizer)

In [6]:
# model = BalmForMaskedLM(
#     embed_dim=256,
#     ffn_dim=1024,
#     num_layers=8,
#     num_heads=8,
#     vocab_size=tokenizer.vocab_size,
# )

# model = BalmMoEForMaskedLM(
#     embed_dim=256,
#     ffn_dim=1024,
#     num_experts=4,
#     num_shared_experts=0,
#     num_layers=8,
#     num_heads=8,
#     alternate_sparsity=True,
#     router_top_k=1,
#     expert_capacity=128,
#     router_z_loss_coef=0.01,
#     router_aux_loss_coef=0.01,
#     vocab_size=tokenizer.vocab_size,
# )

model = BalmExpertChoiceMoEForMaskedLM(
    embed_dim=256,
    ffn_dim=1024,
    num_experts=4,
    num_shared_experts=1,
    num_layers=8,
    num_heads=8,
    alternate_sparsity=False,
    expert_capacity=128,
    router_z_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)

In [7]:
model.num_parameters

18980641

In [8]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    epochs=1,
    logging_steps=10,
    eval_steps=50,
    warmup_steps=50,
    per_device_train_batch_size=32,
    # use_cpu=True,
)

In [9]:
trainer.device

device(type='cpu')

In [10]:
trainer.train()

  0%|          | 0/31 [00:00<?, ?it/s]

step 10 | loss: 2.8997 | lm_loss: 2.8862 | router_z_loss: 0.0135 | lr: 0.000080
step 20 | loss: 2.7035 | lm_loss: 2.6962 | router_z_loss: 0.0073 | lr: 0.000160
step 30 | loss: 2.5979 | lm_loss: 2.5938 | router_z_loss: 0.0041 | lr: 0.000240
Training complete
