In [1]:
from balm.data import load_dataset, DataCollator
from balm.models import (
    BalmForMaskedLM,
    BalmMoEForMaskedLM,
    BalmExpertChoiceMoEForMaskedLM,
    BalmHybridMoEForMaskedLM,
)
from balm.tokenizer import Tokenizer
# from balm.train import Trainer

from datasets import load_dataset, DataCollator
from transformers import Trainer, TrainingArguments




In [2]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [3]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [4]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text"
)

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

In [5]:
collator = DataCollator(tokenizer=tokenizer)

In [6]:
# # matched to ESM-2 8M
# model = BalmForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320*4,
#     num_layers=6,
#     num_heads=20,
#     vocab_size=tokenizer.vocab_size,
# )

# # matched to ESM-2 8M
# model = BalmMoEForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320*4,
#     num_experts=8,
#     num_shared_experts=0,
#     num_layers=6,
#     num_heads=20,
#     alternate_sparsity=True,
#     router_top_k=1,
#     expert_capacity=128,
#     router_z_loss_coef=0.01,
#     router_aux_loss_coef=0.01,
#     vocab_size=tokenizer.vocab_size,
# )

# matched to ESM-2 8M
model = BalmExpertChoiceMoEForMaskedLM(
    embed_dim=320,
    ffn_dim=320 * 4,
    num_experts=8,
    num_shared_experts=0,
    num_layers=6,
    num_heads=20,
    alternate_sparsity=False,
    expert_capacity=128,
    router_z_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)

In [7]:
model.num_parameters

41935073

In [8]:
# trainer = Trainer(
#     model=model,
#     data_collator=collator,
#     train_dataset=tokenized_dataset["train"],
#     eval_dataset=tokenized_dataset["eval"],
#     epochs=1,
#     logging_steps=10,
#     eval_steps=50,
#     warmup_steps=50,
#     per_device_train_batch_size=32,
#     # use_cpu=True,
# )

training_args = TrainingArguments(
    output_dir="~/Desktop/training",
    logging_dir="~/Desktop/training/log",
    per_device_train_batch_size=32,
    learning_rate=1e-4,
    max_steps=2000,
    gradient_accumulation_steps=1,
    logging_steps=10,
    eval_steps=50,
    warmup_steps=100,
    use_cpu=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
)



In [9]:
trainer.train()

  0%|          | 0/2000 [00:00<?, ?it/s]

TypeError: expected Tensor as element 0 in argument 0, but got dict