# References:
- [Deepspeed](https://www.deepspeed.ai/tutorials/zero/)
- [Ray + Deepspeed for parallel training](https://medium.com/sage-ai/fine-tuning-large-language-models-a-guide-into-distributed-parallel-training-with-deepspeed-ray-784914926a17)
- [Llama parallel finetuning repo](https://github.com/facebookresearch/llama-recipes/blob/main/src/llama_recipes/utils/train_utils.py)

In [1]:
base_model = "NousResearch/Llama-2-7b-chat-hf"
new_model = "llama-2-7b-chat-finetuned"

In [2]:
def training_function():
    from accelerate import Accelerator, DeepSpeedPlugin
    import os
    import deepspeed
    import torch
    from datasets import load_dataset
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline, logging, get_linear_schedule_with_warmup
    from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
    from trl import SFTTrainer
    from torch.utils.data import Dataset, DataLoader
    import datasets
    import transformers
    from tqdm import tqdm
    
    # deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2)
    # accelerator = Accelerator(mixed_precision='fp16', deepspeed_plugin=deepspeed_plugin)

    # if accelerator.is_main_process:
    #     datasets.utils.logging.set_verbosity_warning()
    #     transformers.utils.logging.set_verbosity_info()
    # else:
    #     datasets.utils.logging.set_verbosity_error()
    #     transformers.utils.logging.set_verbosity_error()

    dataset = load_dataset('json', data_files='NIv2_zs_opt_task092_check_prime_classification.json')

    def combine_features(examples):
        combined_text = "Instruction: " + examples['instruction'] + " Question: " + examples['question'] + " Answer: " + examples['answer']
        return {"combined_text": combined_text}

    dataset['train'] = dataset['train'].map(combine_features)

    train_dataset = dataset["train"].map(lambda examples: {"text": examples["combined_text"]}, remove_columns=dataset["train"].column_names)
    train_dataset.set_format(type="torch", columns=["text"])

    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"


    
    # def tokenize_function(examples, max_length=2048, padding="max_length", truncation=True):
    #     return tokenizer(
    #         examples["text"], 
    #         max_length=max_length, 
    #         padding=padding, 
    #         truncation=truncation, 
    #         return_tensors="pt"
    #     )

    # train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

    # train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    use_flash_attention = False

    model = AutoModelForCausalLM.from_pretrained(
        base_model, 
        quantization_config=bnb_config,
        use_cache=False,
        use_flash_attention_2=use_flash_attention,
        torch_dtype=torch.float16
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1

    optimizer = torch.optim.Adam(model.parameters())

    num_epochs = 3
    # progress_bar = tqdm(range(num_epochs * len(train_dataloader)), disable=not accelerator.is_main_process)

    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=len(train_dataloader) * num_epochs,
    )
        
    peft_config = LoraConfig(
        lora_alpha=32,
        lora_dropout=0.1,
        r=16,
        bias="none",
        task_type="CAUSAL_LM",
    )
    
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, peft_config)

    args = TrainingArguments(
        output_dir="llama-sft",
        num_train_epochs=1,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        optim="paged_adamw_32bit",
        logging_steps=5,
        # save_strategy="epoch",
        learning_rate=5e-5,
        fp16=True,
        max_grad_norm=0.3,
        warmup_ratio=0.03,
        lr_scheduler_type="constant",
        disable_tqdm=False,
        report_to="tensorboard",
        save_steps=100,
        save_total_limit=5,
    )

    max_seq_length = 1024  # max sequence length for model and packing of the dataset

    # model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        peft_config=peft_config,
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        packing=True,
        # formatting_func=format_prompt,
        dataset_text_field="text",
        args=args,
    )


    trainer.train()

    return model

    # for epoch in range(num_epochs):
    #     model.train()
    #     for step, batch in enumerate(train_dataloader):
    #         outputs = model(**batch)
    #         loss = outputs.loss
    #         accelerator.backward(loss)
            
    #         optimizer.step()
    #         lr_scheduler.step()
    #         optimizer.zero_grad()
    #         progress_bar.update(1)

    #     accelerator.print(f"epoch {epoch}:", loss.item())
        # model.eval()
        # all_predictions = []
        # all_labels = []

        # for step, batch in enumerate(eval_dataloader):
        #     with torch.no_grad():
        #         outputs = model(**batch)
        #     predictions = outputs.logits.argmax(dim=-1)

        #     # We gather predictions and labels from the 8 TPUs to have them all.
        #     all_predictions.append(accelerator.gather(predictions))
        #     all_labels.append(accelerator.gather(batch["labels"]))

        # # Concatenate all predictions and labels.
        # # The last thing we need to do is to truncate the predictions and labels we concatenated
        # # together as the prepared evaluation dataloader has a little bit more elements to make
        # # batches of the same size on each process.
        # all_predictions = torch.cat(all_predictions)[:len(tokenized_datasets["validation"])]
        # all_labels = torch.cat(all_labels)[:len(tokenized_datasets["validation"])]

        # eval_metric = metric.compute(predictions=all_predictions, references=all_labels)

        # Use accelerator.print to print only on the main process.
        # accelerator.print(f"epoch {epoch}:", eval_metric)

In [3]:
# class LLaMAModelWrapper(torch.nn.Module):
#     def __init__(self, model_name):
#         super().__init__()
#         self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
#         self.model.config.use_cache = False
#         self.model.config.pretraining_tp = 1
    
#     def forward(self, *args, **kwargs):
#         return self.model(*args, **kwargs)

# LLaMAModelWrapper(
#     base_model
# )


In [4]:
model = training_function()

  from .autonotebook import tqdm as notebook_tqdm


[2024-03-05 01:49:48,917] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


`low_cpu_mem_usage` was None, now set to True since model is quantized.
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.27s/it]
Generating train split: 478 examples [00:00, 828.22 examples/s]
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/peft/peft_model.py", line 1083, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1168, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 997, in forward
    layer_outputs = self._gradient_checkpointing_func(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 482, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 261, in forward
    outputs = run_function(*args)
              ^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 749, in forward
    hidden_states = self.mlp(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 236, in forward
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                                                                ^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/bitsandbytes/nn/modules.py", line 256, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py", line 577, in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tushar/.conda/envs/finetune/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py", line 516, in forward
    output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 344.00 MiB. GPU 0 has a total capacity of 23.68 GiB of which 6.69 MiB is free. Process 2646781 has 6.75 GiB memory in use. Including non-PyTorch memory, this process has 16.87 GiB memory in use. Of the allocated memory 14.59 GiB is allocated by PyTorch, and 1.85 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


In [None]:
from accelerate import notebook_launcher

notebook_launcher(training_function, num_processes=8)

In [12]:
# deepspeed_args = {
#     "train_batch_size": 1,
#     "gradient_accumulation_steps": 1,
#     "gradient_clipping": 1.0,
#     "fp16": {
#         "enabled": True,
#         "loss_scale": 0,
#         "initial_scale_power": 16
#     }
# }

In [21]:
# model_engine, optimizer, _, _ = deepspeed.initialize(args=deepspeed_args,
#                                                      model=model,
#                                                      model_parameters=model.parameters(),
#                                                      config="ds_config.json")


In [17]:
num_epochs = 10

for epoch in range(num_epochs):
    for batch in data_loader:
        # Assuming `batch` is a dict with 'input_ids' and 'attention_mask'
        input_ids = batch['input_ids'].to(model_engine.device)
        attention_mask = batch['attention_mask'].to(model_engine.device)

        # Forward pass
        outputs = model_engine(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss

        # Backward pass
        model_engine.backward(loss)
        model_engine.step()

OutOfMemoryError: CUDA out of memory. Tried to allocate 172.00 MiB. GPU 0 has a total capacity of 23.68 GiB of which 71.25 MiB is free. Including non-PyTorch memory, this process has 23.54 GiB memory in use. Of the allocated memory 22.93 GiB is allocated by PyTorch, and 36.80 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [18]:
model_engine.device

device(type='cuda', index=0)

In [None]:
# Save the model
model_engine.save_checkpoint(new_model)
