# Kanana RAG Fine-tuning with Evaluation Metrics

This notebook fine-tunes Kanana 8B on RAG tasks and tracks multiple RAG-specific metrics:

**Metrics tracked:**
1. **Train Loss** - Standard training loss
2. **Context Precision@3** - Correct document in top-3 retrieved docs
3. **Context Recall** - Correct document retrieved at all
4. **ROUGE-L** - Lexical similarity with reference answer
5. **BERTScore** - Semantic similarity with reference answer
6. **Answer Relevance** - Answer relevance to question

All metrics are logged to Weights & Biases during training.

In [None]:
# Install required packages
!pip install -q transformers peft datasets wandb bitsandbytes accelerate
!pip install -q rouge-score bert-score evaluate scikit-learn

In [None]:
# Check CUDA availability
import os
import torch

os.environ["NVIDIA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Initialize Weights & Biases
import wandb

wandb.login()
wandb.init(
    project="kanana-rag-finetuning",
    name="kanana-rag-with-metrics",
    config={
        "model": "kakaocorp/kanana-1.5-8b-instruct-2505",
        "task": "RAG fine-tuning with evaluation metrics",
        "dataset": "Jecheon Tourism"
    }
)

In [None]:
# Load training data and split into train/test
import json
from sklearn.model_selection import train_test_split

data_path = "/home/user/goodganglabs/data/processed/training_data.jsonl"

# Load all data
data_list = []
with open(data_path, 'r', encoding='utf-8') as f:
    for line in f:
        data_list.append(json.loads(line))

# Split 80/20
train_data, test_data = train_test_split(
    data_list, 
    test_size=0.2, 
    random_state=42
)

print(f"Total: {len(data_list)} examples")
print(f"Train: {len(train_data)} examples")
print(f"Test: {len(test_data)} examples")

In [None]:
# Load model and tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

model_name = "kakaocorp/kanana-1.5-8b-instruct-2505"

print(f"Loading {model_name}...")
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Apply LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj"]
)

model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

print("Model loaded!")

In [None]:
# Define RAG Metrics
from rouge_score import rouge_scorer
from bert_score import score as bert_score_fn
import numpy as np
from typing import List, Dict

class RAGMetrics:
    def __init__(self):
        self.rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)
    
    def context_precision_at_k(self, documents: List[Dict], correct_doc_id: str, k: int = 3) -> float:
        """Correct document in top-k?"""
        if not correct_doc_id:
            return 0.0
        doc_ids = [doc['doc_id'] for doc in documents[:k]]
        return 1.0 if correct_doc_id in doc_ids else 0.0
    
    def context_recall(self, documents: List[Dict], correct_doc_id: str) -> float:
        """Correct document retrieved?"""
        if not correct_doc_id:
            return 0.0
        doc_ids = [doc['doc_id'] for doc in documents]
        return 1.0 if correct_doc_id in doc_ids else 0.0
    
    def rouge_l(self, prediction: str, reference: str) -> float:
        """ROUGE-L F1 score"""
        scores = self.rouge_scorer.score(reference, prediction)
        return scores['rougeL'].fmeasure
    
    def bert_score(self, predictions: List[str], references: List[str]) -> float:
        """BERTScore F1 (batched)"""
        if not predictions or not references:
            return 0.0
        P, R, F1 = bert_score_fn(
            predictions, references,
            lang='ko', verbose=False,
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )
        return F1.mean().item()
    
    def answer_relevance(self, answer: str, question: str) -> float:
        """Simple keyword overlap-based relevance"""
        if not answer or len(answer.strip()) < 10:
            return 0.0
        
        stop_words = {'은', '는', '이', '가', '을', '를', '의', '에', '에서', '로', '으로', '와', '과', '도', '만', '?'}
        q_words = set(question.split()) - stop_words
        a_words = set(answer.split()) - stop_words
        
        if not q_words:
            return 1.0
        
        overlap = len(q_words & a_words) / len(q_words)
        return min(overlap, 1.0)

rag_metrics = RAGMetrics()
print("RAG Metrics initialized!")

In [None]:
# Format data for training
from datasets import Dataset

