In [3]:

from balm.data import load_dataset, DataCollator
from balm.models.balm import BalmForMaskedLM
from balm.models.balm_moe import BalmMoEForMaskedLM
from balm.models.balm_moe_rope import BalmMoERoPEForMaskedLM
from balm.tokenizer import Tokenizer
from balm.training.trainer import Trainer


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [4]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [5]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [6]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text"
)

  0%|          | 0/66792 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [7]:
collator = DataCollator(tokenizer=tokenizer)

In [8]:
# model = BalmMoERoPEForMaskedLM(
model = BalmMoEForMaskedLM(
    embed_dim=256,
    ffn_dim=1024,
    num_experts=4,
    num_layers=8,
    num_heads=8,
    expert_capacity=128,
    router_z_loss_coef=0.01,
    router_aux_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)


In [9]:
model.num_parameters

18982689

## Trainer todos:
- [x] train/eval dataloader
- [x] move everything to the correct device (cuda or cpu)
- [x] learning rate (including scheduler)
- [ ] gradient accumulation steps
- [ ] logging
- [ ] checkpointing
- [ ] distributed training
- [ ] gradient clipping



In [10]:
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    matthews_corrcoef,
    precision_recall_fscore_support,
)


def compute_metrics(eval_predictions):
    # # Convert predictions and labels to CPU if necessary and to NumPy arrays for sklearn
    # preds = eval_predictions.predictions.cpu().numpy()
    # lbls = eval_predictions.labels.cpu().numpy()

    # # Calculate accuracy
    # accuracy = accuracy_score(lbls, preds)

    # # Calculate AUROC
    # # Note: For multi-class, you need to binarize the labels and use `average='macro'` or `average='micro'`
    # try:
    #     auroc_score = roc_auc_score(
    #         lbls, preds, multi_class="ovo" if len(lbls.shape) > 1 else "raise"
    #     )
    # except ValueError:
    #     # AUROC is not defined when there is only one class in the labels, or labels are not binary
    #     auroc_score = float("nan")

    # # Calculate Matthews correlation coefficient (MCC)
    # try:
    #     mcc_score = matthews_corrcoef(lbls, preds)
    # except ValueError:
    #     # MCC is not defined when there is only one class in the labels
    #     mcc_score = float("nan")

    # # Calculate precision, recall, and F1-score
    # precision, recall, f1, _ = precision_recall_fscore_support(
    #     lbls, preds, average="binary"
    # )

    # Calculate perplexity if logits are provided and labels are not one-hot encoded
    if hasattr(eval_predictions, "logits"):
        softmaxes = F.softmax(eval_predictions.logits, dim=-1)
        cross_entropy = F.cross_entropy(softmaxes, eval_predictions.labels)
        perplexity = torch.exp(cross_entropy)

    return {
        "perplexity": perplexity.item(),
        # "accuracy": accuracy,
        # "auroc": auroc_score,
        # "mcc": mcc_score,
        # "precision": precision,
        # "recall": recall,
        # "f1": f1,
    }

In [11]:
trainer = Trainer(
    model=model, 
    data_collator=collator, 
    train_dataset=tokenized_dataset["train"], 
    eval_dataset=tokenized_dataset["eval"],
    epochs=1,
    logging_steps=10,
    eval_steps=50,
    warmup_steps = 50,
    per_device_train_batch_size=32,
    # per_device_eval_batch_size=32,
    use_cpu=True,
    # compute_metrics=compute_metrics,
)



In [14]:
trainer.device

device(type='cpu')

In [15]:
trainer.train()

  0%|          | 0/2087 [00:00<?, ?it/s]

step 10   | loss: 3.0613 | lm_loss: 3.0290 | router_z_loss: 0.0220 | router_aux_loss: 0.0102 | lr: 0.000080
step 20   | loss: 2.7867 | lm_loss: 2.7573 | router_z_loss: 0.0194 | router_aux_loss: 0.0101 | lr: 0.000160
step 30   | loss: 2.6294 | lm_loss: 2.6043 | router_z_loss: 0.0150 | router_aux_loss: 0.0101 | lr: 0.000240
step 40   | loss: 2.5059 | lm_loss: 2.4840 | router_z_loss: 0.0117 | router_aux_loss: 0.0101 | lr: 0.000320
step 50   | loss: 2.3115 | lm_loss: 2.2930 | router_z_loss: 0.0084 | router_aux_loss: 0.0101 | lr: 0.000400


