In [1]:
from balm.config import BalmConfig
from balm.data import load_dataset, DataCollator
from balm.models import (
    BalmForMaskedLM,
    BalmMoEForMaskedLM,
    BalmExpertChoiceMoEForMaskedLM,
    BalmHybridMoEForMaskedLM,
)
from balm.tokenizer import Tokenizer
from balm.train import Trainer

In [2]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [3]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

# data_files = {
#     "train": "../train-test-eval_paired/train.txt",
#     "test": "../train-test-eval_paired/test.txt",
#     "eval": "../train-test-eval_paired/eval.txt",
# }

# data_files = {
#     "train": "../jaffe-plusHD_clust0.9_split/train.txt",
#     "test": "../jaffe-plusHD_clust0.9_split/test.txt",
#     "eval": "../jaffe-plusHD_clust0.9_split/eval.txt",
# }

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [4]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text"
)

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

In [5]:
collator = DataCollator(tokenizer=tokenizer)

In [6]:
# matched to ESM-2 8M
config = BalmConfig(
    embed_dim=320,
    ffn_dim=320*4,
    num_layers=6,
    num_heads=20,
    vocab_size=tokenizer.vocab_size,
)
# model = BalmForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320*4,
#     num_layers=6,
#     num_heads=20,
#     vocab_size=tokenizer.vocab_size,
# )
model = BalmForMaskedLM(config=config)

# # matched to ESM-2 8M
# model = BalmMoEForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320*4,
#     num_experts=8,
#     num_shared_experts=0,
#     num_layers=6,
#     num_heads=20,
#     alternate_sparsity=True,
#     router_top_k=1,
#     expert_capacity=128,
#     router_z_loss_coef=0.01,
#     router_aux_loss_coef=0.01,
#     vocab_size=tokenizer.vocab_size,
# )

# # matched to ESM-2 8M
# model = BalmExpertChoiceMoEForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320 * 4,
#     num_experts=8,
#     num_shared_experts=0,
#     num_layers=6,
#     num_heads=20,
#     alternate_sparsity=False,
#     expert_capacity=128,
#     router_z_loss_coef=0.01,
#     vocab_size=tokenizer.vocab_size,
# )

In [7]:
model.num_parameters

6294113

In [8]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    output_dir="./training_runs/save_tests",
    epochs=1,
    logging_steps=5,
    eval_steps=10,
    warmup_steps=10,
    save_steps=15,
    per_device_train_batch_size=32,
    # use_cpu=True,
    # use_wandb=True,
    wandb_project="test_wandb_logging",
    # wandb_entity="bryanbriney",
    run_name="save_test_001",
)

# training_args = TrainingArguments(
#     output_dir="~/Desktop/training",
#     logging_dir="~/Desktop/training/log",
#     per_device_train_batch_size=32,
#     learning_rate=1e-4,
#     max_steps=2000,
#     gradient_accumulation_steps=1,
#     logging_steps=10,
#     eval_steps=50,
#     warmup_steps=100,
#     use_cpu=True,
#     report_to="none"
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     data_collator=collator,
#     train_dataset=tokenized_dataset["train"],
#     eval_dataset=tokenized_dataset["eval"],
# )

In [9]:
trainer.train()

Training:   0%|          | 0/31 [00:00<?, ?step/s]

step 5     | loss: 3.1904 | lr: 0.000200
step 10    | loss: 2.7761 | lr: 0.000400


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<< EVAL >> | loss: 2.8147 | accuracy: 0.2133 | perplexity: 15.2694
step 15    | loss: 2.6738 | lr: 0.000305
<< SAVING MODEL CHECKPOINT >>
step 20    | loss: 2.6841 | lr: 0.000210


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<< EVAL >> | loss: 2.7219 | accuracy: 0.2241 | perplexity: 13.9600
step 25    | loss: 2.6536 | lr: 0.000114
step 30    | loss: 2.6383 | lr: 0.000019


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<< EVAL >> | loss: 2.7187 | accuracy: 0.2252 | perplexity: 13.8820
<< SAVING MODEL CHECKPOINT >>
<< SAVING FINAL MODEL >>

Training complete