INSTRUCTION = """당신은 제천시 관광 안내 전문가입니다.
제공된 여러 문서 중에서 질문과 관련된 문서를 찾아, 그 문서의 내용을 바탕으로 정확하고 친절하게 답변해주세요.

답변 시 주의사항:
1. 관련 문서의 내용만을 바탕으로 답변하세요
2. 문서에 정보가 없으면 "제공된 정보에는 해당 내용이 없습니다"라고 답변하세요
3. 추측하거나 문서 외부 지식을 사용하지 마세요
4. 간결하고 이해하기 쉽게 답변하세요"""

def format_example(example):
    # Build prompt
    info_sections = [f"Information:\n{doc['content']}" for doc in example['documents']]
    prompt = INSTRUCTION + "\n\n" + "\n\n".join(info_sections) + f"\n\nQuestion: {example['question']}"
    
    return {
        "prompt": prompt,
        "answer": example['answer'],
        "question": example['question'],
        "documents": example['documents'],
        "correct_doc_id": example.get('correct_doc_id', None)
    }

# Format train and test
train_formatted = [format_example(ex) for ex in train_data]
test_formatted = [format_example(ex) for ex in test_data]

# Create training dataset with text column
train_texts = [f"{ex['prompt']}\n\nAnswer: {ex['answer']}{tokenizer.eos_token}" for ex in train_formatted]
train_dataset = Dataset.from_dict({"text": train_texts})

print(f"Train: {len(train_formatted)}, Test: {len(test_formatted)}")

In [None]:
# Tokenize training data
def tokenize(examples):
    tokens = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=2048
    )
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens

train_dataset_tokenized = train_dataset.map(tokenize, batched=True, remove_columns=["text"])

print(f"Tokenized training set: {len(train_dataset_tokenized)} examples")

In [None]:
# Define evaluation function (called periodically during training)
def evaluate_rag(model, test_data, tokenizer, num_samples=20):
    """Evaluate model on test set and return metrics"""
    model.eval()
    
    # Sample subset for faster evaluation
    import random
    samples = random.sample(test_data, min(num_samples, len(test_data)))
    
    predictions = []
    references = []
    context_precisions = []
    context_recalls = []
    answer_relevances = []
    
    for sample in samples:
        # Generate answer
        inputs = tokenizer(sample['prompt'], return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=150,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract answer part (after "Answer:")
        if "Answer:" in generated:
            pred_answer = generated.split("Answer:")[-1].strip()
        else:
            pred_answer = generated[len(sample['prompt']):].strip()
        
        predictions.append(pred_answer)
        references.append(sample['answer'])
        
        # Compute context metrics
        if sample['correct_doc_id']:
            context_precisions.append(
                rag_metrics.context_precision_at_k(sample['documents'], sample['correct_doc_id'], k=3)
            )
            context_recalls.append(
                rag_metrics.context_recall(sample['documents'], sample['correct_doc_id'])
            )
        
        # Compute answer relevance
        answer_relevances.append(
            rag_metrics.answer_relevance(pred_answer, sample['question'])
        )
    
    # Compute ROUGE-L (average over samples)
    rouge_scores = [rag_metrics.rouge_l(p, r) for p, r in zip(predictions, references)]
    
    # Compute BERTScore (batched)
    bert_score = rag_metrics.bert_score(predictions, references)
    
    metrics = {
        "eval_context_precision@3": np.mean(context_precisions) if context_precisions else 0.0,
        "eval_context_recall": np.mean(context_recalls) if context_recalls else 0.0,
        "eval_rouge_l": np.mean(rouge_scores),
        "eval_bert_score": bert_score,
        "eval_answer_relevance": np.mean(answer_relevances)
    }
    
    model.train()
    return metrics

print("Evaluation function defined!")

In [None]:
# Create custom callback for periodic evaluation
from transformers import TrainerCallback

class RAGEvaluationCallback(TrainerCallback):
    def __init__(self, test_data, tokenizer, eval_steps=50):
        self.test_data = test_data
        self.tokenizer = tokenizer
        self.eval_steps = eval_steps
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        # Evaluate every eval_steps
        if state.global_step % self.eval_steps == 0 and state.global_step > 0:
            print(f"\n[Step {state.global_step}] Running RAG evaluation...")
            metrics = evaluate_rag(model, self.test_data, self.tokenizer, num_samples=15)
            
            # Log to wandb
            wandb.log({**metrics, "step": state.global_step})
            
            print(f"Context Precision@3: {metrics['eval_context_precision@3']:.3f}")
            print(f"Context Recall: {metrics['eval_context_recall']:.3f}")
            print(f"ROUGE-L: {metrics['eval_rouge_l']:.3f}")
            print(f"BERTScore: {metrics['eval_bert_score']:.3f}")
            print(f"Answer Relevance: {metrics['eval_answer_relevance']:.3f}\n")

eval_callback = RAGEvaluationCallback(test_formatted, tokenizer, eval_steps=50)
print("Evaluation callback created!")

In [None]:
# Configure training
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./outputs/kanana-rag-metrics",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_steps=10,
    bf16=True,
    logging_steps=5,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,
    report_to="wandb",
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_tokenized,
    callbacks=[eval_callback]  # Add evaluation callback
)

