In [None]:
from balm.data import load_dataset, DataCollator
from balm.models import (
    BalmForMaskedLM,
    BalmMoEForMaskedLM,
    BalmExpertChoiceMoEForMaskedLM,
    BalmHybridMoEForMaskedLM,
)
from balm.tokenizer import Tokenizer
from balm.train import Trainer

In [None]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [None]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [None]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text"
)

In [None]:
collator = DataCollator(tokenizer=tokenizer)

In [11]:
# matched to ESM-2 8M
model = BalmForMaskedLM(
    embed_dim=320,
    ffn_dim=320*4,
    num_layers=6,
    num_heads=20,
    vocab_size=tokenizer.vocab_size,
)

# # matched to ESM-2 8M
# model = BalmMoEForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320*4,
#     num_experts=8,
#     num_shared_experts=0,
#     num_layers=6,
#     num_heads=20,
#     alternate_sparsity=True,
#     router_top_k=1,
#     expert_capacity=128,
#     router_z_loss_coef=0.01,
#     router_aux_loss_coef=0.01,
#     vocab_size=tokenizer.vocab_size,
# )

# # matched to ESM-2 8M
# model = BalmExpertChoiceMoEForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320 * 4,
#     num_experts=8,
#     num_shared_experts=0,
#     num_layers=6,
#     num_heads=20,
#     alternate_sparsity=False,
#     expert_capacity=128,
#     router_z_loss_coef=0.01,
#     vocab_size=tokenizer.vocab_size,
# )

In [12]:
model.num_parameters

6294113

In [13]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    epochs=1,
    logging_steps=10,
    eval_steps=50,
    warmup_steps=50,
    per_device_train_batch_size=32,
    # use_cpu=True,
)

In [14]:
trainer.device

device(type='cpu')

In [15]:
trainer.train()

Training:   0%|          | 0/2087 [00:00<?, ?step/s]

