In [25]:
# !pip install -U \
#     torch==2.3.0+cu121 \
#     torchvision==0.18.0+cu121 \
#     torchaudio==2.3.0+cu121 \
#     bitsandbytes==0.43.3 \
#     triton==2.3.0 \
#     peft==0.10.0 \
#     trl==0.9.6 \
#     transformers==4.37.2 \
#     accelerate==0.27.2 \
#     datasets==2.16.0 \
#     evaluate==0.4.2 \
#     tensorboard==2.20.0 \
#     scipy==1.11.4 \
#     pandas==2.1.4 \
#     tqdm==4.67.1 \
#     --extra-index-url https://download.pytorch.org/whl/cu121


In [26]:
"""
Llama-2 7B Fine-tuning with RAG-Augmented Context
For Multiple Choice Question Answering

Dataset format expected:
{
  "question": "...",
  "answerChoices": "(A)... (B)... (C)... (D)...",
  "correctAnswer": "D",
  "context": "Retrieved passage 1...\n\nRetrieved passage 2..."
}

Files needed:
- train_with_context.jsonl
- test_with_context.jsonl  
- val_with_context.jsonl
"""

# ============================================================================
# CELL 1: IMPORTS
# ============================================================================
import os
import torch
import gc
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

In [27]:
# ============================================================================
# CELL 2: MODEL CONFIGURATION
# ============================================================================
model_name = "meta-llama/Llama-2-7b-hf"  # or "meta-llama/Llama-2-7b-chat-hf"
new_model = "llama-2-7b-mcq-rag-finetuned"

In [28]:
# ============================================================================
# CELL 3: DATASET FILES (CHANGED FOR RAG DATASETS)
# ============================================================================
# Dataset files (RAG-augmented with context field)
train_file = "train_fine_tune_with_context.jsonl"
val_file = "valid_fine_tune_with_context.jsonl"

In [29]:
# ============================================================================
# CELL 4: TRAINING HYPERPARAMETERS (KEPT ORIGINAL)
# ============================================================================
local_rank = -1
per_device_train_batch_size = 4
per_device_eval_batch_size = 4
gradient_accumulation_steps = 1
learning_rate = 2e-4
max_grad_norm = 0.3
weight_decay = 0.001
num_train_epochs = 2
max_steps = -1  # -1 means train for num_train_epochs


In [30]:
# ============================================================================
# CELL 5: LORA PARAMETERS (KEPT ORIGINAL)
# ============================================================================
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

In [31]:
# ============================================================================
# CELL 6: QUANTIZATION PARAMETERS (KEPT ORIGINAL)
# ============================================================================
use_4bit = True
use_nested_quant = False
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"

In [32]:
# ============================================================================
# CELL 7: SEQUENCE AND TRAINING SETTINGS (KEPT ORIGINAL)
# ============================================================================
max_seq_length = 1024
packing = False
gradient_checkpointing = True

In [33]:
# ============================================================================
# CELL 8: PRECISION SETTINGS (KEPT ORIGINAL)
# ============================================================================
fp16 = False
bf16 = False

In [34]:
# ============================================================================
# CELL 9: OPTIMIZER AND SCHEDULER (KEPT ORIGINAL)
# ============================================================================
optim = "paged_adamw_32bit"
lr_scheduler_type = "cosine"
warmup_ratio = 0.03

In [35]:
# ============================================================================
# CELL 10: LOGGING AND SAVING (KEPT ORIGINAL)
# ============================================================================
group_by_length = True
save_steps = 100
logging_steps = 10
output_dir = "llama2-mcq-rag-finetuned"
report_to = "tensorboard"
tb_log_dir = "llama2-mcq-rag-finetuned/logs"

In [36]:
# ============================================================================
# CELL 11: DEVICE CONFIGURATION (KEPT ORIGINAL)
# ============================================================================
device_map = {"": 0}

In [37]:
#============================================================================
# CELL 12: NEFTUNE (KEPT ORIGINAL)
# ============================================================================
use_neftune = True
neftune_noise_alpha = 5

