In [19]:
import os
import torch
import gc
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

In [51]:
model_name = "meta-llama/Llama-2-7b-hf"  # or "meta-llama/Llama-2-7b-chat-hf"
new_model = "llama-2-7b-mcq-finetuned"

In [21]:
# Dataset files (local JSONL files)
train_file = "train_finetune.jsonl"
val_file = "valid_finetune.jsonl"

In [52]:

# Training Hyperparameters
local_rank = -1
per_device_train_batch_size = 4
per_device_eval_batch_size = 4
gradient_accumulation_steps = 1
learning_rate = 2e-4
max_grad_norm = 0.3
weight_decay = 0.001
num_train_epochs = 2
max_steps = -1  # -1 means train for num_train_epochs


In [23]:

# LoRA Parameters
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

In [24]:
# Quantization Parameters
use_4bit = True
use_nested_quant = False
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"

In [25]:
# Sequence and Training Settings
max_seq_length = 1024  # Increased for longer contexts
packing = False
gradient_checkpointing = True

In [26]:
# Precision Settings
fp16 = False
bf16 = False

In [27]:
# Optimizer and Scheduler
optim = "paged_adamw_32bit"
lr_scheduler_type = "cosine"
warmup_ratio = 0.03

In [32]:
# Logging and Saving
group_by_length = True
save_steps = 100
logging_steps = 10
output_dir = "llama2-mcq-finetuned"
report_to = "tensorboard"
tb_log_dir = "llama2-mcq-finetuned/logs"


In [33]:
# Device Configuration
device_map = {"": 0}

In [34]:
# NEFTune (adds noise to embeddings for better generalization)
use_neftune = True
neftune_noise_alpha = 5

In [53]:
print("="*80)
print("CONFIGURATION LOADED")
print("="*80)
print(f"Model: {model_name}")
print(f"Train file: {train_file}")
print(f"Val file: {val_file}")
print(f"Output dir: {output_dir}")
print("="*80 + "\n")

CONFIGURATION LOADED
Model: meta-llama/Llama-2-7b-hf
Train file: train_finetune.jsonl
Val file: valid_finetune.jsonl
Output dir: llama2-mcq-finetuned



In [55]:
print("Loading datasets from local JSONL files...")

train_dataset = load_dataset("json", data_files=train_file, split="train")
val_dataset = load_dataset("json", data_files=val_file, split="train")

Loading datasets from local JSONL files...


In [59]:
# Shuffle datasets
train_dataset = train_dataset.shuffle(seed=42)
val_dataset = val_dataset.shuffle(seed=42)

In [57]:
print(f"✓ Train samples: {len(train_dataset)}")
print(f"✓ Validation samples: {len(val_dataset)}")
print(f"✓ Columns: {train_dataset.column_names}\n")

✓ Train samples: 8653
✓ Validation samples: 2528
✓ Columns: ['Context', 'question', 'answerChoices', 'correctAnswer']



In [60]:
# Print a sample
print("="*80)
print("SAMPLE DATA:")
print("="*80)
sample = train_dataset[0]
print(f"Context (first 200 chars): {sample['Context'][:200]}...")
print(f"Question: {sample['question']}")
print(f"Answer Choices: {sample['answerChoices']}")
print(f"Correct Answer: {sample['correctAnswer']}")
print("="*80 + "\n")

SAMPLE DATA:
Context (first 200 chars): Fossils are the preserved remains of animals, plants, and other organisms from the distant past. Examples of fossils include bones, teeth, and impressions. By studying fossils, evidence for evolution ...
Question: the fossil record provides evidence for
Answer Choices: (A) when organisms lived on earth.. (B) how some species have gone extinct.. (C) how species evolved.. (D) all of the above.
Correct Answer: D



