In [25]:
from dataclasses import asdict, dataclass, field, fields
from enum import Enum, StrEnum
from typing import Optional, Tuple, List, Dict, Any, Iterable, Union

from balm.data import load_dataset, DataCollator, Dataset
from balm.models.balm import BalmForMaskedLM
from balm.models.balm_moe import BalmMoEForMaskedLM
from balm.models.balm_moe_rope import BalmMoERoPEForMaskedLM
from balm.tokenizer import Tokenizer

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Optimizer
from torch.utils.data import DataLoader

In [3]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [4]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [5]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    # remove_columns="text"
)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [6]:
collator = DataCollator(tokenizer=tokenizer)

In [7]:
train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=32,
    shuffle=True,
)



In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")



Using device: mps


In [10]:
# model = BalmMoERoPEForMaskedLM(
model = BalmMoEForMaskedLM(
    embed_dim=256,
    ffn_dim=1024,
    num_experts=4,
    num_layers=8,
    num_heads=8,
    expert_capacity=128,
    router_z_loss_coef=0.01,
    router_aux_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=4e-4)



In [11]:
model.num_parameters

18982689

## Trainer todos:
* train/eval dataloader
* move everything to the correct device (cuda or cpu)
* learning rate (including scheduler)
* gradient accumulation steps
* logging
* checkpointing
* distributed training
* gradient clipping



In [16]:
class IterationStrategy(Enum):
    STEPS = "steps"
    EPOCHS = "epochs"



In [22]:
eval_strategy = IterationStrategy("steps")

In [23]:
eval_strategy == IterationStrategy.STEPS



True

In [55]:
class ExplicitEnum(Enum):
    def __init__(self, value):
        self._value = value

    def __repr__(self):
        return f"{self.__class__.__name__}.{self.name}"
    
    @classmethod
    def _missing_(cls, value):
        raise ValueError(
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
        )


class IntervalStrategy(ExplicitEnum):
    NO = "no"
    STEPS = "steps"
    EPOCHS = "epochs"


class PaddingStrategy(ExplicitEnum):
    LONGEST = "longest"
    MAX_LENGTH = "max_length"
    DO_NOT_PAD = "do_not_pad"


class SchedulerType(ExplicitEnum):
    LINEAR = "linear"
    COSINE = "cosine"
    COSINE_WITH_RESTARTS = "cosine_with_restarts"
    POLYNOMIAL = "polynomial"
    CONSTANT = "constant"
    CONSTANT_WITH_WARMUP = "constant_with_warmup"
    INVERSE_SQRT = "inverse_sqrt"
    REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"


class OptimizerNames(ExplicitEnum):
    """
    Stores the acceptable string identifiers for optimizers.
    """

    ADAMW_HF = "adamw_hf"
    ADAMW_TORCH = "adamw_torch"
    ADAMW_TORCH_FUSED = "adamw_torch_fused"
    ADAMW_TORCH_XLA = "adamw_torch_xla"
    ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
    ADAMW_APEX_FUSED = "adamw_apex_fused"
    ADAFACTOR = "adafactor"
    ADAMW_ANYPRECISION = "adamw_anyprecision"
    SGD = "sgd"
    ADAGRAD = "adagrad"
    ADAMW_BNB = "adamw_bnb_8bit"
    ADAMW_8BIT = "adamw_8bit"  # just an alias for adamw_bnb_8bit
    LION_8BIT = "lion_8bit"
    LION = "lion_32bit"
    PAGED_ADAMW = "paged_adamw_32bit"
    PAGED_ADAMW_8BIT = "paged_adamw_8bit"
    PAGED_LION = "paged_lion_32bit"
    PAGED_LION_8BIT = "paged_lion_8bit"
    RMSPROP = "rmsprop"
    RMSPROP_BNB = "rmsprop_bnb"
    RMSPROP_8BIT = "rmsprop_bnb_8bit"
    RMSPROP_32BIT = "rmsprop_bnb_32bit"
    GALORE_ADAMW = "galore_adamw"
    GALORE_ADAMW_8BIT = "galore_adamw_8bit"
    GALORE_ADAFACTOR = "galore_adafactor"
    GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
    GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
    GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
    

In [53]:
ps = PaddingStrategy("max_length")