In [38]:
# ============================================================================
# CELL 13: PRINT CONFIGURATION
# ============================================================================
print("="*80)
print("CONFIGURATION LOADED")
print("="*80)
print(f"Model: {model_name}")
print(f"Train file: {train_file}")
print(f"Val file: {val_file}")
print(f"Output dir: {output_dir}")
print("="*80 + "\n")

CONFIGURATION LOADED
Model: meta-llama/Llama-2-7b-hf
Train file: train_fine_tune_with_context.jsonl
Val file: valid_fine_tune_with_context.jsonl
Output dir: llama2-mcq-rag-finetuned



In [39]:
# ============================================================================
# CELL 14: LOAD DATASETS (ADAPTED FOR RAG FILES)
# ============================================================================
print("Loading datasets from local JSONL files...")

train_dataset = load_dataset("json", data_files=train_file, split="train")
val_dataset = load_dataset("json", data_files=val_file, split="train")

# Shuffle datasets
train_dataset = train_dataset.shuffle(seed=42)
val_dataset = val_dataset.shuffle(seed=42)

print(f"✓ Train samples: {len(train_dataset)}")
print(f"✓ Validation samples: {len(val_dataset)}")
print(f"✓ Columns: {train_dataset.column_names}\n")

Loading datasets from local JSONL files...
✓ Train samples: 8653
✓ Validation samples: 2528
✓ Columns: ['question', 'answerChoices', 'correctAnswer', 'context']



In [40]:
# ============================================================================
# CELL 15: PRINT SAMPLE DATA (ADAPTED FOR 'context' FIELD)
# ============================================================================
print("="*80)
print("SAMPLE DATA:")
print("="*80)
sample = train_dataset[0]
print(f"Context (first 200 chars): {sample['context'][:200]}...")
print(f"Question: {sample['question']}")
print(f"Answer Choices: {sample['answerChoices']}")
print(f"Correct Answer: {sample['correctAnswer']}")
print("="*80 + "\n")

SAMPLE DATA:
Context (first 200 chars): The water cycle continuously recycles Earths water. Condensation plays an important role in this cycle. Find condensation in the water cycle Figure 1.3. It changes water vapor in the atmosphere to liq...
Question: Which statement about the water cycle is false?
Answer Choices: (A) The water cycle is a global cycle.. (B) The water cycle takes place only on and above Earths surface.. (C) In the water cycle, water exists in three different states.. (D) Water cycle processes include condensation..
Correct Answer: B



