In [1]:
import os
from pathlib import Path

import torch
from datatrove.utils.dataset import DatatroveFolderDataset
from torch import Tensor
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments

from src.optim import get_wsd_scheduler
from src.utilities import get_logger


In [2]:
# Configure the logger and configure colorlog
logger = get_logger("training", "info")

In [5]:
tok_path = Path("/home/pl487/rdd/tokenizer_train/2024-08-30T12-00-43/tok-vocab32000/")
tok = AutoTokenizer.from_pretrained(str(tok_path), clean_up_tokenization_spaces=False)

In [25]:
tok(["ciao"])

{'input_ids': [[67, 24885]], 'token_type_ids': [[0, 0]], 'attention_mask': [[1, 1]]}

In [17]:
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

In [23]:
# Adapted from SmolLM
# https://huggingface.co/HuggingFaceTB/SmolLM-135M/blob/main/config.json
config = LlamaConfig(
    attention_bias=False,
    attention_dropout=0.0,
    bos_token_id=tok.eos_token_id,
    eos_token_id=tok.eos_token_id,
    hidden_act="silu",
    hidden_size=512,
    intermediate_size=1024,
    initializer_range=0.02,
    max_position_embeddings=2048,
    mlp_bias=False,
    model_type="llama",
    num_attention_heads=9,
    num_hidden_layers=8,
    num_key_value_heads=3,
    pretraining_tp=1,
    rms_norm_eps=1e-05,
    rope_scaling=None,
    rope_theta=10000.0,
    tie_word_embeddings=True,
    torch_dtype="bfloat16",
    use_cache=True,
    vocab_size=tok.vocab_size,
)

model = LlamaForCausalLM(config)

logger.info(f"Memory footprint: {model.get_memory_footprint() / 1e6:.2f} MB")
logger.info(f"Num parameters: {model.num_parameters() / 1e6:.1f}M")