In [54]:
ps

PaddingStrategy.MAX_LENGTH

In [None]:
@dataclass
class TrainingArguments:
    """
    Arguments for training a model.

    Parameters
    ----------
    output_dir : str
        The output directory where the model predictions and checkpoints will be written.

    overwrite_output_dir : bool, default=False
        Overwrite the content of the output directory. Use this to continue training
        if output_dir points to a checkpoint directory.

    evaluation_strategy : Union[IntervalStrategy, str], default="no"
        The evaluation strategy to use. Possible values are:

            - `"no"`: No evaluation is done during training.
            - `"steps"`: Evaluation is done (and logged) every `eval_steps`.
            - `"epoch"`: Evaluation is done at the end of each epoch.

    per_device_train_batch_size : int, default=8
        Batch size per GPU/MPS core or CPU for training.

    per_device_eval_batch_size : int, default=8
        Batch size per GPU/MPS core or CPU for evaluation.

    gradient_accumulation_steps : int, default=1
        Number of updates steps to accumulate before performing a backward/update pass.

        ..note::
            When using gradient accumulation, a step is only counted for logging, evaluation,
            and saving if it inclucdes a backward pass. Therefore, logging, evaluation, and
            save operations will be conducted every `gradient_accumulation_steps * xxx_step`
            training examples.

    learning_rate : float, default=5e-5
        The initial learning rate for AdamW.

    adam_beta1 : float, default=0.9
        Beta1 for AdamW optimizer.

    adam_beta2 : float, default=0.999
        Beta2 for AdamW optimizer.

    adam_epsilon : float, default=1e-8
        Epsilon for AdamW optimizer.

    max_grad_norm : float, default=1.0
        Max gradient norm  (for gradient clipping).

    max_epochs : int, default=3
        Total number of training epochs to perform.

    max_steps : int, default=-1
        If set to a positive number, the total number of training steps to perform.
        Overrides `num_train_epochs`. For a finite dataset, training is reiterated
        through the dataset (if all data is exhausted) until `max_steps` is reached.

    lr_scheduler_type : Union[SchedulerType, str], default="linear"
        The scheduler type to use.

    lr_scheduler_kwargs : Dict, default={}
        Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the
        cosine with hard restarts.

    warmup_ratio : float, default=0.0
        Linear warmup over warmup_ratio fraction of total steps.

    warmup_steps : int, default=0
        Linear warmup over warmup_steps.

    logging_dir : str, default=None
        Tensorboard log dir.

    logging_strategy : Union[IntervalStrategy, str], default="steps"
        The logging strategy to use.

    logging_first_step : bool, default=False
        Whether to log the first global_step.

    logging_steps : float, default=500
        Log every X updates steps. Should be an integer or a float in range `[0,1)`.
        If smaller than 1, will be interpreted as ratio of total training steps.

    save_strategy : Union[IntervalStrategy, str], default="steps"
        The saving strategy to use.

    save_steps : float, default=500
        Save every X updates steps. Should be an integer or a float in range `[0,1)`.
        If smaller than 1, will be interpreted as ratio of total training steps.

    save_total_limit : int, default=None
        Maximum number of checkpoints to save.

    use_cpu : bool, default=False
        Whether to use CPU instead of GPU.

    seed : int, default=42
        Random seed that will be set for reproducibility.

    data_seed : int, default=None
        Random seed to be used with data samplers. If not set, random generators
        for data sampling will use the same seed as `seed`. This can be used to
        ensure reproducibility of data sampling, independent of the model seed.

    fp16 : bool, default=False
        Whether to use 16-bit precision instead of 32-bit.

    ddp_backend : str, default="nccl"
        Backend to use for distributed training.

    dataloader_drop_last : bool, default=False
        Whether to drop the last incomplete batch.

    eval_steps : int, default=None
        Number of steps to run the evaluation for.

    run_name : str, default=None
        Name of the run.

    disable_tqdm : bool, default=False
        Whether to disable the tqdm progress bar.

    remove_unused_columns : bool, default=False
        Whether to remove unused columns from the dataset.

    label_names : List[str], default=None
        Names of the labels.

    accelerator_config : dict, default=None
        Configuration for the accelerator.

    deepspeed : str, default=None
        Configuration for DeepSpeed.

    label_smoothing_factor : float, default=0.0
        The factor by which to smooth the labels.

    optim : dict, default=None
        The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused,
        adamw_apex_fused, adamw_anyprecision, or adafactor.

    optim_args : dict, default=None
        Arguments for the optimizer.

    report_to : str, default=None
        The platform to report the results to: wandb, tensorboard, or none.

    dataloader_pin_memory : bool, default=True
        Whether to pin memory for the dataloaders.

    dataloader_num_workers : int, default=0
        The number of workers to use for the dataloaders.

    resume_from_checkpoint : str, default=None
        The path to the checkpoint from which to resume training.

    include_inputs_for_metrics : bool, default=False
        Whether or not the inputs will be passed to the `compute_metrics` function. 
        This is intended for metrics that need inputs, predictions and references 
        for scoring calculation in Metric class.
        
    """

    output_dir: str = field(
        metadata={
            "help": "The output directory where the model predictions and checkpoints will be written."
        },
    )
    overwrite_output_dir: bool = field(
        default=False,
        metadata={
            "help": (
                "Overwrite the content of the output directory. "
                "Use this to continue training if output_dir points to a checkpoint directory."
            )
        },
    )
    evaluation_strategy: Union[IntervalStrategy, str] = field(
        default="no",
        metadata={"help": "The evaluation strategy to use."},
    )
    per_device_train_batch_size: int = field(
        default=8,
        metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."},
    )
    per_device_eval_batch_size: int = field(
        default=8,
        metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for evaluation."},
    )
    gradient_accumulation_steps: int = field(
        default=1,
        metadata={
            "help": "Number of updates steps to accumulate before performing a backward/update pass."
        },
    )
    eval_delay: Optional[float] = field(
        default=0,
        metadata={
            "help": (
                "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the"
                " evaluation_strategy."
            )
        },
    )

    learning_rate: float = field(
        default=5e-5, metadata={"help": "The initial learning rate for AdamW."}
    )
    weight_decay: float = field(
        default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}
    )
    adam_beta1: float = field(
        default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}
    )
    adam_beta2: float = field(
        default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}
    )
    adam_epsilon: float = field(
        default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
    )
    max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})

    num_train_epochs: float = field(
        default=3.0, metadata={"help": "Total number of training epochs to perform."}
    )
    max_steps: int = field(
        default=-1,
        metadata={
            "help": "If > 0: set total number of training steps to perform. Override num_train_epochs."
        },
    )
    lr_scheduler_type: Union[SchedulerType, str] = field(
        default="linear",
        metadata={"help": "The scheduler type to use."},
    )
    lr_scheduler_kwargs: Optional[Dict] = field(
        default_factory=dict,
        metadata={
            "help": (
                "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts"
            )
        },
    )
    warmup_ratio: float = field(
        default=0.0,
        metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."},
    )
    warmup_steps: int = field(
        default=0, metadata={"help": "Linear warmup over warmup_steps."}
    )
    logging_dir: Optional[str] = field(
        default=None, metadata={"help": "Tensorboard log dir."}
    )
    logging_strategy: Union[IntervalStrategy, str] = field(
        default="steps",
        metadata={"help": "The logging strategy to use."},
    )
    logging_first_step: bool = field(
        default=False, metadata={"help": "Log the first global_step"}
    )
    logging_steps: float = field(
        default=500,
        metadata={
            "help": (
                "Log every X updates steps. Should be an integer or a float in range `[0,1)`. "
                "If smaller than 1, will be interpreted as ratio of total training steps."
            )
        },
    )
    save_strategy: Union[IntervalStrategy, str] = field(
        default="steps",
        metadata={"help": "The checkpoint save strategy to use."},
    )
    save_steps: float = field(
        default=500,
        metadata={
            "help": (
                "Save checkpoint every X updates steps. Should be an integer or a float in range `[0,1)`. "
                "If smaller than 1, will be interpreted as ratio of total training steps."
            )
        },
    )
    save_total_limit: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in"
                " `output_dir`. When `load_best_model_at_end` is enabled, the 'best' checkpoint according to"
                " `metric_for_best_model` will always be retained in addition to the most recent ones. For example,"
                " for `save_total_limit=5` and `load_best_model_at_end=True`, the four last checkpoints will always be"
                " retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end=True`,"
                " it is possible that two checkpoints are saved: the last one and the best one (if they are different)."
                " Default is unlimited checkpoints"
            )
        },
    )
    use_cpu: bool = field(
        default=False,
        metadata={
            "help": " Whether or not to use cpu. If set to False, we will use cuda/tpu/mps/npu device if available."
        },
    )
    seed: int = field(
        default=42,
        metadata={"help": "Random seed that will be set at the beginning of training."},
    )
    data_seed: Optional[int] = field(
        default=None, metadata={"help": "Random seed to be used with data samplers."}
    )
    fp16: bool = field(
        default=False,
        metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
    )
    ddp_backend: Optional[str] = field(
        default=None,
        metadata={
            "help": "The backend to be used for distributed training",
            "choices": ["nccl", "gloo", "mpi", "ccl", "hccl"],
        },
    )
    dataloader_drop_last: bool = field(
        default=False,
        metadata={
            "help": "Drop the last incomplete batch if it is not divisible by the batch size."
        },
    )
    eval_steps: Optional[float] = field(
        default=None,
        metadata={
            "help": (
                "Run an evaluation every X steps. Should be an integer or a float in range `[0,1)`. "
                "If smaller than 1, will be interpreted as ratio of total training steps."
            )
        },
    )
    run_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "An optional descriptor for the run. Notably used for wandb logging."
        },
    )
    disable_tqdm: Optional[bool] = field(
        default=None,
        metadata={"help": "Whether or not to disable the tqdm progress bars."},
    )

    remove_unused_columns: Optional[bool] = field(
        default=True,
        metadata={
            "help": "Remove columns not required by the model when using an nlp.Dataset."
        },
    )
    label_names: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "The list of keys in your dictionary of inputs that correspond to the labels."
        },
    )
    # Do not touch this type annotation or it will stop working in CLI
    accelerator_config: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Config to be used with the internal Accelerator object initializtion. The value is either a "
                "accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`."
            )
        },
    )
    # Do not touch this type annotation or it will stop working in CLI
    deepspeed: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already"
                " loaded json file as a dict"
            )
        },
    )
    label_smoothing_factor: float = field(
        default=0.0,
        metadata={
            "help": "The label smoothing epsilon to apply (zero means no label smoothing)."
        },
    )

    default_optim = "adamw_torch"
    # XXX: enable when pytorch==2.0.1 comes out - we want to give it time to get all the bugs sorted out
    # if is_torch_available() and version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.1.0"):
    #     default_optim = "adamw_torch_fused"
    # and update the doc above to:
    # optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch_fused"` (for torch<2.1.0 `"adamw_torch"`):
    optim: Union[OptimizerNames, str] = field(
        default=default_optim,
        metadata={"help": "The optimizer to use."},
    )
    optim_args: Optional[str] = field(
        default=None, metadata={"help": "Optional arguments to supply to optimizer."}
    )
    report_to: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "The list of integrations to report the results and logs to."
        },
    )
    dataloader_pin_memory: bool = field(
        default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
    )
    dataloader_persistent_workers: bool = field(
        default=False,
        metadata={
            "help": "If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will increase RAM usage."
        },
    )
    resume_from_checkpoint: Optional[str] = field(
        default=None,
        metadata={
            "help": "The path to a folder with a valid checkpoint for your model."
        },
    )
    include_inputs_for_metrics: bool = field(
        default=False,
        metadata={
            "help": "Whether or not the inputs will be passed to the `compute_metrics` function."
        },
    )

