In [1]:
from dataclasses import asdict, dataclass, field, fields
from enum import Enum, StrEnum
from typing import Optional, Tuple, List, Dict, Any, Iterable, Union

from balm.data import load_dataset, DataCollator, Dataset
from balm.models.balm import BalmForMaskedLM
from balm.models.balm_moe import BalmMoEForMaskedLM
from balm.models.balm_moe_rope import BalmMoERoPEForMaskedLM
from balm.tokenizer import Tokenizer

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Optimizer
from torch.utils.data import DataLoader

In [2]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [3]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [4]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    # remove_columns="text"
)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [5]:
collator = DataCollator(tokenizer=tokenizer)

In [6]:
train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=32,
    shuffle=True,
)



In [7]:
# model = BalmMoERoPEForMaskedLM(
model = BalmMoEForMaskedLM(
    embed_dim=256,
    ffn_dim=1024,
    num_experts=4,
    num_layers=8,
    num_heads=8,
    expert_capacity=128,
    router_z_loss_coef=0.01,
    router_aux_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)
# model = model.to(device)

# optimizer = optim.Adam(model.parameters(), lr=4e-4)



In [8]:
model.num_parameters

18982689

## Trainer todos:
* train/eval dataloader
* move everything to the correct device (cuda or cpu)
* learning rate (including scheduler)
* gradient accumulation steps
* logging
* checkpointing
* distributed training
* gradient clipping



In [9]:
# class ExplicitEnum(Enum):
#     def __init__(self, value):
#         self._value = value

#     def __repr__(self):
#         return f"{self.__class__.__name__}.{self.name}"
    
#     @classmethod
#     def _missing_(cls, value):
#         raise ValueError(
#             f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
#         )


# class IntervalStrategy(ExplicitEnum):
#     NO = "no"
#     STEPS = "steps"
#     EPOCHS = "epochs"


# class PaddingStrategy(ExplicitEnum):
#     LONGEST = "longest"
#     MAX_LENGTH = "max_length"
#     DO_NOT_PAD = "do_not_pad"


# class SchedulerType(ExplicitEnum):
#     LINEAR = "linear"
#     COSINE = "cosine"
#     COSINE_WITH_RESTARTS = "cosine_with_restarts"
#     POLYNOMIAL = "polynomial"
#     CONSTANT = "constant"
#     CONSTANT_WITH_WARMUP = "constant_with_warmup"
#     INVERSE_SQRT = "inverse_sqrt"
#     REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"


# class OptimizerNames(ExplicitEnum):
#     """
#     Stores the acceptable string identifiers for optimizers.
#     """

#     ADAMW_HF = "adamw_hf"
#     ADAMW_TORCH = "adamw_torch"
#     ADAMW_TORCH_FUSED = "adamw_torch_fused"
#     ADAMW_TORCH_XLA = "adamw_torch_xla"
#     ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
#     ADAMW_APEX_FUSED = "adamw_apex_fused"
#     ADAFACTOR = "adafactor"
#     ADAMW_ANYPRECISION = "adamw_anyprecision"
#     SGD = "sgd"
#     ADAGRAD = "adagrad"
#     ADAMW_BNB = "adamw_bnb_8bit"
#     ADAMW_8BIT = "adamw_8bit"  # just an alias for adamw_bnb_8bit
#     LION_8BIT = "lion_8bit"
#     LION = "lion_32bit"
#     PAGED_ADAMW = "paged_adamw_32bit"
#     PAGED_ADAMW_8BIT = "paged_adamw_8bit"
#     PAGED_LION = "paged_lion_32bit"
#     PAGED_LION_8BIT = "paged_lion_8bit"
#     RMSPROP = "rmsprop"
#     RMSPROP_BNB = "rmsprop_bnb"
#     RMSPROP_8BIT = "rmsprop_bnb_8bit"
#     RMSPROP_32BIT = "rmsprop_bnb_32bit"
#     GALORE_ADAMW = "galore_adamw"
#     GALORE_ADAMW_8BIT = "galore_adamw_8bit"
#     GALORE_ADAFACTOR = "galore_adafactor"
#     GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
#     GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
#     GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
    

In [10]:
from balm.training.training_arguments import TrainingArguments



In [11]:
training_args = TrainingArguments("./")

In [12]:
training_args.device



device(type='mps')

In [13]:
# class Trainer:
#     def __init__(
#         self, 
#         model: nn.Module, 
#         data_collator: DataCollator, 
#         optimizer: Optional[Optimizer] = None,
#         train_dataset: Optional[Dataset] = None,
#         eval_dataset: Optional[Dataset] = None,
#         batch_size: int = 32,
#         eval_batch_size: int = 32,
#     ):
#         self.model = model
#         self.data_collator = data_collator
#         self.optimizer = optimizer
#         self.batch_size = batch_size
#         self.eval_batch_size = eval_batch_size
#         self.train_dataset = train_dataset
#         self.eval_dataset = eval_dataset

#         self._device = None
#         self._train_dataloader = None
#         self._eval_dataloader = None

#     @property
#     def device(self):
#         """
#         The device to run the model on. 
#         Will check for CUDA, MPS (Apple Silicon), and CPU in that order.
#         """
#         if self._device is None:
#             if torch.cuda.is_available():
#                 self._device = torch.device("cuda")
#             elif torch.backends.mps.is_available():
#                 self._device = torch.device("mps")
#             else:
#                 self._device = torch.device("cpu")
#         return self._device
    
#     @property
#     def train_dataloader(self):
#         if self._train_dataloader is None:
#             self._train_dataloader = DataLoader(
#                 self.train_dataset,
#                 batch_size=self.batch_size,
#                 shuffle=True,
#             )
#         return self._train_dataloader
    