In [42]:
# ============================================================================
# CELL 16: LOAD MODEL AND TOKENIZER FUNCTION (KEPT ORIGINAL)
# ============================================================================
def load_model(model_name):
    """Load tokenizer and model with QLoRA configuration"""
    
    # Compute dtype
    compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

    # Quantization Config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=use_4bit,
        bnb_4bit_quant_type=bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=use_nested_quant,
    )

    # Check for bfloat16 support
    if compute_dtype == torch.float16 and use_4bit:
        major, _ = torch.cuda.get_device_capability()
        if major >= 8:
            print("=" * 80)
            print("Your GPU supports bfloat16, you can accelerate training with bf16=True")
            print("=" * 80)

    print(f"Loading model: {model_name}...")
    
    # Load Model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map=device_map,
        quantization_config=bnb_config,
        trust_remote_code=True,
    )

    model.config.use_cache = False
    model.config.pretraining_tp = 1

    # Load LoRA configuration
    peft_config = LoraConfig(
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        r=lora_r,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # Load Tokenizer
    print(f"Loading tokenizer: {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    print("✓ Model and tokenizer loaded successfully!\n")

    return model, tokenizer, peft_config


In [43]:

# ============================================================================
# CELL 17: LOAD MODEL
# ============================================================================
model, tokenizer, peft_config = load_model(model_name)


Your GPU supports bfloat16, you can accelerate training with bf16=True
Loading model: meta-llama/Llama-2-7b-hf...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading tokenizer: meta-llama/Llama-2-7b-hf...
✓ Model and tokenizer loaded successfully!



In [44]:
# ============================================================================
# CELL 18: DATA FORMATTING FUNCTION (ADAPTED FOR 'context' FIELD)
# ============================================================================
def formatting_func(example):
    """
    Format examples for Llama2 instruction fine-tuning
    Using Llama2-chat format with system prompt
    
    KEY CHANGE: Uses 'context' field instead of 'Context'
    """
    output_texts = []
    
    for i in range(len(example['question'])):
        context = str(example["context"][i])  # Changed from "Context" to "context"
        options = str(example["answerChoices"][i])
        question = str(example["question"][i])
        answer = str(example["correctAnswer"][i])
        
        # Llama2-chat format with system prompt
        template = f"""<s>[INST]<<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible using the context text provided. Your answer must be exactly one letter corresponding to the correct option. Do not include any explanation, punctuation, or extra text.
<</SYS>>

Context: {context}

Question: {question}

Options: {options}

Answer: [/INST]{answer}</s>"""
    
        
        output_texts.append(template)
    
    return output_texts


In [45]:
# ============================================================================
# CELL 19: DATA COLLATOR (KEPT ORIGINAL)
# ============================================================================
# This ensures we only compute loss on the answer portion, not the instruction
response_template = "[/INST]"
collator = DataCollatorForCompletionOnlyLM(
    tokenizer.encode(response_template, add_special_tokens=False)[2:], 
    tokenizer=tokenizer
)

print("✓ Data formatting and collator configured\n")


✓ Data formatting and collator configured



In [46]:
# ============================================================================
# CELL 20: TRAINING ARGUMENTS (KEPT ORIGINAL)
# ============================================================================
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to=report_to,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    auto_find_batch_size=True,
    save_total_limit=2,
)

print("✓ Training arguments configured\n")


✓ Training arguments configured



In [47]:
# ============================================================================
# CELL 21: INITIALIZE TRAINER (KEPT ORIGINAL LOGIC)
# ============================================================================
print("Initializing SFTTrainer...")

try:
    from trl import SFTConfig
    
    sft_config = SFTConfig(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        optim=optim,
        save_steps=save_steps,
        logging_steps=logging_steps,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        fp16=fp16,
        bf16=bf16,
        max_grad_norm=max_grad_norm,
        max_steps=max_steps,
        warmup_ratio=warmup_ratio,
        group_by_length=group_by_length,
        lr_scheduler_type=lr_scheduler_type,
        report_to=report_to,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        auto_find_batch_size=True,
        save_total_limit=2,
        max_seq_length=max_seq_length,
        packing=packing,
        neftune_noise_alpha=neftune_noise_alpha if use_neftune else None,
    )
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        peft_config=peft_config,
        tokenizer=tokenizer,
        args=sft_config,
        data_collator=collator,
        formatting_func=formatting_func,
    )
    print("✓ Using SFTConfig (TRL 1.0.0+ compatible)")
    
except ImportError:
    # Fallback for older TRL versions
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        peft_config=peft_config,
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        args=training_arguments,
        packing=packing,
        data_collator=collator,
        formatting_func=formatting_func,
        neftune_noise_alpha=neftune_noise_alpha if use_neftune else None,
    )
    print("✓ Using legacy SFTTrainer initialization (pre-1.0.0)")

print("✓ Trainer initialized successfully!\n")

Initializing SFTTrainer...


Map:   0%|          | 0/2528 [00:00<?, ? examples/s]

✓ Using SFTConfig (TRL 1.0.0+ compatible)
✓ Trainer initialized successfully!



In [48]:
# ============================================================================
# CELL 22: START TRAINING
# ============================================================================
print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80 + "\n")

trainer.train()


STARTING TRAINING



Epoch,Training Loss,Validation Loss
1,1.0957,0.61633
2,0.2112,0.650122




TrainOutput(global_step=8654, training_loss=0.5778177645450137, metrics={'train_runtime': 18713.2336, 'train_samples_per_second': 0.925, 'train_steps_per_second': 0.462, 'total_flos': 3.52288019874816e+17, 'train_loss': 0.5778177645450137, 'epoch': 2.0})