In [None]:
class Trainer:
    def __init__(
        self, 
        model: nn.Module, 
        data_collator: DataCollator, 
        optimizer: Optional[Optimizer] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        batch_size: int = 32,
        eval_batch_size: int = 32,
    ):
        self.model = model
        self.data_collator = data_collator
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.eval_batch_size = eval_batch_size
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        self._device = None
        self._train_dataloader = None
        self._eval_dataloader = None

    @property
    def device(self):
        """
        The device to run the model on. 
        Will check for CUDA, MPS (Apple Silicon), and CPU in that order.
        """
        if self._device is None:
            if torch.cuda.is_available():
                self._device = torch.device("cuda")
            elif torch.backends.mps.is_available():
                self._device = torch.device("mps")
            else:
                self._device = torch.device("cpu")
        return self._device
    
    @property
    def train_dataloader(self):
        if self._train_dataloader is None:
            self._train_dataloader = DataLoader(
                self.train_dataset,
                batch_size=self.batch_size,
                shuffle=True,
            )
        return self._train_dataloader
    
    @property
    def eval_dataloader(self):
        if self._eval_dataloader is None:
            self._eval_dataloader = DataLoader(
                self.eval_dataset,
                batch_size=self.eval_batch_size,
                shuffle=False,
            )
        return self._eval_dataloader
        

    def train(self, dataloader):
        self.model.train()
        total_loss = 0
        for batch in tqdm(dataloader):
            inputs = self.collator(batch)
            outputs = self.model(**inputs)
            loss = outputs["loss"]
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            total_loss += loss.item()
        return total_loss / len(dataloader)