In [61]:
# ============================================================================
# LOAD MODEL AND TOKENIZER
# ============================================================================
def load_model(model_name):
    """Load tokenizer and model with QLoRA configuration"""
    
    # Compute dtype
    compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

    # Quantization Config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=use_4bit,
        bnb_4bit_quant_type=bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=use_nested_quant,
    )

    # Check for bfloat16 support
    if compute_dtype == torch.float16 and use_4bit:
        major, _ = torch.cuda.get_device_capability()
        if major >= 8:
            print("=" * 80)
            print("Your GPU supports bfloat16, you can accelerate training with bf16=True")
            print("=" * 80)

    print(f"Loading model: {model_name}...")
    
    # Load Model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map=device_map,
        quantization_config=bnb_config,
        trust_remote_code=True,
    )

    model.config.use_cache = False
    model.config.pretraining_tp = 1

    # Load LoRA configuration
    peft_config = LoraConfig(
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        r=lora_r,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # Load Tokenizer
    print(f"Loading tokenizer: {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    print("✓ Model and tokenizer loaded successfully!\n")

    return model, tokenizer, peft_config


In [62]:
# Load model, tokenizer, and peft config
model, tokenizer, peft_config = load_model(model_name)


Your GPU supports bfloat16, you can accelerate training with bf16=True
Loading model: meta-llama/Llama-2-7b-hf...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading tokenizer: meta-llama/Llama-2-7b-hf...
✓ Model and tokenizer loaded successfully!



In [None]:
# ============================================================================
# DATA FORMATTING FUNCTION
# ============================================================================
def formatting_func(example):
    """
    Format examples for Llama2 instruction fine-tuning
    Using Llama2-chat format with system prompt
    """
    output_texts = []
    
    for i in range(len(example['question'])):
        context = str(example["Context"][i])
        options = str(example["answerChoices"][i])
        question = str(example["question"][i])
        answer = str(example["correctAnswer"][i])
        
        # Llama2-chat format with system prompt
        template = f"""<s>[INST]<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible using the context text provided. Your answer must be exactly one letter corresponding to the correct option. Do not include any explanation, punctuation, or extra text.\n<</SYS>>

        Context: {context}
        Question: {question}
        Options: {options}
        Answer: [/INST]{answer}</s>"""
    
        
        output_texts.append(template)
    
    return output_texts


In [47]:
# ============================================================================
# DATA COLLATOR (For Completion Only)
# ============================================================================
# This ensures we only compute loss on the answer portion, not the instruction
response_template = "[/INST]"
collator = DataCollatorForCompletionOnlyLM(
    tokenizer.encode(response_template, add_special_tokens=False)[2:], 
    tokenizer=tokenizer
)

In [48]:

print("✓ Data formatting and collator configured\n")

✓ Data formatting and collator configured



In [49]:
# ============================================================================
# TRAINING ARGUMENTS
# ============================================================================
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to=report_to,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    auto_find_batch_size=True,  # Automatically find optimal batch size if OOM
    save_total_limit=2,
)

print("✓ Training arguments configured\n")

✓ Training arguments configured



In [63]:
# ============================================================================
# INITIALIZE TRAINER (Updated for TRL 1.0.0+ compatibility)
# ============================================================================
print("Initializing SFTTrainer...")

# Try using SFTConfig (for newer TRL versions)
try:
    from trl import SFTConfig
    
    sft_config = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        optim=optim,
        save_steps=save_steps,
        logging_steps=logging_steps,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        fp16=fp16,
        bf16=bf16,
        max_grad_norm=max_grad_norm,
        max_steps=max_steps,
        warmup_ratio=warmup_ratio,
        group_by_length=group_by_length,
        lr_scheduler_type=lr_scheduler_type,
        report_to=report_to,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        auto_find_batch_size=True,
        save_total_limit=2,
        # SFT-specific parameters
        max_seq_length=max_seq_length,
        packing=packing,
        neftune_noise_alpha=neftune_noise_alpha if use_neftune else None,
    )
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        peft_config=peft_config,
        tokenizer=tokenizer,
        args=sft_config,
        data_collator=collator,
        formatting_func=formatting_func,
    )
    print("✓ Using SFTConfig (TRL 1.0.0+ compatible)")
    
except ImportError:
    # Fallback for older TRL versions
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        peft_config=peft_config,
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        args=training_arguments,
        packing=packing,
        data_collator=collator,
        formatting_func=formatting_func,
        neftune_noise_alpha=neftune_noise_alpha if use_neftune else None,
    )
    print("✓ Using legacy SFTTrainer initialization (pre-1.0.0)")

print("✓ Trainer initialized successfully!\n")

Initializing SFTTrainer...


Map:   0%|          | 0/8653 [00:00<?, ? examples/s]

Map:   0%|          | 0/2528 [00:00<?, ? examples/s]

