In [1]:
from balm.data import load_dataset, DataCollator
from balm.models import BalmMoEForMaskedLM
from balm.tokenizer import Tokenizer
from balm.train import Trainer

In [2]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [3]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [4]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text"
)

  0%|          | 0/66792 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [5]:
collator = DataCollator(tokenizer=tokenizer)



In [6]:
# model = BalmMoERoPEForMaskedLM(
model = BalmMoEForMaskedLM(
    embed_dim=256,
    ffn_dim=1024,
    num_experts=4,
    num_shared_experts=0,
    num_layers=8,
    num_heads=8,
    alternate_sparsity=True,
    router_top_k=1,
    # router_class=ExpertChoiceRouter,
    expert_capacity=128,
    # expert_capacity=128,
    router_z_loss_coef=0.01,
    # router_aux_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)

In [7]:
model.num_parameters

12692257

In [8]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    epochs=1,
    logging_steps=10,
    eval_steps=50,
    warmup_steps=50,
    per_device_train_batch_size=32,
    # per_device_eval_batch_size=32,
    use_cpu=True,
    # compute_metrics=compute_metrics,
)

In [9]:
trainer.device

device(type='cpu')

In [10]:
trainer.train()

  0%|          | 0/2087 [00:00<?, ?it/s]

step 10   | loss: 3.0536 | lm_loss: 3.0292 | router_z_loss: 0.0223 | router_aux_loss: 0.0022 | lr: 0.000080
step 20   | loss: 2.8129 | lm_loss: 2.7974 | router_z_loss: 0.0134 | router_aux_loss: 0.0021 | lr: 0.000160
step 30   | loss: 2.6459 | lm_loss: 2.6383 | router_z_loss: 0.0056 | router_aux_loss: 0.0019 | lr: 0.000240
step 40   | loss: 2.4792 | lm_loss: 2.4734 | router_z_loss: 0.0039 | router_aux_loss: 0.0019 | lr: 0.000320
step 50   | loss: 2.2892 | lm_loss: 2.2847 | router_z_loss: 0.0026 | router_aux_loss: 0.0019 | lr: 0.000400


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 2.2572
step 60   | loss: 2.1936 | lm_loss: 2.1902 | router_z_loss: 0.0018 | router_aux_loss: 0.0017 | lr: 0.000398
step 70   | loss: 2.0466 | lm_loss: 2.0437 | router_z_loss: 0.0014 | router_aux_loss: 0.0016 | lr: 0.000396
step 80   | loss: 2.1281 | lm_loss: 2.1253 | router_z_loss: 0.0012 | router_aux_loss: 0.0016 | lr: 0.000394
step 90   | loss: 2.0680 | lm_loss: 2.0652 | router_z_loss: 0.0012 | router_aux_loss: 0.0016 | lr: 0.000392
step 100  | loss: 2.0606 | lm_loss: 2.0583 | router_z_loss: 0.0009 | router_aux_loss: 0.0014 | lr: 0.000390


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 2.0198
step 110  | loss: 1.9348 | lm_loss: 1.9323 | router_z_loss: 0.0012 | router_aux_loss: 0.0013 | lr: 0.000388
step 120  | loss: 1.9943 | lm_loss: 1.9921 | router_z_loss: 0.0009 | router_aux_loss: 0.0013 | lr: 0.000386
step 130  | loss: 2.0311 | lm_loss: 2.0289 | router_z_loss: 0.0009 | router_aux_loss: 0.0013 | lr: 0.000384
step 140  | loss: 1.9543 | lm_loss: 1.9521 | router_z_loss: 0.0008 | router_aux_loss: 0.0014 | lr: 0.000382
step 150  | loss: 1.9783 | lm_loss: 1.9763 | router_z_loss: 0.0007 | router_aux_loss: 0.0013 | lr: 0.000380


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.9263
step 160  | loss: 2.0417 | lm_loss: 2.0399 | router_z_loss: 0.0006 | router_aux_loss: 0.0012 | lr: 0.000378
step 170  | loss: 1.9179 | lm_loss: 1.9161 | router_z_loss: 0.0006 | router_aux_loss: 0.0012 | lr: 0.000376
step 180  | loss: 1.9228 | lm_loss: 1.9208 | router_z_loss: 0.0007 | router_aux_loss: 0.0012 | lr: 0.000374
step 190  | loss: 1.8958 | lm_loss: 1.8940 | router_z_loss: 0.0006 | router_aux_loss: 0.0012 | lr: 0.000373
step 200  | loss: 1.8689 | lm_loss: 1.8672 | router_z_loss: 0.0005 | router_aux_loss: 0.0011 | lr: 0.000371


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.8689
step 210  | loss: 1.8817 | lm_loss: 1.8800 | router_z_loss: 0.0006 | router_aux_loss: 0.0011 | lr: 0.000369
step 220  | loss: 1.8877 | lm_loss: 1.8862 | router_z_loss: 0.0005 | router_aux_loss: 0.0011 | lr: 0.000367
step 230  | loss: 1.7258 | lm_loss: 1.7238 | router_z_loss: 0.0009 | router_aux_loss: 0.0011 | lr: 0.000365
step 240  | loss: 1.8583 | lm_loss: 1.8568 | router_z_loss: 0.0004 | router_aux_loss: 0.0011 | lr: 0.000363
step 250  | loss: 1.8388 | lm_loss: 1.8372 | router_z_loss: 0.0005 | router_aux_loss: 0.0011 | lr: 0.000361


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.7606
step 260  | loss: 1.7136 | lm_loss: 1.7119 | router_z_loss: 0.0007 | router_aux_loss: 0.0011 | lr: 0.000359
step 270  | loss: 1.7217 | lm_loss: 1.7201 | router_z_loss: 0.0006 | router_aux_loss: 0.0011 | lr: 0.000357
step 280  | loss: 1.5930 | lm_loss: 1.5913 | router_z_loss: 0.0006 | router_aux_loss: 0.0011 | lr: 0.000355
step 290  | loss: 1.6847 | lm_loss: 1.6833 | router_z_loss: 0.0005 | router_aux_loss: 0.0010 | lr: 0.000353
step 300  | loss: 1.6752 | lm_loss: 1.6737 | router_z_loss: 0.0005 | router_aux_loss: 0.0010 | lr: 0.000351


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.6471
step 310  | loss: 1.5373 | lm_loss: 1.5358 | router_z_loss: 0.0005 | router_aux_loss: 0.0010 | lr: 0.000349
step 320  | loss: 1.6273 | lm_loss: 1.6259 | router_z_loss: 0.0004 | router_aux_loss: 0.0010 | lr: 0.000347
step 330  | loss: 1.6251 | lm_loss: 1.6236 | router_z_loss: 0.0005 | router_aux_loss: 0.0010 | lr: 0.000345
step 340  | loss: 1.5961 | lm_loss: 1.5948 | router_z_loss: 0.0004 | router_aux_loss: 0.0009 | lr: 0.000343
step 350  | loss: 1.5413 | lm_loss: 1.5399 | router_z_loss: 0.0004 | router_aux_loss: 0.0009 | lr: 0.000341


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.5161
step 360  | loss: 1.4912 | lm_loss: 1.4897 | router_z_loss: 0.0005 | router_aux_loss: 0.0010 | lr: 0.000339
step 370  | loss: 1.4821 | lm_loss: 1.4806 | router_z_loss: 0.0005 | router_aux_loss: 0.0010 | lr: 0.000337
step 380  | loss: 1.4715 | lm_loss: 1.4701 | router_z_loss: 0.0004 | router_aux_loss: 0.0009 | lr: 0.000335
step 390  | loss: 1.5111 | lm_loss: 1.5097 | router_z_loss: 0.0005 | router_aux_loss: 0.0009 | lr: 0.000333
step 400  | loss: 1.3675 | lm_loss: 1.3661 | router_z_loss: 0.0005 | router_aux_loss: 0.0009 | lr: 0.000331


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.3839
step 410  | loss: 1.2987 | lm_loss: 1.2974 | router_z_loss: 0.0004 | router_aux_loss: 0.0009 | lr: 0.000329
step 420  | loss: 1.3694 | lm_loss: 1.3680 | router_z_loss: 0.0005 | router_aux_loss: 0.0009 | lr: 0.000327
step 430  | loss: 1.3263 | lm_loss: 1.3250 | router_z_loss: 0.0004 | router_aux_loss: 0.0009 | lr: 0.000325
step 440  | loss: 1.2712 | lm_loss: 1.2700 | router_z_loss: 0.0004 | router_aux_loss: 0.0008 | lr: 0.000323
step 450  | loss: 1.3931 | lm_loss: 1.3919 | router_z_loss: 0.0003 | router_aux_loss: 0.0008 | lr: 0.000321


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.2469
step 460  | loss: 1.2199 | lm_loss: 1.2186 | router_z_loss: 0.0005 | router_aux_loss: 0.0008 | lr: 0.000319
step 470  | loss: 1.2565 | lm_loss: 1.2554 | router_z_loss: 0.0004 | router_aux_loss: 0.0007 | lr: 0.000318
step 480  | loss: 1.2974 | lm_loss: 1.2963 | router_z_loss: 0.0004 | router_aux_loss: 0.0007 | lr: 0.000316
step 490  | loss: 1.2020 | lm_loss: 1.2009 | router_z_loss: 0.0004 | router_aux_loss: 0.0007 | lr: 0.000314
step 500  | loss: 1.0745 | lm_loss: 1.0734 | router_z_loss: 0.0003 | router_aux_loss: 0.0008 | lr: 0.000312


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.1442
step 510  | loss: 1.2005 | lm_loss: 1.1995 | router_z_loss: 0.0003 | router_aux_loss: 0.0007 | lr: 0.000310
step 520  | loss: 1.0898 | lm_loss: 1.0887 | router_z_loss: 0.0004 | router_aux_loss: 0.0007 | lr: 0.000308
step 530  | loss: 1.0098 | lm_loss: 1.0087 | router_z_loss: 0.0004 | router_aux_loss: 0.0007 | lr: 0.000306
step 540  | loss: 1.0514 | lm_loss: 1.0504 | router_z_loss: 0.0003 | router_aux_loss: 0.0007 | lr: 0.000304
step 550  | loss: 1.0452 | lm_loss: 1.0441 | router_z_loss: 0.0003 | router_aux_loss: 0.0007 | lr: 0.000302


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.0375
step 560  | loss: 1.0334 | lm_loss: 1.0324 | router_z_loss: 0.0003 | router_aux_loss: 0.0007 | lr: 0.000300
step 570  | loss: 0.9741 | lm_loss: 0.9731 | router_z_loss: 0.0003 | router_aux_loss: 0.0007 | lr: 0.000298
step 580  | loss: 0.8857 | lm_loss: 0.8847 | router_z_loss: 0.0004 | router_aux_loss: 0.0007 | lr: 0.000296
step 590  | loss: 0.8963 | lm_loss: 0.8954 | router_z_loss: 0.0003 | router_aux_loss: 0.0006 | lr: 0.000294
step 600  | loss: 1.0102 | lm_loss: 1.0092 | router_z_loss: 0.0003 | router_aux_loss: 0.0006 | lr: 0.000292


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.9410
step 610  | loss: 0.8983 | lm_loss: 0.8974 | router_z_loss: 0.0003 | router_aux_loss: 0.0006 | lr: 0.000290
step 620  | loss: 0.9484 | lm_loss: 0.9475 | router_z_loss: 0.0003 | router_aux_loss: 0.0006 | lr: 0.000288
step 630  | loss: 0.9542 | lm_loss: 0.9534 | router_z_loss: 0.0003 | router_aux_loss: 0.0006 | lr: 0.000286
step 640  | loss: 0.8591 | lm_loss: 0.8583 | router_z_loss: 0.0002 | router_aux_loss: 0.0006 | lr: 0.000284
step 650  | loss: 0.8702 | lm_loss: 0.8693 | router_z_loss: 0.0003 | router_aux_loss: 0.0006 | lr: 0.000282


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.8686
step 660  | loss: 0.8906 | lm_loss: 0.8898 | router_z_loss: 0.0002 | router_aux_loss: 0.0006 | lr: 0.000280
step 670  | loss: 0.7473 | lm_loss: 0.7465 | router_z_loss: 0.0002 | router_aux_loss: 0.0006 | lr: 0.000278
step 680  | loss: 0.8081 | lm_loss: 0.8074 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000276
step 690  | loss: 0.8927 | lm_loss: 0.8919 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000274
step 700  | loss: 0.7832 | lm_loss: 0.7824 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000272


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.8411
step 710  | loss: 0.8310 | lm_loss: 0.8303 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000270
step 720  | loss: 0.8526 | lm_loss: 0.8519 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000268
step 730  | loss: 0.7445 | lm_loss: 0.7439 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000266
step 740  | loss: 0.8015 | lm_loss: 0.8009 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000265
step 750  | loss: 0.7679 | lm_loss: 0.7672 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000263


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.7626
step 760  | loss: 0.7632 | lm_loss: 0.7625 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000261
step 770  | loss: 0.8589 | lm_loss: 0.8583 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000259
step 780  | loss: 0.7840 | lm_loss: 0.7833 | router_z_loss: 0.0002 | router_aux_loss: 0.0005 | lr: 0.000257
step 790  | loss: 0.6988 | lm_loss: 0.6981 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000255
step 800  | loss: 0.7631 | lm_loss: 0.7625 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000253


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.7367
step 810  | loss: 0.7945 | lm_loss: 0.7939 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000251
step 820  | loss: 0.7315 | lm_loss: 0.7309 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000249
step 830  | loss: 0.7177 | lm_loss: 0.7170 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000247
step 840  | loss: 0.7881 | lm_loss: 0.7875 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000245
step 850  | loss: 0.6879 | lm_loss: 0.6873 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000243


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.7112
step 860  | loss: 0.6358 | lm_loss: 0.6352 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000241
step 870  | loss: 0.6643 | lm_loss: 0.6637 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000239
step 880  | loss: 0.7159 | lm_loss: 0.7153 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000237
step 890  | loss: 0.6668 | lm_loss: 0.6663 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000235
step 900  | loss: 0.7259 | lm_loss: 0.7254 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000233


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.6755
step 910  | loss: 0.6078 | lm_loss: 0.6072 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000231
step 920  | loss: 0.7332 | lm_loss: 0.7328 | router_z_loss: 0.0002 | router_aux_loss: 0.0003 | lr: 0.000229
step 930  | loss: 0.6398 | lm_loss: 0.6394 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000227
step 940  | loss: 0.6681 | lm_loss: 0.6676 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000225
step 950  | loss: 0.6065 | lm_loss: 0.6060 | router_z_loss: 0.0002 | router_aux_loss: 0.0003 | lr: 0.000223


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.6547
step 960  | loss: 0.5821 | lm_loss: 0.5816 | router_z_loss: 0.0002 | router_aux_loss: 0.0003 | lr: 0.000221
step 970  | loss: 0.6279 | lm_loss: 0.6274 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000219
step 980  | loss: 0.5765 | lm_loss: 0.5761 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000217
step 990  | loss: 0.5916 | lm_loss: 0.5911 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000215
step 1000 | loss: 0.5391 | lm_loss: 0.5385 | router_z_loss: 0.0002 | router_aux_loss: 0.0004 | lr: 0.000213


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.6166
step 1010 | loss: 0.7458 | lm_loss: 0.7453 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000211
step 1020 | loss: 0.6217 | lm_loss: 0.6212 | router_z_loss: 0.0002 | router_aux_loss: 0.0003 | lr: 0.000210
step 1030 | loss: 0.6857 | lm_loss: 0.6852 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000208
step 1040 | loss: 0.6114 | lm_loss: 0.6109 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000206
step 1050 | loss: 0.6508 | lm_loss: 0.6504 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000204


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.6060
step 1060 | loss: 0.6093 | lm_loss: 0.6089 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000202
step 1070 | loss: 0.6084 | lm_loss: 0.6079 | router_z_loss: 0.0002 | router_aux_loss: 0.0003 | lr: 0.000200
step 1080 | loss: 0.5897 | lm_loss: 0.5893 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000198
step 1090 | loss: 0.5534 | lm_loss: 0.5530 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000196
step 1100 | loss: 0.5950 | lm_loss: 0.5946 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000194


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5926
step 1110 | loss: 0.6147 | lm_loss: 0.6142 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000192
step 1120 | loss: 0.5588 | lm_loss: 0.5584 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000190
step 1130 | loss: 0.5464 | lm_loss: 0.5460 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000188
step 1140 | loss: 0.6081 | lm_loss: 0.6077 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000186
step 1150 | loss: 0.5565 | lm_loss: 0.5561 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000184


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5646
step 1160 | loss: 0.6685 | lm_loss: 0.6681 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000182
step 1170 | loss: 0.5713 | lm_loss: 0.5709 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000180
step 1180 | loss: 0.5139 | lm_loss: 0.5135 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000178
step 1190 | loss: 0.5018 | lm_loss: 0.5013 | router_z_loss: 0.0002 | router_aux_loss: 0.0003 | lr: 0.000176
step 1200 | loss: 0.5785 | lm_loss: 0.5781 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000174


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5634
step 1210 | loss: 0.5025 | lm_loss: 0.5021 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000172
step 1220 | loss: 0.5544 | lm_loss: 0.5540 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000170
step 1230 | loss: 0.5392 | lm_loss: 0.5389 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000168
step 1240 | loss: 0.5098 | lm_loss: 0.5095 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000166
step 1250 | loss: 0.5462 | lm_loss: 0.5458 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000164


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5296
step 1260 | loss: 0.5786 | lm_loss: 0.5782 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000162
step 1270 | loss: 0.5145 | lm_loss: 0.5141 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000160
step 1280 | loss: 0.4265 | lm_loss: 0.4262 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000158
step 1290 | loss: 0.4655 | lm_loss: 0.4652 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000157
step 1300 | loss: 0.5542 | lm_loss: 0.5538 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000155


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5303
step 1310 | loss: 0.4460 | lm_loss: 0.4457 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000153
step 1320 | loss: 0.4598 | lm_loss: 0.4595 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000151
step 1330 | loss: 0.5490 | lm_loss: 0.5487 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000149
step 1340 | loss: 0.5337 | lm_loss: 0.5333 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000147
step 1350 | loss: 0.5546 | lm_loss: 0.5542 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000145


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5086
step 1360 | loss: 0.4608 | lm_loss: 0.4604 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000143
step 1370 | loss: 0.4373 | lm_loss: 0.4369 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000141
step 1380 | loss: 0.4945 | lm_loss: 0.4942 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000139
step 1390 | loss: 0.4415 | lm_loss: 0.4412 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000137
step 1400 | loss: 0.5207 | lm_loss: 0.5204 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000135


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5101
step 1410 | loss: 0.4903 | lm_loss: 0.4899 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000133
step 1420 | loss: 0.3747 | lm_loss: 0.3743 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000131
step 1430 | loss: 0.5029 | lm_loss: 0.5026 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000129
step 1440 | loss: 0.5878 | lm_loss: 0.5875 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000127
step 1450 | loss: 0.5326 | lm_loss: 0.5323 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000125


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4885
step 1460 | loss: 0.5024 | lm_loss: 0.5021 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000123
step 1470 | loss: 0.4735 | lm_loss: 0.4732 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000121
step 1480 | loss: 0.4241 | lm_loss: 0.4237 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000119
step 1490 | loss: 0.5409 | lm_loss: 0.5406 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000117
step 1500 | loss: 0.6167 | lm_loss: 0.6163 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000115


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4712
step 1510 | loss: 0.5165 | lm_loss: 0.5162 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000113
step 1520 | loss: 0.4622 | lm_loss: 0.4618 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000111
step 1530 | loss: 0.4397 | lm_loss: 0.4394 | router_z_loss: 0.0001 | router_aux_loss: 0.0003 | lr: 0.000109
step 1540 | loss: 0.4269 | lm_loss: 0.4266 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000107
step 1550 | loss: 0.4106 | lm_loss: 0.4103 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000105


Evaluating: : 0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4748
step 1560 | loss: 0.4558 | lm_loss: 0.4555 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000103
step 1570 | loss: 0.5932 | lm_loss: 0.5929 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000102
step 1580 | loss: 0.5024 | lm_loss: 0.5020 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000100
step 1590 | loss: 0.4417 | lm_loss: 0.4413 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000098
step 1600 | loss: 0.5512 | lm_loss: 0.5508 | router_z_loss: 0.0001 | router_aux_loss: 0.0002 | lr: 0.000096


Evaluating: : 0it [00:00, ?it/s]

KeyboardInterrupt: 