In [None]:




class Trainer:
    def __init__(self, model, optimizer, collator):
        self.model = model
        self.optimizer = optimizer
        self.collator = collator

    def train(self, dataloader):
        self.model.train()
        total_loss = 0
        for batch in tqdm(dataloader):
            inputs = self.collator(batch)
            outputs = self.model(**inputs)
            loss = outputs["loss"]
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            total_loss += loss.item()
        return total_loss / len(dataloader)

In [12]:
n_epochs = 10

model.train()
# pbar = tqdm(total=len(train_dataloader) * train_dataloader.batch_size * n_epochs)
pbar = tqdm(total=len(train_dataloader) * n_epochs)
n_steps = 0
pbar.reset()
for epoch in range(n_epochs):
    for examples in train_dataloader:
        optimizer.zero_grad()
        collated = collator(examples["input_ids"])

        input_ids = collated["input_ids"].to(device)
        labels = collated.get("labels", None)
        if labels is not None:
            labels = labels.to(device)
        attention_mask = collated.get("attention_mask", None)
        if attention_mask is not None:
            attention_mask = attention_mask.to(device)

        outputs = model(input_ids=input_ids, labels=labels, key_padding_mask=attention_mask)

        # outputs = model(
        #     input_ids=collated["input_ids"],
        #     labels=collated.get("labels", None),
        #     key_padding_mask=collated.get("attention_mask", None),
        # )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        # pbar.update(train_dataloader.batch_size)
        pbar.update(1)
        pbar.refresh()
        n_steps += 1
        if n_steps % 5 == 0:
            print(
                f"step: {n_steps}, total Loss: {loss.item():.4f}, LM loss: {outputs.lm_loss.item():.4f}, router z loss: {outputs.router_z_loss.item():.4f}, router aux loss: {outputs.router_aux_loss.item():.4f}  "
            )

  0%|          | 0/320 [00:00<?, ?it/s]