#     @property
#     def eval_dataloader(self):
#         if self._eval_dataloader is None:
#             self._eval_dataloader = DataLoader(
#                 self.eval_dataset,
#                 batch_size=self.eval_batch_size,
#                 shuffle=False,
#             )
#         return self._eval_dataloader
        

#     def train(self, dataloader):
#         self.model.train()
#         total_loss = 0
#         for batch in tqdm(dataloader):
#             inputs = self.collator(batch)
#             outputs = self.model(**inputs)
#             loss = outputs["loss"]
#             loss.backward()
#             self.optimizer.step()
#             self.optimizer.zero_grad()
#             total_loss += loss.item()
#         return total_loss / len(dataloader)

In [14]:
from balm.training.trainer import Trainer



In [15]:
trainer = Trainer(
    model=model, 
    data_collator=collator, 
    train_dataset=tokenized_dataset["train"], 
    max_steps=1000,
    per_device_train_batch_size=32,
    # per_device_eval_batch_size=32,
)



In [16]:
trainer.n_train_steps

1000

In [20]:
for dl in trainer.train_dataloader:
    print(dl.keys())
    break

dict_keys(['text', 'input_ids', 'attention_mask'])


In [17]:
trainer.train()

  0%|          | 0/1000 [00:00<?, ?it/s]

{'text': ['EVQLVESGGVVVQPGGSLRLSCAASGFTFDDYAMHWVRQAPGKGLEWVSLISWDGGSTYYADSVKGRFTISRDNSKNSLYLQMNSLRAEDTALYYCAKDISGLTHPGYYDSSGYYSLGSWGQGTLVTVSS<cls><cls>QSALTQPASVSGSPGQSITISCTGTSSDVGGYNYVSWYQQHPGKAPKLMIYDVSNRPSGVSNRFSGSKSGNTASLTISGLQAEDEADYYCSSYTSSSTLAFGGGTKLTVL', 'QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARDLQGNQYSSGWSWGQGTLVTVSS<cls><cls>QSVLTQPPSVSAAPGQKVTISCSGSSSNIGNNYVSWYQQLPGTAPKLLIYDNNKRPSGIPDRFSGSKSGTSATLGITGLQTGDEADYYCGTWDSSLSAVVFGGGTKLTVL', 'QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARVRQYCSSTSCYLPDAFDIWGQGTMVTVSS<cls><cls>DIQMTQSPSTLSASVGDRVTITCRASQSISSWLAWYQQKPGKAPKLLIYKASSLESGVPSRFSGSGSGTEFTLTISSLQPDDFATYYCQQYNSYSRTFGQGTKVEIK', 'EVQLLESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKGDSMGFKPALFDYWGQGTLVTVSS<cls><cls>SYVLTQPPSVSVAPGQTARITCGGNNIGSKSVHWYQQKPGQAPVLVVYDDSDRPSGIPERFSGSNSGNTATLTISRVEAGDEADYYCQVWDSSSDHPVFGGGTKLTVL',

KeyError: 0

In [12]:
n_epochs = 10

model.train()
# pbar = tqdm(total=len(train_dataloader) * train_dataloader.batch_size * n_epochs)
pbar = tqdm(total=len(train_dataloader) * n_epochs)
n_steps = 0
pbar.reset()
for epoch in range(n_epochs):
    for examples in train_dataloader:
        optimizer.zero_grad()
        collated = collator(examples["input_ids"])

        input_ids = collated["input_ids"].to(device)
        labels = collated.get("labels", None)
        if labels is not None:
            labels = labels.to(device)
        attention_mask = collated.get("attention_mask", None)
        if attention_mask is not None:
            attention_mask = attention_mask.to(device)

        outputs = model(input_ids=input_ids, labels=labels, key_padding_mask=attention_mask)

        # outputs = model(
        #     input_ids=collated["input_ids"],
        #     labels=collated.get("labels", None),
        #     key_padding_mask=collated.get("attention_mask", None),
        # )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        # pbar.update(train_dataloader.batch_size)
        pbar.update(1)
        pbar.refresh()
        n_steps += 1
        if n_steps % 5 == 0:
            print(
                f"step: {n_steps}, total Loss: {loss.item():.4f}, LM loss: {outputs.lm_loss.item():.4f}, router z loss: {outputs.router_z_loss.item():.4f}, router aux loss: {outputs.router_aux_loss.item():.4f}  "
            )

  0%|          | 0/320 [00:00<?, ?it/s]

step: 5, total Loss: 0.3494, LM loss: 0.3329, router z loss: 0.0066, router aux loss: 0.0099  
step: 10, total Loss: 0.3194, LM loss: 0.3092, router z loss: 0.0005, router aux loss: 0.0098  
step: 15, total Loss: 0.3203, LM loss: 0.3099, router z loss: 0.0004, router aux loss: 0.0100  
step: 20, total Loss: 0.3068, LM loss: 0.2969, router z loss: 0.0003, router aux loss: 0.0096  
step: 25, total Loss: 0.2884, LM loss: 0.2781, router z loss: 0.0007, router aux loss: 0.0096  
step: 30, total Loss: 0.3100, LM loss: 0.2995, router z loss: 0.0008, router aux loss: 0.0097  
step: 35, total Loss: 0.2651, LM loss: 0.2547, router z loss: 0.0006, router aux loss: 0.0098  
step: 40, total Loss: 0.2867, LM loss: 0.2759, router z loss: 0.0007, router aux loss: 0.0100  
step: 45, total Loss: 0.2733, LM loss: 0.2627, router z loss: 0.0007, router aux loss: 0.0099  
step: 50, total Loss: 0.2584, LM loss: 0.2482, router z loss: 0.0005, router aux loss: 0.0097  


KeyboardInterrupt: 