step 10   | loss: 3.1020 | lr: 0.000080
step 20   | loss: 2.7443 | lr: 0.000160
step 30   | loss: 2.6527 | lr: 0.000240
step 40   | loss: 2.6453 | lr: 0.000320
step 50   | loss: 2.5886 | lr: 0.000400


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 2.6825
step 60   | loss: 2.5415 | lr: 0.000398
step 70   | loss: 2.2838 | lr: 0.000396
step 80   | loss: 2.2063 | lr: 0.000394
step 90   | loss: 2.1203 | lr: 0.000392
step 100  | loss: 2.1287 | lr: 0.000390


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 2.1314
step 110  | loss: 1.9825 | lr: 0.000388
step 120  | loss: 1.9320 | lr: 0.000386
step 130  | loss: 1.9641 | lr: 0.000384
step 140  | loss: 1.9640 | lr: 0.000382
step 150  | loss: 1.8985 | lr: 0.000380


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 1.9031
step 160  | loss: 1.8407 | lr: 0.000378
step 170  | loss: 1.7660 | lr: 0.000376
step 180  | loss: 1.6834 | lr: 0.000374
step 190  | loss: 1.6126 | lr: 0.000373
step 200  | loss: 1.6124 | lr: 0.000371


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 1.6210
step 210  | loss: 1.5491 | lr: 0.000369
step 220  | loss: 1.5128 | lr: 0.000367
step 230  | loss: 1.4577 | lr: 0.000365
step 240  | loss: 1.4858 | lr: 0.000363
step 250  | loss: 1.4321 | lr: 0.000361


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 1.4053
step 260  | loss: 1.3326 | lr: 0.000359
step 270  | loss: 1.4082 | lr: 0.000357
step 280  | loss: 1.2744 | lr: 0.000355
step 290  | loss: 1.2189 | lr: 0.000353
step 300  | loss: 1.1393 | lr: 0.000351


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 1.2633
step 310  | loss: 1.1223 | lr: 0.000349
step 320  | loss: 1.2093 | lr: 0.000347
step 330  | loss: 1.2347 | lr: 0.000345
step 340  | loss: 1.1423 | lr: 0.000343
step 350  | loss: 1.1128 | lr: 0.000341


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 1.1003
step 360  | loss: 1.0664 | lr: 0.000339
step 370  | loss: 1.0037 | lr: 0.000337
step 380  | loss: 1.0089 | lr: 0.000335
step 390  | loss: 1.0358 | lr: 0.000333
step 400  | loss: 0.9334 | lr: 0.000331


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.9782
step 410  | loss: 0.9920 | lr: 0.000329
step 420  | loss: 0.8711 | lr: 0.000327
step 430  | loss: 0.9032 | lr: 0.000325
step 440  | loss: 0.8752 | lr: 0.000323
step 450  | loss: 0.9360 | lr: 0.000321


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.8778
step 460  | loss: 0.8293 | lr: 0.000319
step 470  | loss: 0.9083 | lr: 0.000318
step 480  | loss: 0.8401 | lr: 0.000316
step 490  | loss: 0.8241 | lr: 0.000314
step 500  | loss: 0.7942 | lr: 0.000312


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.8291
step 510  | loss: 0.8690 | lr: 0.000310
step 520  | loss: 0.7452 | lr: 0.000308
step 530  | loss: 0.7309 | lr: 0.000306
step 540  | loss: 0.6372 | lr: 0.000304
step 550  | loss: 0.7340 | lr: 0.000302


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.7571
step 560  | loss: 0.7693 | lr: 0.000300
step 570  | loss: 0.6750 | lr: 0.000298
step 580  | loss: 0.6693 | lr: 0.000296
step 590  | loss: 0.6618 | lr: 0.000294
step 600  | loss: 0.6492 | lr: 0.000292


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.7068
step 610  | loss: 0.6381 | lr: 0.000290
step 620  | loss: 0.7049 | lr: 0.000288
step 630  | loss: 0.7305 | lr: 0.000286
step 640  | loss: 0.6417 | lr: 0.000284
step 650  | loss: 0.6464 | lr: 0.000282


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.6615
step 660  | loss: 0.6438 | lr: 0.000280
step 670  | loss: 0.5524 | lr: 0.000278
step 680  | loss: 0.5674 | lr: 0.000276
step 690  | loss: 0.6030 | lr: 0.000274
step 700  | loss: 0.6130 | lr: 0.000272


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.6271
step 710  | loss: 0.5672 | lr: 0.000270
step 720  | loss: 0.6053 | lr: 0.000268
step 730  | loss: 0.5970 | lr: 0.000266
step 740  | loss: 0.5577 | lr: 0.000265
step 750  | loss: 0.5638 | lr: 0.000263


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.6003
step 760  | loss: 0.5937 | lr: 0.000261
step 770  | loss: 0.6897 | lr: 0.000259
step 780  | loss: 0.5801 | lr: 0.000257
step 790  | loss: 0.4468 | lr: 0.000255
step 800  | loss: 0.5985 | lr: 0.000253


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.5756
step 810  | loss: 0.5855 | lr: 0.000251
step 820  | loss: 0.6151 | lr: 0.000249
step 830  | loss: 0.5341 | lr: 0.000247
step 840  | loss: 0.5753 | lr: 0.000245
step 850  | loss: 0.5145 | lr: 0.000243


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.5505
step 860  | loss: 0.5101 | lr: 0.000241
step 870  | loss: 0.5521 | lr: 0.000239
step 880  | loss: 0.5879 | lr: 0.000237
step 890  | loss: 0.5239 | lr: 0.000235
step 900  | loss: 0.5155 | lr: 0.000233


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.5304
step 910  | loss: 0.4332 | lr: 0.000231
step 920  | loss: 0.5433 | lr: 0.000229
step 930  | loss: 0.5131 | lr: 0.000227
step 940  | loss: 0.5850 | lr: 0.000225
step 950  | loss: 0.5835 | lr: 0.000223


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.5268
step 960  | loss: 0.4257 | lr: 0.000221
step 970  | loss: 0.5519 | lr: 0.000219
step 980  | loss: 0.4773 | lr: 0.000217
step 990  | loss: 0.4797 | lr: 0.000215
step 1000 | loss: 0.4934 | lr: 0.000213


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.5149
step 1010 | loss: 0.6018 | lr: 0.000211
step 1020 | loss: 0.5279 | lr: 0.000210
step 1030 | loss: 0.5409 | lr: 0.000208
step 1040 | loss: 0.4465 | lr: 0.000206
step 1050 | loss: 0.5558 | lr: 0.000204


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4867
step 1060 | loss: 0.4211 | lr: 0.000202
step 1070 | loss: 0.4368 | lr: 0.000200
step 1080 | loss: 0.5023 | lr: 0.000198
step 1090 | loss: 0.4219 | lr: 0.000196
step 1100 | loss: 0.5235 | lr: 0.000194


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.5042
step 1110 | loss: 0.4913 | lr: 0.000192
step 1120 | loss: 0.4468 | lr: 0.000190
step 1130 | loss: 0.4980 | lr: 0.000188
step 1140 | loss: 0.6059 | lr: 0.000186
step 1150 | loss: 0.4300 | lr: 0.000184


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4681
step 1160 | loss: 0.6297 | lr: 0.000182
step 1170 | loss: 0.5152 | lr: 0.000180
step 1180 | loss: 0.3479 | lr: 0.000178
step 1190 | loss: 0.4011 | lr: 0.000176
step 1200 | loss: 0.4574 | lr: 0.000174


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4680
step 1210 | loss: 0.3992 | lr: 0.000172
step 1220 | loss: 0.4974 | lr: 0.000170
step 1230 | loss: 0.5163 | lr: 0.000168
step 1240 | loss: 0.3897 | lr: 0.000166
step 1250 | loss: 0.5078 | lr: 0.000164


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4421
step 1260 | loss: 0.4292 | lr: 0.000162
step 1270 | loss: 0.4816 | lr: 0.000160
step 1280 | loss: 0.3368 | lr: 0.000158
step 1290 | loss: 0.3588 | lr: 0.000157
step 1300 | loss: 0.4232 | lr: 0.000155


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4475
step 1310 | loss: 0.4150 | lr: 0.000153
step 1320 | loss: 0.3775 | lr: 0.000151
step 1330 | loss: 0.4169 | lr: 0.000149
step 1340 | loss: 0.5353 | lr: 0.000147
step 1350 | loss: 0.3301 | lr: 0.000145


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4350
step 1360 | loss: 0.4034 | lr: 0.000143
step 1370 | loss: 0.3313 | lr: 0.000141
step 1380 | loss: 0.4552 | lr: 0.000139
step 1390 | loss: 0.4441 | lr: 0.000137
step 1400 | loss: 0.4393 | lr: 0.000135


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4329
step 1410 | loss: 0.4099 | lr: 0.000133
step 1420 | loss: 0.3280 | lr: 0.000131
step 1430 | loss: 0.4643 | lr: 0.000129
step 1440 | loss: 0.4185 | lr: 0.000127
step 1450 | loss: 0.4268 | lr: 0.000125


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4182
step 1460 | loss: 0.3901 | lr: 0.000123
step 1470 | loss: 0.3078 | lr: 0.000121
step 1480 | loss: 0.3784 | lr: 0.000119
step 1490 | loss: 0.5291 | lr: 0.000117
step 1500 | loss: 0.4727 | lr: 0.000115


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4229
step 1510 | loss: 0.5057 | lr: 0.000113
step 1520 | loss: 0.3756 | lr: 0.000111
step 1530 | loss: 0.3378 | lr: 0.000109
step 1540 | loss: 0.4084 | lr: 0.000107
step 1550 | loss: 0.4142 | lr: 0.000105


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4152
step 1560 | loss: 0.4156 | lr: 0.000103
step 1570 | loss: 0.5196 | lr: 0.000102
step 1580 | loss: 0.4002 | lr: 0.000100
step 1590 | loss: 0.3420 | lr: 0.000098
step 1600 | loss: 0.3743 | lr: 0.000096


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4134
step 1610 | loss: 0.4204 | lr: 0.000094
step 1620 | loss: 0.3372 | lr: 0.000092
step 1630 | loss: 0.2907 | lr: 0.000090
step 1640 | loss: 0.3942 | lr: 0.000088
step 1650 | loss: 0.3021 | lr: 0.000086


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.3983
step 1660 | loss: 0.4938 | lr: 0.000084
step 1670 | loss: 0.4348 | lr: 0.000082
step 1680 | loss: 0.3812 | lr: 0.000080
step 1690 | loss: 0.3805 | lr: 0.000078
step 1700 | loss: 0.3646 | lr: 0.000076


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.3968
step 1710 | loss: 0.3181 | lr: 0.000074
step 1720 | loss: 0.3736 | lr: 0.000072
step 1730 | loss: 0.4564 | lr: 0.000070
step 1740 | loss: 0.4923 | lr: 0.000068
step 1750 | loss: 0.3680 | lr: 0.000066


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.4009
step 1760 | loss: 0.3021 | lr: 0.000064
step 1770 | loss: 0.3456 | lr: 0.000062
step 1780 | loss: 0.3505 | lr: 0.000060
step 1790 | loss: 0.4145 | lr: 0.000058
step 1800 | loss: 0.3243 | lr: 0.000056


Evaluating:   0%|          | 0/31 [00:00<?, ?step/s]

<<< EVAL >>> loss: 0.3980
step 1810 | loss: 0.3906 | lr: 0.000054
step 1820 | loss: 0.3631 | lr: 0.000052
step 1830 | loss: 0.3979 | lr: 0.000050
step 1840 | loss: 0.3852 | lr: 0.000049


KeyboardInterrupt: 