step: 5, total Loss: 0.3494, LM loss: 0.3329, router z loss: 0.0066, router aux loss: 0.0099  
step: 10, total Loss: 0.3194, LM loss: 0.3092, router z loss: 0.0005, router aux loss: 0.0098  
step: 15, total Loss: 0.3203, LM loss: 0.3099, router z loss: 0.0004, router aux loss: 0.0100  
step: 20, total Loss: 0.3068, LM loss: 0.2969, router z loss: 0.0003, router aux loss: 0.0096  
step: 25, total Loss: 0.2884, LM loss: 0.2781, router z loss: 0.0007, router aux loss: 0.0096  
step: 30, total Loss: 0.3100, LM loss: 0.2995, router z loss: 0.0008, router aux loss: 0.0097  
step: 35, total Loss: 0.2651, LM loss: 0.2547, router z loss: 0.0006, router aux loss: 0.0098  
step: 40, total Loss: 0.2867, LM loss: 0.2759, router z loss: 0.0007, router aux loss: 0.0100  
step: 45, total Loss: 0.2733, LM loss: 0.2627, router z loss: 0.0007, router aux loss: 0.0099  
step: 50, total Loss: 0.2584, LM loss: 0.2482, router z loss: 0.0005, router aux loss: 0.0097  


KeyboardInterrupt: 