✓ Using SFTConfig (TRL 1.0.0+ compatible)
✓ Trainer initialized successfully!



In [64]:
# ============================================================================
# TRAIN THE MODEL
# ============================================================================
print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80 + "\n")




STARTING TRAINING



In [65]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.7874,2.393914
2,0.4038,2.584598
3,0.5095,2.93459




KeyError: 'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight'

In [1]:
"""
Recover the best model from saved checkpoints
When load_best_model_at_end fails due to 4-bit quantization issues
"""

import os
import json
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

# ============================================================================
# CONFIGURATION
# ============================================================================
OUTPUT_DIR = "./llama2-mcq-finetuned"
BEST_MODEL_DIR = "./llama2-mcq-best"

print("="*80)
print("FINDING BEST CHECKPOINT")
print("="*80)

# ============================================================================
# FIND ALL CHECKPOINTS AND THEIR METRICS
# ============================================================================
checkpoints_info = []

if os.path.exists(OUTPUT_DIR):
    checkpoints = [d for d in os.listdir(OUTPUT_DIR) if d.startswith("checkpoint-")]
    checkpoints.sort(key=lambda x: int(x.split("-")[1]))
    
    print(f"Found {len(checkpoints)} checkpoints:\n")
    
    for checkpoint in checkpoints:
        checkpoint_path = os.path.join(OUTPUT_DIR, checkpoint)
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.json")
        
        if os.path.exists(trainer_state_path):
            with open(trainer_state_path) as f:
                state = json.load(f)
                
                # Get epoch and eval loss
                epoch = state.get("epoch")
                
                # Find eval loss from log history
                eval_loss = None
                for log_entry in reversed(state.get("log_history", [])):
                    if "eval_loss" in log_entry:
                        eval_loss = log_entry["eval_loss"]
                        break
                
                checkpoints_info.append({
                    'checkpoint': checkpoint,
                    'path': checkpoint_path,
                    'epoch': epoch,
                    'eval_loss': eval_loss,
                    'step': int(checkpoint.split("-")[1])
                })
                
                print(f"  {checkpoint}")
                print(f"    Epoch: {epoch}")
                print(f"    Eval Loss: {eval_loss}")
                print()

# ============================================================================
# FIND BEST CHECKPOINT (LOWEST EVAL LOSS)
# ============================================================================
if checkpoints_info:
    # Filter out checkpoints without eval_loss
    valid_checkpoints = [c for c in checkpoints_info if c['eval_loss'] is not None]
    
    if valid_checkpoints:
        best_checkpoint = min(valid_checkpoints, key=lambda x: x['eval_loss'])
        
        print("="*80)
        print("BEST CHECKPOINT IDENTIFIED")
        print("="*80)
        print(f"Checkpoint: {best_checkpoint['checkpoint']}")
        print(f"Epoch: {best_checkpoint['epoch']}")
        print(f"Eval Loss: {best_checkpoint['eval_loss']}")
        print(f"Path: {best_checkpoint['path']}")
        print("="*80 + "\n")
        
        # ====================================================================
        # LOAD AND SAVE BEST MODEL
        # ====================================================================
        print("Loading best checkpoint...")
        
        model = AutoPeftModelForCausalLM.from_pretrained(
            best_checkpoint['path'],
            device_map="auto",
            torch_dtype="auto",
            low_cpu_mem_usage=True
        )
        
        print("✓ Model loaded successfully!")
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(best_checkpoint['path'])
        
        # Save to new location
        print(f"\nSaving best model to: {BEST_MODEL_DIR}")
        model.save_pretrained(BEST_MODEL_DIR)
        tokenizer.save_pretrained(BEST_MODEL_DIR)
        
        print("✓ Best model saved!")
        
        # Save metadata
        metadata = {
            'source_checkpoint': best_checkpoint['checkpoint'],
            'epoch': best_checkpoint['epoch'],
            'eval_loss': best_checkpoint['eval_loss'],
            'step': best_checkpoint['step']
        }
        
        with open(os.path.join(BEST_MODEL_DIR, "best_model_info.json"), "w") as f:
            json.dump(metadata, f, indent=2)
        
        print(f"✓ Metadata saved to {BEST_MODEL_DIR}/best_model_info.json")
        
        print("\n" + "="*80)
        print("SUCCESS!")
        print("="*80)
        print(f"Your best model is saved at: {BEST_MODEL_DIR}")
        print("\nTo use it:")
        print("  from peft import AutoPeftModelForCausalLM")
        print(f"  model = AutoPeftModelForCausalLM.from_pretrained('{BEST_MODEL_DIR}')")
        print("="*80)
        
    else:
        print("❌ No checkpoints with evaluation metrics found!")
