In [10]:
from balm.config import BalmConfig
from balm.data import load_dataset, DataCollator
from balm.models import (
    BalmForMaskedLM,
    BalmMoEForMaskedLM,
    BalmExpertChoiceMoEForMaskedLM,
    BalmHybridMoEForMaskedLM,
)
from balm.tokenizer import Tokenizer
from balm.train import Trainer

import json

import wandb



In [2]:
tokenizer = Tokenizer(vocab="./vocab.json")
# tokenizer = EsmTokenizer.from_pretrained("./vocab.json")



In [3]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

# data_files = {
#     "train": "../train-test-eval_paired/train.txt",
#     "test": "../train-test-eval_paired/test.txt",
#     "eval": "../train-test-eval_paired/eval.txt",
# }

# data_files = {
#     "train": "~/shared/Sarah/training-data/paired-longitudinalHD/data/clust90-split/jaffe_longHD_clust90_train.csv",
#     "test": "~/shared/Sarah/training-data/paired-longitudinalHD/data/clust90-split/jaffe_longHD_clust90_test.csv",
#     "eval": "~/shared/Sarah/training-data/paired-longitudinalHD/data/clust90-split/jaffe_longHD_clust90_eval.csv",
# }

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [4]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text"
)

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

In [5]:
collator = DataCollator(tokenizer=tokenizer)
# collator = DataCollatorForLanguageModeling(
#     tokenizer=tokenizer, 
#     mlm=True,
#     mlm_probability=0.15,
# )



In [6]:
# matched to ESM-2 8M
config = BalmConfig(
    embed_dim=320,
    ffn_dim=320*4,
    num_layers=6,
    num_heads=20,
    vocab_size=tokenizer.vocab_size,
)
# model = BalmForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320*4,
#     num_layers=6,
#     num_heads=20,
#     vocab_size=tokenizer.vocab_size,
# )
model = BalmForMaskedLM(config=config)

# # matched to ESM-2 8M
# model = BalmMoEForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320*4,
#     num_experts=8,
#     num_shared_experts=0,
#     num_layers=6,
#     num_heads=20,
#     alternate_sparsity=True,
#     router_top_k=1,
#     expert_capacity=128,
#     router_z_loss_coef=0.01,
#     router_aux_loss_coef=0.01,
#     vocab_size=tokenizer.vocab_size,
# )

# # matched to ESM-2 8M
# model = BalmExpertChoiceMoEForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320 * 4,
#     num_experts=8,
#     num_shared_experts=0,
#     num_layers=6,
#     num_heads=20,
#     alternate_sparsity=False,
#     expert_capacity=128,
#     router_z_loss_coef=0.01,
#     vocab_size=tokenizer.vocab_size,
# )

In [7]:
model.num_parameters

6294113

In [8]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    output_dir="./training_runs/save_tests",
    epochs=1,
    logging_steps=5,
    eval_steps=10,
    warmup_steps=10,
    save_steps=15,
    per_device_train_batch_size=32,
    # use_cpu=True,
    # use_wandb=True,
    wandb_project="test_wandb_logging",
    # wandb_entity="bryanbriney",
    run_name="save_test_001",
)

# training_args = TrainingArguments(
#     output_dir="~/Desktop/training",
#     logging_dir="~/Desktop/training/log",
#     per_device_train_batch_size=32,
#     learning_rate=1e-4,
#     max_steps=2000,
#     gradient_accumulation_steps=1,
#     logging_steps=10,
#     eval_steps=50,
#     warmup_steps=100,
#     use_cpu=True,
#     report_to="none"
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     data_collator=collator,
#     train_dataset=tokenized_dataset["train"],
#     eval_dataset=tokenized_dataset["eval"],
# )




In [9]:
trainer.output_dir

'/Users/bryanbriney/git/BALM/training_runs/save_tests'

In [9]:
os.environ["WANDB_PROJECT"] = trainer.run_name
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mbryanbriney[0m ([33mthebrineylab[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
trainer.train()

Training:   0%|          | 0/31 [00:00<?, ?step/s]

step 5     | loss: 3.1558 | lr: 0.000200
step 10    | loss: 2.8548 | lr: 0.000400


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<< EVAL >> | loss: 2.8366 | accuracy: 0.2139 | perplexity: 15.6163
step 15    | loss: 2.6944 | lr: 0.000305
<< SAVING MODEL CHECKPOINT >>
step 20    | loss: 2.7200 | lr: 0.000210


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<< EVAL >> | loss: 2.7276 | accuracy: 0.2231 | perplexity: 14.0248
step 25    | loss: 2.6147 | lr: 0.000114
step 30    | loss: 2.6676 | lr: 0.000019


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<< EVAL >> | loss: 2.7190 | accuracy: 0.2226 | perplexity: 13.9186
<< SAVING MODEL CHECKPOINT >>
<< SAVING FINAL MODEL >>

Training complete