print("Trainer initialized with RAG evaluation callback!")
print(f"Will evaluate every {eval_callback.eval_steps} steps")

In [None]:
# Train model
print("Starting training with RAG metrics tracking...\n")
print("="*60)

trainer_stats = trainer.train()

print("\n" + "="*60)
print("Training completed!")
print(f"Final train loss: {trainer_stats.training_loss:.4f}")
print(f"Training time: {trainer_stats.metrics['train_runtime']:.2f}s")

In [None]:
# Final full evaluation on test set
print("Running final evaluation on full test set...\n")

final_metrics = evaluate_rag(model, test_formatted, tokenizer, num_samples=len(test_formatted))

print("Final Test Metrics:")
print("=" * 40)
print(f"Context Precision@3: {final_metrics['eval_context_precision@3']:.4f}")
print(f"Context Recall:      {final_metrics['eval_context_recall']:.4f}")
print(f"ROUGE-L:             {final_metrics['eval_rouge_l']:.4f}")
print(f"BERTScore:           {final_metrics['eval_bert_score']:.4f}")
print(f"Answer Relevance:    {final_metrics['eval_answer_relevance']:.4f}")

# Log to wandb
wandb.log({f"final_{k}": v for k, v in final_metrics.items()})

In [None]:
# Save model
output_dir = "./outputs/kanana-rag-final"
print(f"Saving model to {output_dir}...")

trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print("Model saved!")

In [None]:
# Test inference
print("Testing inference on sample...\n")

test_sample = test_formatted[0]
inputs = tokenizer(test_sample['prompt'], return_tensors="pt").to(model.device)

model.eval()
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=200,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

generated = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Question:", test_sample['question'])
print("\nGenerated Answer:")
if "Answer:" in generated:
    print(generated.split("Answer:")[-1].strip())
else:
    print(generated[len(test_sample['prompt']):].strip())
print("\nReference Answer:")
print(test_sample['answer'])

In [None]:
# Close wandb
wandb.finish()
print("WandB run finished!")

## Summary

This notebook:

### Training
- ✅ Loads Kanana 1.5 8B instruct model
- ✅ Splits data into train (80%) / test (20%)
- ✅ Applies LoRA for efficient fine-tuning
- ✅ Trains with standard cross-entropy loss

### Evaluation Metrics (tracked during training)
1. **Train Loss** - Standard language modeling loss
2. **Context Precision@3** - How often the correct document appears in top-3 retrieved docs
3. **Context Recall** - How often the correct document is retrieved at all
4. **ROUGE-L** - Lexical overlap between generated and reference answers
5. **BERTScore** - Semantic similarity using Korean BERT embeddings
6. **Answer Relevance** - Whether answer contains relevant keywords from question

### Logging
- ✅ All metrics logged to Weights & Biases
- ✅ Evaluation runs every 50 training steps
- ✅ Final full evaluation on complete test set

### Next Steps
- Compare baseline (pre-trained) vs fine-tuned metrics
- Analyze failure cases
- Upload model to Hugging Face Hub
- Generate report with metric visualizations