0it [00:00, ?it/s]

<<< EVAL >>> loss: 2.2903
step 60   | loss: 2.2155 | lm_loss: 2.1983 | router_z_loss: 0.0070 | router_aux_loss: 0.0101 | lr: 0.000398
step 70   | loss: 2.0435 | lm_loss: 2.0280 | router_z_loss: 0.0055 | router_aux_loss: 0.0100 | lr: 0.000396
step 80   | loss: 2.1350 | lm_loss: 2.1198 | router_z_loss: 0.0053 | router_aux_loss: 0.0099 | lr: 0.000394
step 90   | loss: 2.0281 | lm_loss: 2.0132 | router_z_loss: 0.0049 | router_aux_loss: 0.0099 | lr: 0.000392
step 100  | loss: 2.0571 | lm_loss: 2.0427 | router_z_loss: 0.0045 | router_aux_loss: 0.0099 | lr: 0.000390


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.9878
step 110  | loss: 1.9012 | lm_loss: 1.8876 | router_z_loss: 0.0038 | router_aux_loss: 0.0098 | lr: 0.000388
step 120  | loss: 1.9250 | lm_loss: 1.9112 | router_z_loss: 0.0040 | router_aux_loss: 0.0098 | lr: 0.000386
step 130  | loss: 1.9637 | lm_loss: 1.9498 | router_z_loss: 0.0040 | router_aux_loss: 0.0098 | lr: 0.000384
step 140  | loss: 1.9374 | lm_loss: 1.9240 | router_z_loss: 0.0036 | router_aux_loss: 0.0098 | lr: 0.000382
step 150  | loss: 1.9180 | lm_loss: 1.9046 | router_z_loss: 0.0036 | router_aux_loss: 0.0098 | lr: 0.000380


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.8761
step 160  | loss: 1.9958 | lm_loss: 1.9828 | router_z_loss: 0.0032 | router_aux_loss: 0.0098 | lr: 0.000378
step 170  | loss: 1.7992 | lm_loss: 1.7865 | router_z_loss: 0.0029 | router_aux_loss: 0.0098 | lr: 0.000376
step 180  | loss: 1.8263 | lm_loss: 1.8128 | router_z_loss: 0.0037 | router_aux_loss: 0.0099 | lr: 0.000374
step 190  | loss: 1.7817 | lm_loss: 1.7687 | router_z_loss: 0.0031 | router_aux_loss: 0.0098 | lr: 0.000373
step 200  | loss: 1.7914 | lm_loss: 1.7785 | router_z_loss: 0.0031 | router_aux_loss: 0.0098 | lr: 0.000371


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.7428
step 210  | loss: 1.7795 | lm_loss: 1.7665 | router_z_loss: 0.0032 | router_aux_loss: 0.0098 | lr: 0.000369
step 220  | loss: 1.7894 | lm_loss: 1.7765 | router_z_loss: 0.0030 | router_aux_loss: 0.0099 | lr: 0.000367
step 230  | loss: 1.6386 | lm_loss: 1.6256 | router_z_loss: 0.0032 | router_aux_loss: 0.0099 | lr: 0.000365
step 240  | loss: 1.7444 | lm_loss: 1.7316 | router_z_loss: 0.0030 | router_aux_loss: 0.0098 | lr: 0.000363
step 250  | loss: 1.7312 | lm_loss: 1.7175 | router_z_loss: 0.0038 | router_aux_loss: 0.0099 | lr: 0.000361


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.6630
step 260  | loss: 1.6239 | lm_loss: 1.6107 | router_z_loss: 0.0033 | router_aux_loss: 0.0099 | lr: 0.000359
step 270  | loss: 1.6648 | lm_loss: 1.6513 | router_z_loss: 0.0036 | router_aux_loss: 0.0099 | lr: 0.000357
step 280  | loss: 1.5151 | lm_loss: 1.5018 | router_z_loss: 0.0036 | router_aux_loss: 0.0098 | lr: 0.000355
step 290  | loss: 1.6363 | lm_loss: 1.6228 | router_z_loss: 0.0035 | router_aux_loss: 0.0099 | lr: 0.000353
step 300  | loss: 1.6073 | lm_loss: 1.5936 | router_z_loss: 0.0038 | router_aux_loss: 0.0099 | lr: 0.000351


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.5532
step 310  | loss: 1.4659 | lm_loss: 1.4525 | router_z_loss: 0.0036 | router_aux_loss: 0.0099 | lr: 0.000349
step 320  | loss: 1.5855 | lm_loss: 1.5714 | router_z_loss: 0.0041 | router_aux_loss: 0.0100 | lr: 0.000347
step 330  | loss: 1.4525 | lm_loss: 1.4392 | router_z_loss: 0.0034 | router_aux_loss: 0.0099 | lr: 0.000345
step 340  | loss: 1.5250 | lm_loss: 1.5117 | router_z_loss: 0.0035 | router_aux_loss: 0.0099 | lr: 0.000343
step 350  | loss: 1.4944 | lm_loss: 1.4809 | router_z_loss: 0.0036 | router_aux_loss: 0.0098 | lr: 0.000341


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.4352
step 360  | loss: 1.4269 | lm_loss: 1.4137 | router_z_loss: 0.0034 | router_aux_loss: 0.0098 | lr: 0.000339
step 370  | loss: 1.3977 | lm_loss: 1.3842 | router_z_loss: 0.0036 | router_aux_loss: 0.0099 | lr: 0.000337
step 380  | loss: 1.3770 | lm_loss: 1.3638 | router_z_loss: 0.0033 | router_aux_loss: 0.0099 | lr: 0.000335
step 390  | loss: 1.3644 | lm_loss: 1.3512 | router_z_loss: 0.0033 | router_aux_loss: 0.0099 | lr: 0.000333
step 400  | loss: 1.3154 | lm_loss: 1.3017 | router_z_loss: 0.0038 | router_aux_loss: 0.0099 | lr: 0.000331


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.2986
step 410  | loss: 1.2645 | lm_loss: 1.2511 | router_z_loss: 0.0035 | router_aux_loss: 0.0099 | lr: 0.000329
step 420  | loss: 1.2789 | lm_loss: 1.2653 | router_z_loss: 0.0037 | router_aux_loss: 0.0099 | lr: 0.000327
step 430  | loss: 1.2741 | lm_loss: 1.2606 | router_z_loss: 0.0036 | router_aux_loss: 0.0099 | lr: 0.000325
step 440  | loss: 1.2213 | lm_loss: 1.2074 | router_z_loss: 0.0040 | router_aux_loss: 0.0099 | lr: 0.000323
step 450  | loss: 1.3690 | lm_loss: 1.3557 | router_z_loss: 0.0034 | router_aux_loss: 0.0099 | lr: 0.000321


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.2049
step 460  | loss: 1.1806 | lm_loss: 1.1670 | router_z_loss: 0.0037 | router_aux_loss: 0.0099 | lr: 0.000319
step 470  | loss: 1.1628 | lm_loss: 1.1495 | router_z_loss: 0.0034 | router_aux_loss: 0.0099 | lr: 0.000318
step 480  | loss: 1.2746 | lm_loss: 1.2612 | router_z_loss: 0.0035 | router_aux_loss: 0.0099 | lr: 0.000316
step 490  | loss: 1.0425 | lm_loss: 1.0294 | router_z_loss: 0.0033 | router_aux_loss: 0.0099 | lr: 0.000314
step 500  | loss: 1.1179 | lm_loss: 1.1047 | router_z_loss: 0.0032 | router_aux_loss: 0.0099 | lr: 0.000312


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.0961
step 510  | loss: 1.0947 | lm_loss: 1.0810 | router_z_loss: 0.0038 | router_aux_loss: 0.0099 | lr: 0.000310
step 520  | loss: 1.0740 | lm_loss: 1.0609 | router_z_loss: 0.0031 | router_aux_loss: 0.0099 | lr: 0.000308
step 530  | loss: 0.9425 | lm_loss: 0.9291 | router_z_loss: 0.0035 | router_aux_loss: 0.0099 | lr: 0.000306
step 540  | loss: 1.0273 | lm_loss: 1.0129 | router_z_loss: 0.0045 | router_aux_loss: 0.0099 | lr: 0.000304
step 550  | loss: 1.0177 | lm_loss: 1.0044 | router_z_loss: 0.0034 | router_aux_loss: 0.0099 | lr: 0.000302