else:
    print(f"❌ No checkpoints found in {OUTPUT_DIR}")



FINDING BEST CHECKPOINT
Found 2 checkpoints:

  checkpoint-8653
    Epoch: 1.0
    Eval Loss: 2.393913984298706

  checkpoint-25959
    Epoch: 3.0
    Eval Loss: 2.9345898628234863

BEST CHECKPOINT IDENTIFIED
Checkpoint: checkpoint-8653
Epoch: 1.0
Eval Loss: 2.393913984298706
Path: ./llama2-mcq-finetuned/checkpoint-8653

Loading best checkpoint...




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Model loaded successfully!

Saving best model to: ./llama2-mcq-best
✓ Best model saved!
✓ Metadata saved to ./llama2-mcq-best/best_model_info.json

SUCCESS!
Your best model is saved at: ./llama2-mcq-best

To use it:
  from peft import AutoPeftModelForCausalLM
  model = AutoPeftModelForCausalLM.from_pretrained('./llama2-mcq-best')


In [None]:
#need to change the path
# ============================================================================
# EVALUATE THE MODEL
# ============================================================================
print("Evaluating model on validation set...")
evaluation_results = trainer.evaluate()
print("\n" + "="*80)
print("EVALUATION RESULTS:")
print("="*80)
for key, value in evaluation_results.items():
    print(f"{key}: {value}")
print("="*80 + "\n")

In [None]:
# OPTIONAL: MERGE AND SAVE FULL MODEL (For Inference)
# ============================================================================
# Note: This step requires ~14GB VRAM for Llama-2-7b
# You can skip this and merge later on your local machine

MERGE_NOW = False  # Set to True if you want to merge on cloud

if MERGE_NOW:
    print("Merging LoRA weights with base model for inference...")
    
    # Clear VRAM
    del model
    del trainer
    gc.collect()
    torch.cuda.empty_cache()
    
    # Reload base model in FP16
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        low_cpu_mem_usage=True,
        return_dict=True,
        torch_dtype=torch.float16,
        device_map=device_map,
    )
    
    # Load and merge LoRA weights
    model = PeftModel.from_pretrained(base_model, output_dir)
    model = model.merge_and_unload()
    
    # Reload tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    # Save merged model
    merged_output_dir = f"{output_dir}_merged"
    print(f"Saving merged model to {merged_output_dir}...")
    model.save_pretrained(merged_output_dir)
    tokenizer.save_pretrained(merged_output_dir)
    
    print(f"✓ Merged model saved to: {merged_output_dir}")
else:
    print("Skipping merge step. You can merge later on your local machine.")

print("\n" + "="*80)
print("ALL DONE!")
print("="*80)
print(f"✓ LoRA adapter saved to: {output_dir}")
print(f"  Size: ~100-400MB (download this!)")
print("\nTo use the model:")
print("\n1. With LoRA adapter (recommended for now):")
print(f"   from peft import AutoPeftModelForCausalLM")
print(f"   model = AutoPeftModelForCausalLM.from_pretrained('{output_dir}')")
print("\n2. Or merge later on your local machine (see code comments)")
print("="*80)

In [None]:
#do not use yet.
from tqdm import tqdm

def evaluate_mcq(model, tokenizer, dataset, max_samples=100):
    correct = 0
    total = min(len(dataset), max_samples)
    
    for i in tqdm(range(total)):
        example = dataset[i]
        context = example["Context"]
        question = example["question"]
        choices = example["answerChoices"]
        correct_answer = example["correctAnswer"]

        prompt = f"""<s>[INST]<<SYS>>
You are a helpful assistant that answers multiple choice questions.
<</SYS>>

Context:
{context}

Question:
{question}

Options:
{choices}
Your answers should only be the choice from the given multiple Options and not have any text after the answer is done.
Your answer: [/INST]"""

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(**inputs, max_new_tokens=5)
        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        pred = decoded.strip()[-1].upper()

        if pred == correct_answer.upper():
            correct += 1
    
    return correct / total