In [50]:
"""
Fixed Best Checkpoint Recovery
Handles 4-bit quantization issues when loading from checkpoint
"""

import os
import json
import torch
from transformers import AutoTokenizer

# ============================================================================
# CONFIGURATION
# ============================================================================
OUTPUT_DIR = "llama2-mcq-rag-finetuned"  # Your training output directory
BEST_MODEL_DIR = "./llama2-mcq-best"

print("="*80)
print("FINDING BEST CHECKPOINT")
print("="*80)

# ============================================================================
# FIND BEST CHECKPOINT
# ============================================================================
checkpoints_info = []

if os.path.exists(OUTPUT_DIR):
    checkpoints = [d for d in os.listdir(OUTPUT_DIR) if d.startswith("checkpoint-")]
    checkpoints.sort(key=lambda x: int(x.split("-")[1]))
    
    print(f"Found {len(checkpoints)} checkpoints:\n")
    
    for checkpoint in checkpoints:
        checkpoint_path = os.path.join(OUTPUT_DIR, checkpoint)
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.json")
        
        if os.path.exists(trainer_state_path):
            with open(trainer_state_path) as f:
                state = json.load(f)
                
                epoch = state.get("epoch")
                
                # Find eval loss from log history
                eval_loss = None
                for log_entry in reversed(state.get("log_history", [])):
                    if "eval_loss" in log_entry:
                        eval_loss = log_entry["eval_loss"]
                        break
                
                checkpoints_info.append({
                    'checkpoint': checkpoint,
                    'path': checkpoint_path,
                    'epoch': epoch,
                    'eval_loss': eval_loss,
                    'step': int(checkpoint.split("-")[1])
                })
                
                print(f"  {checkpoint}")
                print(f"    Epoch: {epoch}")
                print(f"    Eval Loss: {eval_loss}")
                print()

# ============================================================================
# IDENTIFY BEST CHECKPOINT
# ============================================================================
if checkpoints_info:
    valid_checkpoints = [c for c in checkpoints_info if c['eval_loss'] is not None]
    
    if valid_checkpoints:
        best_checkpoint = min(valid_checkpoints, key=lambda x: x['eval_loss'])
        
        print("="*80)
        print("BEST CHECKPOINT IDENTIFIED")
        print("="*80)
        print(f"Checkpoint: {best_checkpoint['checkpoint']}")
        print(f"Epoch: {best_checkpoint['epoch']}")
        print(f"Eval Loss: {best_checkpoint['eval_loss']}")
        print(f"Path: {best_checkpoint['path']}")
        print("="*80 + "\n")
        
        # ====================================================================
        # SOLUTION: JUST COPY THE CHECKPOINT FOLDER
        # The checkpoint already contains the adapter weights
        # No need to load and re-save with AutoPeftModel
        # ====================================================================
        
        print("Copying best checkpoint to deployment directory...")
        
        import shutil
        
        # Remove existing best model dir if it exists
        if os.path.exists(BEST_MODEL_DIR):
            shutil.rmtree(BEST_MODEL_DIR)
        
        # Copy the entire checkpoint folder
        shutil.copytree(best_checkpoint['path'], BEST_MODEL_DIR)
        
        print(f"✓ Best checkpoint copied to: {BEST_MODEL_DIR}")
        
        # ====================================================================
        # SAVE METADATA
        # ====================================================================
        metadata = {
            'source_checkpoint': best_checkpoint['checkpoint'],
            'epoch': best_checkpoint['epoch'],
            'eval_loss': best_checkpoint['eval_loss'],
            'step': best_checkpoint['step'],
            'model_type': 'LoRA adapter (4-bit quantized)',
            'usage': 'Use with AutoPeftModelForCausalLM.from_pretrained()'
        }
        
        with open(os.path.join(BEST_MODEL_DIR, "best_model_info.json"), "w") as f:
            json.dump(metadata, f, indent=2)
        
        print(f"✓ Metadata saved to {BEST_MODEL_DIR}/best_model_info.json")
        
        # ====================================================================
        # LIST FILES IN BEST MODEL DIR
        # ====================================================================
        print(f"\n✓ Files in {BEST_MODEL_DIR}:")
        for item in os.listdir(BEST_MODEL_DIR):
            item_path = os.path.join(BEST_MODEL_DIR, item)
            if os.path.isfile(item_path):
                size_mb = os.path.getsize(item_path) / (1024 * 1024)
                print(f"  - {item} ({size_mb:.2f} MB)")
            else:
                print(f"  - {item}/ (directory)")
        
        print("\n" + "="*80)
        print("SUCCESS!")
        print("="*80)
        print(f"✓ Best model saved at: {BEST_MODEL_DIR}")
        print(f"✓ Epoch: {best_checkpoint['epoch']}")
        print(f"✓ Validation Loss: {best_checkpoint['eval_loss']:.4f}")
        print("\n" + "="*80)
        print("HOW TO USE:")
        print("="*80)
        print("Download the entire folder: " + BEST_MODEL_DIR)
        print("\nOn your local machine:")
        print("```python")
        print("from peft import AutoPeftModelForCausalLM")
        print("from transformers import AutoTokenizer")
        print()
        print("# Load model")
        print(f"model = AutoPeftModelForCausalLM.from_pretrained(")
        print(f"    '{BEST_MODEL_DIR}',")
        print(f"    device_map='auto',")
        print(f"    torch_dtype='auto'")
        print(f")")
        print()
        print("# Load tokenizer")
        print(f"tokenizer = AutoTokenizer.from_pretrained('{BEST_MODEL_DIR}')")
        print()
        print("# Generate answer")
        print("inputs = tokenizer(prompt, return_tensors='pt').to(model.device)")
        print("outputs = model.generate(**inputs, max_new_tokens=5)")
        print("answer = tokenizer.decode(outputs[0], skip_special_tokens=True)")
        print("```")
        print("="*80)
        
    else:
        print("❌ No checkpoints with evaluation metrics found!")