[[36m2024-08-30 15:51:43,292[0m][[34mtraining[0m][[32mINFO[0m] - Memory footprint: 68.96 MB[0m
[[36m2024-08-30 15:51:43,294[0m][[34mtraining[0m][[32mINFO[0m] - Num parameters: 34.5M[0m


In [17]:
model = model.to("cuda")
model.forward(torch.randint(0, 10000, size=(16, 516), device="cuda")).logits.shape

torch.Size([16, 516, 32000])

In [18]:
# too many arguments, use the set methods to make things clearer
training_args = TrainingArguments(
    # =======
    # logging
    # =======
    output_dir=f"training_outputs/{hub_model_id}",
    logging_strategy="steps",
    logging_first_step=True,
    log_level="passive", # takes it from global
    logging_steps=1,
    report_to="tensorboard",
    include_num_input_tokens_seen=True,
    # =============
    # checkpointing
    # =============
    save_strategy="steps",
    save_steps=50,
    save_safetensors=True,
    # ===========
    # push to hub
    # ===========
    push_to_hub=True,
    hub_model_id=hub_model_id,
    hub_strategy="all_checkpoints",
    hub_private_repo=True,
    # =====
    # setup
    # =====
    eval_strategy="no",
    seed=42,
    bf16=True,
    bf16_full_eval=True,
    tf32=True,
    torch_compile=False,
    # ============
    # optimisation
    # ============
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    optim="adamw_torch",
    learning_rate=2e-5,
    weight_decay=0.1,
    adam_beta1=0.9,
    adam_beta2=0.95,
    adam_epsilon=1e-8,
    max_grad_norm=1.0,
    # lr_scheduler_type="wsd",
    lr_scheduler_kwargs=dict(
        final_lr_factor=0.0,
        init_div_factor=100,
        frac_decay=0.2,
        decay_type="sqrt",
    ),  # use to pass 
    warmup_steps=2_000,
    num_train_epochs=1,
    max_steps=100,
    # ===========
    # dataloading
    # ===========
    dataloader_num_workers=os.cpu_count() - 1,
)

In [19]:
class LMTrainer(Trainer):
    
    def create_optimizer(self) -> Optimizer:
        # need to set self.optimizer
        
        # Get params that require grad
        param_dict = {pn: p for pn, p in self.model.named_parameters() if p.requires_grad}
        
        # Create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': self.args.weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        
        logger.info(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        logger.info(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        
        # Create AdamW optimizer and use the fused version
        self.optimizer = AdamW(
            optim_groups, 
            lr=self.args.learning_rate, 
            betas=(self.args.adam_beta1, self.args.adam_beta2), 
            eps=self.args.adam_epsilon, 
            fused=True,
        )
        
        return self.optimizer
    
    def create_scheduler(self, num_training_steps: int, optimizer: Optimizer = None) -> LRScheduler:
        # HACK: to avoid changing too much stuff, just assume that when I pass kwargs
        # I mean that I want the wsd scheduler
        if self.args.lr_scheduler_kwargs is not None:
            return get_wsd_scheduler(
                optimizer=self.optimizer if optimizer is None else optimizer,
                num_warmup_steps=self.args.warmup_steps,
                num_training_steps=num_training_steps,
                **self.args.lr_scheduler_kwargs,
            )
        
        return super().create_scheduler(num_training_steps, optimizer)

    def get_train_dataloader(self) -> DataLoader:
        target_repo = "hf://datasets/pietrolesci/fineweb-edu-10BT"
        ds = DatatroveFolderDataset(
            folder_path=f"{target_repo}/{tok_path.name}", 
            seq_len=config.max_position_embeddings, 
            shuffle=True,
            seed=42,
            token_size=2 if config.vocab_size < 65_000 else 4,
        )

        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": self.data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": True,
            "persistent_workers": False,
            "shuffle": False,
            "drop_last": self.args.dataloader_drop_last,
        }

        return self.accelerator.prepare(DataLoader(ds, **dataloader_params))
    
    def compute_loss(self, model, inputs, return_outputs=False) -> Tensor:
        input_ids = inputs["input_ids"]
        labels = input_ids.clone()
        outputs = model(input_ids=input_ids, labels=labels)
        return outputs.loss
                
    # def compute_loss(self, model, inputs, return_outputs=False):
    #     labels = inputs.pop("labels")
    #     outputs = model(**inputs)
    #     logits = outputs.logits
    #     shift_logits = logits[..., :-1, :].contiguous()
    #     shift_labels = labels[..., 1:].contiguous()
    #     loss_fct = nn.CrossEntropyLoss()
    #     loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    #     return (loss, outputs) if return_outputs else loss

    

In [20]:
trainer = LMTrainer(model, args=training_args)

max_steps is given, it will override any value given in num_train_epochs


In [21]:
dl = trainer.get_train_dataloader()
batch = next(iter(dl))

In [26]:
batch["input_ids"].max()

tensor(65513, device='cuda:0')

In [23]:
trainer.model.forward(
    torch.randint(0, 10000, size=(16, 516), device="cuda")
)

CausalLMOutputWithPast(loss=None, logits=tensor([[[-0.0479, -0.3145, -0.7266,  ..., -0.1289, -0.3750,  0.4805],
         [ 0.1235, -0.2061, -0.6328,  ...,  0.0679, -0.3223,  0.4316],
         [-0.0593, -0.3906, -0.5352,  ...,  0.0679, -0.1445,  0.2490],
         ...,
         [-0.1289,  0.3457,  0.0393,  ..., -0.2520, -0.2422,  0.2637],
         [-0.3438,  0.6680,  0.2188,  ..., -0.6367, -0.1875,  0.2871],
         [-0.3086,  0.6289,  0.0012,  ..., -0.4023, -0.0879,  0.4375]],

        [[-0.3027,  0.1934, -0.4766,  ...,  0.3496,  0.1807,  0.0962],
         [ 0.1689,  0.0294, -0.4434,  ..., -0.0391,  0.1680,  0.2930],
         [ 0.1992, -0.1030, -0.1973,  ..., -0.0830,  0.2490,  0.2070],
         ...,
         [ 0.2119,  0.0557,  0.0378,  ...,  0.0369,  1.0156,  0.1387],
         [-0.1465, -0.0435, -0.2598,  ..., -0.0066,  0.8008,  0.1167],
         [-0.0554, -0.1514, -0.0586,  ..., -0.1484,  0.8750,  0.0771]],

        [[ 0.0388,  0.3027, -0.2793,  ...,  0.2891, -0.6914,  0.4219],
    

In [None]:
batch["input_ids"].shape

In [None]:
model = trainer.accelerator.prepare_model(model)

In [None]:
model.forward(input_ids=batch["input_ids"])

In [None]:
trainer.compute_loss(model, batch)

In [None]:
trainer.train()