In [None]:
import torch
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import Trainer, TrainingArguments
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

import utils

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

### Load base model with quantization settings

In [None]:
base_model = 'meta-llama/Llama-2-7b-chat-hf'
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map={"":0},
)

In [None]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

### Only fine tune adapter (LoRA)

In [None]:
config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=['q_proj', 'k_proj', 'v_proj'],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

### Add special tokens and update embeddings

In [None]:
if tokenizer.pad_token is None:
    utils.resize_token_embedding(
        special_tokens_dict=dict(pad_token=utils.PAD_TOKEN),
        tokenizer=tokenizer,
        model=model,
    )

tokenizer.add_special_tokens(
    {
        "eos_token": utils.EOS_TOKEN,
        "bos_token": utils.BOS_TOKEN,
        "unk_token": utils.UNK_TOKEN,
    }
)

### Prepare MedQuAD

In [None]:
data_path = 'data/med_qa.merged.pkl.tar.gz'
ds_medqa = Dataset.from_pandas(
    pd.read_pickle(data_path)
)

ds_train = utils.tokenize_QA_for_llm(ds_medqa, tokenizer)

### Set hyperparameters and train adapter

In [None]:
train_args = TrainingArguments(
    per_device_train_batch_size=2,
    gradient_accumulation_steps=64,
    warmup_steps=24,
    max_steps=500,
    learning_rate=2e-5,
    fp16=True,
    logging_steps=15,
    output_dir='outputs',
    overwrite_output_dir=True,
    optim='paged_adamw_8bit'
)

trainer = Trainer(
    model=model,
    train_dataset=ds_train,
    args=train_args,
    data_collator=utils.DataCollatorForLLM(tokenizer)
)
model.config.use_cache = False
trainer.train()