0it [00:00, ?it/s]

<<< EVAL >>> loss: 1.0067
step 560  | loss: 0.9703 | lm_loss: 0.9567 | router_z_loss: 0.0036 | router_aux_loss: 0.0100 | lr: 0.000300
step 570  | loss: 0.9480 | lm_loss: 0.9346 | router_z_loss: 0.0034 | router_aux_loss: 0.0099 | lr: 0.000298
step 580  | loss: 0.8762 | lm_loss: 0.8629 | router_z_loss: 0.0034 | router_aux_loss: 0.0100 | lr: 0.000296
step 590  | loss: 0.8394 | lm_loss: 0.8258 | router_z_loss: 0.0036 | router_aux_loss: 0.0100 | lr: 0.000294
step 600  | loss: 0.9707 | lm_loss: 0.9572 | router_z_loss: 0.0035 | router_aux_loss: 0.0099 | lr: 0.000292


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.9199
step 610  | loss: 0.8359 | lm_loss: 0.8224 | router_z_loss: 0.0035 | router_aux_loss: 0.0099 | lr: 0.000290
step 620  | loss: 0.9534 | lm_loss: 0.9405 | router_z_loss: 0.0030 | router_aux_loss: 0.0099 | lr: 0.000288
step 630  | loss: 0.8812 | lm_loss: 0.8680 | router_z_loss: 0.0033 | router_aux_loss: 0.0099 | lr: 0.000286
step 640  | loss: 0.8688 | lm_loss: 0.8561 | router_z_loss: 0.0027 | router_aux_loss: 0.0099 | lr: 0.000284
step 650  | loss: 0.8130 | lm_loss: 0.8005 | router_z_loss: 0.0026 | router_aux_loss: 0.0099 | lr: 0.000282


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.8584
step 660  | loss: 0.8435 | lm_loss: 0.8302 | router_z_loss: 0.0034 | router_aux_loss: 0.0099 | lr: 0.000280
step 670  | loss: 0.7488 | lm_loss: 0.7359 | router_z_loss: 0.0031 | router_aux_loss: 0.0099 | lr: 0.000278
step 680  | loss: 0.7541 | lm_loss: 0.7413 | router_z_loss: 0.0029 | router_aux_loss: 0.0099 | lr: 0.000276
step 690  | loss: 0.8484 | lm_loss: 0.8353 | router_z_loss: 0.0032 | router_aux_loss: 0.0099 | lr: 0.000274
step 700  | loss: 0.7816 | lm_loss: 0.7686 | router_z_loss: 0.0031 | router_aux_loss: 0.0099 | lr: 0.000272


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.8177
step 710  | loss: 0.7874 | lm_loss: 0.7747 | router_z_loss: 0.0028 | router_aux_loss: 0.0099 | lr: 0.000270
step 720  | loss: 0.7785 | lm_loss: 0.7656 | router_z_loss: 0.0029 | router_aux_loss: 0.0099 | lr: 0.000268
step 730  | loss: 0.7181 | lm_loss: 0.7053 | router_z_loss: 0.0029 | router_aux_loss: 0.0099 | lr: 0.000266
step 740  | loss: 0.8211 | lm_loss: 0.8085 | router_z_loss: 0.0027 | router_aux_loss: 0.0099 | lr: 0.000265
step 750  | loss: 0.6736 | lm_loss: 0.6608 | router_z_loss: 0.0030 | router_aux_loss: 0.0099 | lr: 0.000263


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.7525
step 760  | loss: 0.7305 | lm_loss: 0.7177 | router_z_loss: 0.0029 | router_aux_loss: 0.0099 | lr: 0.000261
step 770  | loss: 0.8149 | lm_loss: 0.8021 | router_z_loss: 0.0029 | router_aux_loss: 0.0099 | lr: 0.000259
step 780  | loss: 0.7337 | lm_loss: 0.7208 | router_z_loss: 0.0029 | router_aux_loss: 0.0100 | lr: 0.000257
step 790  | loss: 0.7227 | lm_loss: 0.7102 | router_z_loss: 0.0026 | router_aux_loss: 0.0099 | lr: 0.000255
step 800  | loss: 0.8370 | lm_loss: 0.8246 | router_z_loss: 0.0025 | router_aux_loss: 0.0099 | lr: 0.000253


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.7020
step 810  | loss: 0.8217 | lm_loss: 0.8088 | router_z_loss: 0.0029 | router_aux_loss: 0.0100 | lr: 0.000251
step 820  | loss: 0.7038 | lm_loss: 0.6909 | router_z_loss: 0.0029 | router_aux_loss: 0.0099 | lr: 0.000249
step 830  | loss: 0.6691 | lm_loss: 0.6567 | router_z_loss: 0.0025 | router_aux_loss: 0.0100 | lr: 0.000247
step 840  | loss: 0.6765 | lm_loss: 0.6638 | router_z_loss: 0.0027 | router_aux_loss: 0.0100 | lr: 0.000245
step 850  | loss: 0.6723 | lm_loss: 0.6597 | router_z_loss: 0.0027 | router_aux_loss: 0.0099 | lr: 0.000243


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.6741
step 860  | loss: 0.6255 | lm_loss: 0.6128 | router_z_loss: 0.0028 | router_aux_loss: 0.0100 | lr: 0.000241
step 870  | loss: 0.6101 | lm_loss: 0.5980 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000239
step 880  | loss: 0.6598 | lm_loss: 0.6475 | router_z_loss: 0.0024 | router_aux_loss: 0.0099 | lr: 0.000237
step 890  | loss: 0.6582 | lm_loss: 0.6457 | router_z_loss: 0.0026 | router_aux_loss: 0.0099 | lr: 0.000235
step 900  | loss: 0.7254 | lm_loss: 0.7133 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000233


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.6544
step 910  | loss: 0.5693 | lm_loss: 0.5569 | router_z_loss: 0.0025 | router_aux_loss: 0.0099 | lr: 0.000231
step 920  | loss: 0.6845 | lm_loss: 0.6724 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000229
step 930  | loss: 0.5832 | lm_loss: 0.5705 | router_z_loss: 0.0027 | router_aux_loss: 0.0099 | lr: 0.000227
step 940  | loss: 0.6486 | lm_loss: 0.6359 | router_z_loss: 0.0027 | router_aux_loss: 0.0100 | lr: 0.000225
step 950  | loss: 0.5547 | lm_loss: 0.5426 | router_z_loss: 0.0023 | router_aux_loss: 0.0099 | lr: 0.000223


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.6283
step 960  | loss: 0.6195 | lm_loss: 0.6074 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000221
step 970  | loss: 0.5589 | lm_loss: 0.5466 | router_z_loss: 0.0024 | router_aux_loss: 0.0099 | lr: 0.000219
step 980  | loss: 0.5480 | lm_loss: 0.5357 | router_z_loss: 0.0024 | router_aux_loss: 0.0099 | lr: 0.000217
step 990  | loss: 0.5939 | lm_loss: 0.5816 | router_z_loss: 0.0024 | router_aux_loss: 0.0099 | lr: 0.000215
step 1000 | loss: 0.5394 | lm_loss: 0.5270 | router_z_loss: 0.0025 | router_aux_loss: 0.0099 | lr: 0.000213


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.6029
step 1010 | loss: 0.7421 | lm_loss: 0.7301 | router_z_loss: 0.0020 | router_aux_loss: 0.0099 | lr: 0.000211
step 1020 | loss: 0.6450 | lm_loss: 0.6328 | router_z_loss: 0.0023 | router_aux_loss: 0.0099 | lr: 0.000210
step 1030 | loss: 0.6361 | lm_loss: 0.6243 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000208
step 1040 | loss: 0.5771 | lm_loss: 0.5653 | router_z_loss: 0.0020 | router_aux_loss: 0.0099 | lr: 0.000206
step 1050 | loss: 0.6057 | lm_loss: 0.5935 | router_z_loss: 0.0023 | router_aux_loss: 0.0099 | lr: 0.000204


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5790
step 1060 | loss: 0.5448 | lm_loss: 0.5325 | router_z_loss: 0.0023 | router_aux_loss: 0.0099 | lr: 0.000202
step 1070 | loss: 0.5303 | lm_loss: 0.5184 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000200
step 1080 | loss: 0.5491 | lm_loss: 0.5367 | router_z_loss: 0.0026 | router_aux_loss: 0.0099 | lr: 0.000198
step 1090 | loss: 0.5267 | lm_loss: 0.5145 | router_z_loss: 0.0024 | router_aux_loss: 0.0099 | lr: 0.000196
step 1100 | loss: 0.5727 | lm_loss: 0.5607 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000194


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5622
step 1110 | loss: 0.5492 | lm_loss: 0.5371 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000192
step 1120 | loss: 0.5332 | lm_loss: 0.5215 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000190
step 1130 | loss: 0.4756 | lm_loss: 0.4639 | router_z_loss: 0.0019 | router_aux_loss: 0.0098 | lr: 0.000188
step 1140 | loss: 0.5621 | lm_loss: 0.5501 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000186
step 1150 | loss: 0.5241 | lm_loss: 0.5117 | router_z_loss: 0.0024 | router_aux_loss: 0.0099 | lr: 0.000184


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5460
step 1160 | loss: 0.6750 | lm_loss: 0.6630 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000182
step 1170 | loss: 0.5131 | lm_loss: 0.5011 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000180
step 1180 | loss: 0.4901 | lm_loss: 0.4781 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000178
step 1190 | loss: 0.4839 | lm_loss: 0.4717 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000176
step 1200 | loss: 0.5407 | lm_loss: 0.5287 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000174


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5271
step 1210 | loss: 0.4878 | lm_loss: 0.4756 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000172
step 1220 | loss: 0.4988 | lm_loss: 0.4871 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000170
step 1230 | loss: 0.5150 | lm_loss: 0.5034 | router_z_loss: 0.0018 | router_aux_loss: 0.0098 | lr: 0.000168
step 1240 | loss: 0.4666 | lm_loss: 0.4550 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000166
step 1250 | loss: 0.5611 | lm_loss: 0.5491 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000164


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.5013
step 1260 | loss: 0.5193 | lm_loss: 0.5072 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000162
step 1270 | loss: 0.5042 | lm_loss: 0.4921 | router_z_loss: 0.0022 | router_aux_loss: 0.0099 | lr: 0.000160
step 1280 | loss: 0.3864 | lm_loss: 0.3747 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000158
step 1290 | loss: 0.4235 | lm_loss: 0.4116 | router_z_loss: 0.0020 | router_aux_loss: 0.0099 | lr: 0.000157
step 1300 | loss: 0.5670 | lm_loss: 0.5550 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000155


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4900
step 1310 | loss: 0.4102 | lm_loss: 0.3983 | router_z_loss: 0.0020 | router_aux_loss: 0.0099 | lr: 0.000153
step 1320 | loss: 0.4398 | lm_loss: 0.4281 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000151
step 1330 | loss: 0.5274 | lm_loss: 0.5157 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000149
step 1340 | loss: 0.5686 | lm_loss: 0.5562 | router_z_loss: 0.0026 | router_aux_loss: 0.0098 | lr: 0.000147
step 1350 | loss: 0.5008 | lm_loss: 0.4888 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000145


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4874
step 1360 | loss: 0.4318 | lm_loss: 0.4201 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000143
step 1370 | loss: 0.3705 | lm_loss: 0.3588 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000141
step 1380 | loss: 0.4116 | lm_loss: 0.3994 | router_z_loss: 0.0023 | router_aux_loss: 0.0099 | lr: 0.000139
step 1390 | loss: 0.4179 | lm_loss: 0.4061 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000137
step 1400 | loss: 0.4705 | lm_loss: 0.4587 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000135


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4802
step 1410 | loss: 0.4383 | lm_loss: 0.4264 | router_z_loss: 0.0020 | router_aux_loss: 0.0099 | lr: 0.000133
step 1420 | loss: 0.3036 | lm_loss: 0.2920 | router_z_loss: 0.0018 | router_aux_loss: 0.0098 | lr: 0.000131
step 1430 | loss: 0.4946 | lm_loss: 0.4827 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000129
step 1440 | loss: 0.5254 | lm_loss: 0.5135 | router_z_loss: 0.0020 | router_aux_loss: 0.0098 | lr: 0.000127
step 1450 | loss: 0.5022 | lm_loss: 0.4906 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000125


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4558
step 1460 | loss: 0.5294 | lm_loss: 0.5176 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000123
step 1470 | loss: 0.4446 | lm_loss: 0.4327 | router_z_loss: 0.0020 | router_aux_loss: 0.0098 | lr: 0.000121
step 1480 | loss: 0.3494 | lm_loss: 0.3375 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000119
step 1490 | loss: 0.5330 | lm_loss: 0.5215 | router_z_loss: 0.0016 | router_aux_loss: 0.0099 | lr: 0.000117
step 1500 | loss: 0.5739 | lm_loss: 0.5621 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000115


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4443
step 1510 | loss: 0.5581 | lm_loss: 0.5464 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000113
step 1520 | loss: 0.4784 | lm_loss: 0.4668 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000111
step 1530 | loss: 0.3954 | lm_loss: 0.3835 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000109
step 1540 | loss: 0.3952 | lm_loss: 0.3836 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000107
step 1550 | loss: 0.3995 | lm_loss: 0.3878 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000105


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4483
step 1560 | loss: 0.4524 | lm_loss: 0.4408 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000103
step 1570 | loss: 0.5533 | lm_loss: 0.5417 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000102
step 1580 | loss: 0.4729 | lm_loss: 0.4609 | router_z_loss: 0.0021 | router_aux_loss: 0.0099 | lr: 0.000100
step 1590 | loss: 0.3863 | lm_loss: 0.3745 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000098
step 1600 | loss: 0.5633 | lm_loss: 0.5514 | router_z_loss: 0.0020 | router_aux_loss: 0.0099 | lr: 0.000096


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4245
step 1610 | loss: 0.4275 | lm_loss: 0.4158 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000094
step 1620 | loss: 0.4210 | lm_loss: 0.4095 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000092
step 1630 | loss: 0.3861 | lm_loss: 0.3743 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000090
step 1640 | loss: 0.4456 | lm_loss: 0.4338 | router_z_loss: 0.0020 | router_aux_loss: 0.0099 | lr: 0.000088
step 1650 | loss: 0.3780 | lm_loss: 0.3662 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000086


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4253
step 1660 | loss: 0.4927 | lm_loss: 0.4809 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000084
step 1670 | loss: 0.4519 | lm_loss: 0.4404 | router_z_loss: 0.0016 | router_aux_loss: 0.0098 | lr: 0.000082
step 1680 | loss: 0.4153 | lm_loss: 0.4037 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000080
step 1690 | loss: 0.4755 | lm_loss: 0.4639 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000078
step 1700 | loss: 0.3685 | lm_loss: 0.3569 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000076


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4166
step 1710 | loss: 0.3462 | lm_loss: 0.3345 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000074
step 1720 | loss: 0.4024 | lm_loss: 0.3907 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000072
step 1730 | loss: 0.4290 | lm_loss: 0.4175 | router_z_loss: 0.0016 | router_aux_loss: 0.0099 | lr: 0.000070
step 1740 | loss: 0.4752 | lm_loss: 0.4634 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000068
step 1750 | loss: 0.4271 | lm_loss: 0.4153 | router_z_loss: 0.0019 | router_aux_loss: 0.0098 | lr: 0.000066


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4199
step 1760 | loss: 0.4045 | lm_loss: 0.3928 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000064
step 1770 | loss: 0.3568 | lm_loss: 0.3450 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000062
step 1780 | loss: 0.3619 | lm_loss: 0.3502 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000060
step 1790 | loss: 0.3981 | lm_loss: 0.3863 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000058
step 1800 | loss: 0.3505 | lm_loss: 0.3389 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000056


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4016
step 1810 | loss: 0.4007 | lm_loss: 0.3889 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000054
step 1820 | loss: 0.4253 | lm_loss: 0.4135 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000052
step 1830 | loss: 0.4739 | lm_loss: 0.4621 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000050
step 1840 | loss: 0.3520 | lm_loss: 0.3405 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000049
step 1850 | loss: 0.3562 | lm_loss: 0.3446 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000047


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.4020
step 1860 | loss: 0.5023 | lm_loss: 0.4905 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000045
step 1870 | loss: 0.4983 | lm_loss: 0.4865 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000043
step 1880 | loss: 0.4400 | lm_loss: 0.4284 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000041
step 1890 | loss: 0.3513 | lm_loss: 0.3398 | router_z_loss: 0.0017 | router_aux_loss: 0.0098 | lr: 0.000039
step 1900 | loss: 0.5587 | lm_loss: 0.5469 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000037


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.3965
step 1910 | loss: 0.4177 | lm_loss: 0.4061 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000035
step 1920 | loss: 0.4008 | lm_loss: 0.3890 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000033
step 1930 | loss: 0.3832 | lm_loss: 0.3716 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000031
step 1940 | loss: 0.4015 | lm_loss: 0.3899 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000029
step 1950 | loss: 0.4620 | lm_loss: 0.4505 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000027


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.3863
step 1960 | loss: 0.3871 | lm_loss: 0.3753 | router_z_loss: 0.0019 | router_aux_loss: 0.0099 | lr: 0.000025
step 1970 | loss: 0.3896 | lm_loss: 0.3779 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000023
step 1980 | loss: 0.4136 | lm_loss: 0.4020 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000021
step 1990 | loss: 0.3407 | lm_loss: 0.3290 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000019
step 2000 | loss: 0.4843 | lm_loss: 0.4727 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000017


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.3835
step 2010 | loss: 0.4303 | lm_loss: 0.4185 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000015
step 2020 | loss: 0.3774 | lm_loss: 0.3659 | router_z_loss: 0.0016 | router_aux_loss: 0.0099 | lr: 0.000013
step 2030 | loss: 0.3979 | lm_loss: 0.3862 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000011
step 2040 | loss: 0.3444 | lm_loss: 0.3329 | router_z_loss: 0.0017 | router_aux_loss: 0.0098 | lr: 0.000009
step 2050 | loss: 0.3665 | lm_loss: 0.3548 | router_z_loss: 0.0018 | router_aux_loss: 0.0099 | lr: 0.000007