else:
    print(f"❌ No checkpoints found in {OUTPUT_DIR}")

print("\n✓ Checkpoint recovery complete!")
print(f"✓ Ready to download: {BEST_MODEL_DIR}")

FINDING BEST CHECKPOINT
Found 2 checkpoints:

  checkpoint-4327
    Epoch: 1.0
    Eval Loss: 0.6163302659988403

  checkpoint-8654
    Epoch: 2.0
    Eval Loss: 0.6501219272613525

BEST CHECKPOINT IDENTIFIED
Checkpoint: checkpoint-4327
Epoch: 1.0
Eval Loss: 0.6163302659988403
Path: llama2-mcq-rag-finetuned/checkpoint-4327

Copying best checkpoint to deployment directory...
✓ Best checkpoint copied to: ./llama2-mcq-best
✓ Metadata saved to ./llama2-mcq-best/best_model_info.json

✓ Files in ./llama2-mcq-best:
  - README.md (0.00 MB)
  - tokenizer_config.json (0.00 MB)
  - tokenizer.json (1.76 MB)
  - special_tokens_map.json (0.00 MB)
  - rng_state.pth (0.01 MB)
  - optimizer.pt (256.08 MB)
  - adapter_config.json (0.00 MB)
  - best_model_info.json (0.00 MB)
  - trainer_state.json (0.05 MB)
  - adapter_model.safetensors (128.02 MB)
  - training_args.bin (0.00 MB)
  - scheduler.pt (0.00 MB)

SUCCESS!
✓ Best model saved at: ./llama2-mcq-best
✓ Epoch: 1.0
✓ Validation Loss: 0.6163

HOW TO U

In [51]:
import shutil

# Path to the folder you want to zip
folder_path = "llama2-mcq-best"

# Output zip file name (without .zip extension)
output_zip = "llama2-mcq-best-with-rag"

# Create the zip file
shutil.make_archive(output_zip, 'zip', folder_path)

print(f"✅ Folder zipped successfully: {output_zip}.zip")


✅ Folder zipped successfully: llama2-mcq-best-with-rag.zip
