In [None]:
# Import libraries
import os
import torch
import torch.nn as nn

# We will be using hugging face transformers library here
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig

# For PEFT we will be using peft library
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 

# datasets library for loading data
from datasets import load_dataset

# We will be using "trl" library for SFT finetuning 
# SFT stands for Supervised Fine-Tuning
from trl import SFTTrainer, SFTConfig

# Login to Hugging Face
# We need to use HF token to access certain models 
from huggingface_hub import login
hf_token = ""
login(token= hf_token)

In [1]:
'''
What is SFT Finetuning
'''
# Supervised Fine-Tuning (SFT) involves training a pretrained LLM on a dataset of input-output pairs, where each example includes a prompt (instruction, query, etc.) 
# and a desired response. This helps the model learn how to better follow instructions or generate task-specific outputs.

# SFT is a commonly used technique in training/fine-tuning large language models (LLMs), especially for instruction-following tasks.

'''
What is bitsandbytes
'''
# The bitsandbytes library is an open-source CUDA-based optimization library used  to reduce GPU memory usage and increase training speed.

'\nWhat is bitsandbytes\n'

In [None]:
# Enable HF transfer
# Accelerate model, dataset, and tokenizer downloads from the Hugging Face Hub
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'

# load dataset
ds = load_dataset("ruslanmv/ai-medical-chatbot")

# Model to be used
model_name = "microsoft/Phi-3-mini-4k-instruct"

# Save directory, adjust if needed
save_directory = "./cache"

In [None]:
'''
Why "nf4" over 4-bit quantization
'''
# Compared to fp4 (FloatPoint 4-bit), nf4 gives:
#   1. Better distribution of values
#   2. Better downstream performance

In [None]:
# Configure BitsAndBytes
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, # This tells Hugging Face to load the model weights in 4-bit precision instead of full 16/32-bit
    bnb_4bit_quant_type="nf4", # nf4 stands for NormalFloat 4-bit, a more accurate 4-bit quantization format
    bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for computation, used during forward/backward passes computation, not for storing weights
    bnb_4bit_use_double_quant=True, # Double quantization for even better compression
)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name,
                                          cache_dir=save_directory,
                                          trust_remote_code=True)

# Model - IMPORTANT CHANGES HERE FOR BNB
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.bfloat16, 
                                             device_map="balanced",
                                             cache_dir=save_directory,
                                             trust_remote_code=True,
                                             quantization_config=bnb_config 
                                            )

# Prepare model for k-bit training 
model = prepare_model_for_kbit_training(model)

# Lora config (target_modules might be updated by analyze_model_layers, but good defaults for Phi-3 are 'qkv_proj', 'o_proj')
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=['qkv_proj', 'o_proj'], # Common for Phi-3. 'all-linear' is also an option.
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Get PEFT model
model = get_peft_model(model, lora_config)

# Print trainable parameters
model.print_trainable_parameters()

# Set pad token id
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Set padding side
tokenizer.padding_side = "right"

# --- analyze_model_layers function (good to keep) ---
def analyze_model_layers(model):
    """
    Analyze model layers and suggest target_modules
    """

    # Categorize layers
    linear_layers = []
    attention_layers = []
    embedding_layers = []
    layer_norm_layers = []

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            linear_layers.append(name)
        elif "attention" in name.lower():
            attention_layers.append(name)
        elif "embed" in name.lower():
            embedding_layers.append(name)
        elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
            layer_norm_layers.append(name)

    print(f"Model: {model_name}")
    print("="*50)

    print("\n🎯 LINEAR LAYERS (Best for LoRA target_modules):")
    for layer in linear_layers[:10]:  # Show first 10
        print(f"  {layer}")
    if len(linear_layers) > 10:
        print(f"  ... and {len(linear_layers) - 10} more")

    print(f"\n📊 ATTENTION LAYERS ({len(attention_layers)}):")
    for layer in attention_layers[:5]:  # Show first 5
        print(f"  {layer}")
    if len(attention_layers) > 5:
        print(f"  ... and {len(attention_layers) - 5} more")

    print(f"\n📝 EMBEDDING LAYERS ({len(embedding_layers)}):")
    for layer in embedding_layers:
        print(f"  {layer}")

    # Generate target_modules suggestion
    common_patterns = []
    for layer in linear_layers:
        # Adjusted for Phi-3's typical layer names
        if any(pattern in layer for pattern in ['qkv_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']):
            common_patterns.append(layer.split('.')[-1])
        elif any(pattern in layer for pattern in ['query', 'key', 'value', 'dense']): # More generic
             common_patterns.append(layer.split('.')[-1])

    unique_patterns = list(set(common_patterns))

    print(f"\n💡 SUGGESTED target_modules:")
    if unique_patterns:
        print(f"  {unique_patterns}")
    else:
        # Fallback suggestions
        suggestions = []
        for layer in linear_layers[:5]:
            suggestions.append(layer.split('.')[-1])
        print(f"  {list(set(suggestions))}")

    return linear_layers, attention_layers, embedding_layers
# --- End analyze_model_layers function ---

# Usage
linear_layers, attention_layers, embedding_layers = analyze_model_layers(model)

# Format instruction data (remains the same based on your preference)
def format_instruction_data(example):
    return {
        "text": f"<|user|>\nPatient's input: {example['Patient']}\n<|end|>\n<|assistant|>\n{example['Doctor']}\n<|end|>"
    }

# Apply formatting
formatted_dataset = ds.map(format_instruction_data, batched=False)

# Print first example
print(formatted_dataset['train'][0]['text'])

# Enable gradient checkpointing for memory efficiency during training
model.gradient_checkpointing_enable()

# Define training arguments - IMPORTANT CHANGES HERE FOR BNB OPTIMIZER
training_args = TrainingArguments(
    output_dir="./phi3_doctor_response_finetuned_adapters_bnb",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    save_strategy="epoch",
    logging_steps=100,
    report_to="tensorboard",
    fp16=False, # Keep False if bf16=True, as bfloat16 handles mixed precision better.
    bf16=True, # <--- Set to True if your GPU supports bfloat16 (recommended for Phi-3)
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    disable_tqdm=False
)

trainer = SFTTrainer(
    model = model,
    train_dataset=formatted_dataset['train'],
    peft_config=lora_config,
    args = SFTConfig(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps = 5,
        num_train_epochs = 3, # Set this for 1 full training run.
        #max_steps = 60,
        learning_rate = 2e-4,
        fp16 = False,
        bf16 = True,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "model_traning_outputs",
        report_to = "none",
        max_seq_length = 2048,
        dataset_text_field="text",
        dataset_num_proc = 4,
        packing = False, # Can make training 5x faster for short sequences.
    ),
)

# Start training
trainer.train()

# Save the fine-tuned model (LoRA adapters)
# output_dir = "./phi3_doctor_response_finetuned_adapters_bnb" # Adjusted output dir name
# os.makedirs(output_dir, exist_ok=True)
# trainer.model.save_pretrained(output_dir)
# tokenizer.save_pretrained(output_dir)

# print(f"Fine-tuned model adapters saved to {output_dir}")

In [None]:
'''
What is optimizer pagination (PagedAdam8bit, PagedLion8bit etc)
'''
# - Paged optimizers in the bitsandbytes library are memory-optimized variants of 8-bit optimizers designed to efficiently handle very large models by 
#   paging optimizer states in and out of memory. 
# 
# - Paging is a memory management technique where data (like optimizer states) is moved between Fast memory (GPU VRAM) and Slower memory (CPU RAM) 


''