0it [00:00, ?it/s]

<<< EVAL >>> loss: 0.3885
step 2060 | loss: 0.4335 | lm_loss: 0.4221 | router_z_loss: 0.0016 | router_aux_loss: 0.0098 | lr: 0.000005
step 2070 | loss: 0.3745 | lm_loss: 0.3629 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000003
step 2080 | loss: 0.3213 | lm_loss: 0.3097 | router_z_loss: 0.0017 | router_aux_loss: 0.0099 | lr: 0.000001
Training complete


In [13]:
len(trainer.eval_dataset) // trainer.total_eval_batch_size



31

In [10]:
train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=trainer.total_train_batch_size,
    shuffle=True,
)




In [11]:
for batch in train_dataloader:
    collated = collator(batch)
    break


In [12]:
collated

{'input_ids': tensor([[ 0,  9,  7,  ...,  1,  1,  1],
         [ 0, 32,  7,  ...,  1,  1,  1],
         [ 0,  9,  7,  ...,  1,  1,  1],
         ...,
         [ 0, 16,  7,  ...,  1,  1,  1],
         [ 0,  9,  7,  ...,  1,  1,  1],
         [ 0, 32,  7,  ...,  1,  1,  1]]),
 'labels': tensor([[-100, -100, -100,  ..., -100, -100, -100],
         [-100,   16, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         ...,
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100, -100, -100,  ..., -100, -100, -100],
         [-100,   16, -100,  ..., -100, -100, -100]]),
 'key_padding_mask': tensor([[False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         ...,
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  True],
         [False, False, False,  ...,  True,  True,  