# Experiment 3: Advanced RAG

In [None]:
# Setup
import sys
import json
from pathlib import Path
from typing import Dict, List, Any, Set, Tuple
from dataclasses import dataclass, asdict
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# RAG components
import chromadb
from chromadb.utils import embedding_functions

# LLM
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

sys.path.append('..')
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

print("Imports loaded")

In [None]:
# Configuration
DB_PATH = Path("../data/vector_db")
MODEL_PATH = Path("/home/sskaplun/study/genAI/kaggle/models/gemma-2-9b-it")
OUTPUT_DIR = Path("../evaluation/experiment_03")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

COLLECTION_NAME = "ukrainian_math"
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"

# Advanced RAG parameters
NUM_QUERY_EXPANSIONS = 3  # Generate 3 query variants
RETRIEVAL_K = 15  # Retrieve more candidates
FINAL_K = 5  # Re-rank to top-5
TEMPERATURE = 0.7
MAX_NEW_TOKENS = 512

# Re-ranking weights
RELEVANCE_WEIGHT = 0.5
DIVERSITY_WEIGHT = 0.3
CONTENT_TYPE_WEIGHT = 0.2

print(f"Query Expansions: {NUM_QUERY_EXPANSIONS}")
print(f"Retrieval K: {RETRIEVAL_K} → Final K: {FINAL_K}")
print(f"CUDA: {torch.cuda.is_available()}")

In [None]:
@dataclass
class RetrievedChunk:
    text: str
    content_type: str
    confidence: float
    filename: str
    page_start: int
    page_end: int
    distance: float
    relevance: float
    rerank_score: float  # NEW: re-ranking score
    citation: str

@dataclass
class AdvancedRAGResponse:
    question: str
    expanded_queries: List[str]  # NEW: query variants
    answer: str
    citations: List[str]
    retrieved_chunks: List[RetrievedChunk]
    avg_relevance: float
    avg_rerank_score: float  # NEW
    answer_length: int
    
    def to_dict(self):
        return {
            'question': self.question,
            'expanded_queries': self.expanded_queries,
            'answer': self.answer,
            'citations': self.citations,
            'avg_relevance': self.avg_relevance,
            'avg_rerank_score': self.avg_rerank_score,
            'answer_length': self.answer_length,
            'num_chunks': len(self.retrieved_chunks)
        }

print("Dataclasses defined")

## 1. Load Vector Database

In [None]:
print("="*80)
print("LOADING VECTOR DATABASE")
print("="*80)

client = chromadb.PersistentClient(path=str(DB_PATH))

embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
    model_name=EMBEDDING_MODEL
)

collection = client.get_collection(
    name=COLLECTION_NAME,
    embedding_function=embedding_function
)

print(f"\nCollection: {COLLECTION_NAME}")
print(f"  Total chunks: {collection.count():,}")

## 2. Load LLM

In [None]:
print("="*80)
print("LOADING LLM")
print("="*80)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
model = AutoModelForCausalLM.from_pretrained(
    str(MODEL_PATH),
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch.float16
)

print("Model loaded")

## 3. Query Expansion

In [None]:
def expand_query(query: str, num_variants: int = NUM_QUERY_EXPANSIONS) -> List[str]:
    """
    Generate multiple query variants for better retrieval coverage.
    
    Strategy:
    1. Original query
    2. Extract key math terms and rephrase
    3. Add context (e.g., "формула", "приклад", "визначення")
    """
    queries = [query]  # Always include original
    
    # Generate variants using simple LLM prompting
    expansion_prompt = f"""Перефразуй це запитання українською мовою {num_variants-1} різними способами, 
зберігаючи математичний зміст. Використовуй різні формулювання та синоніми.

Оригінальне запитання: {query}

Варіанти (по одному на рядок):"""
    
    messages = [{"role": "user", "content": expansion_prompt}]
    formatted = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            temperature=0.8,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    result = tokenizer.decode(
        outputs[0][inputs['input_ids'].shape[1]:],
        skip_special_tokens=True
    ).strip()
    
    # Parse variants (split by newlines)
    variants = [line.strip() for line in result.split('\n') if line.strip()]
    queries.extend(variants[:(num_variants-1)])
    
    return queries[:num_variants]

print("Query expansion function defined")

## 4. Hybrid Retrieval & Re-ranking

In [None]:
def retrieve_with_queries(
    queries: List[str],
    k: int = RETRIEVAL_K
) -> List[RetrievedChunk]:
    """
    Retrieve chunks using multiple queries and merge results.
    """
    all_chunks = {}
    
    for query in queries:
        results = collection.query(
            query_texts=[query],
            n_results=k
        )
        
        for doc, meta, dist in zip(
            results['documents'][0],
            results['metadatas'][0],
            results['distances'][0]
        ):
            # Use text as key to deduplicate
            key = doc[:100]  # First 100 chars as key
            
            if key not in all_chunks:
                chunk = RetrievedChunk(
                    text=doc,
                    content_type=meta['content_type'],
                    confidence=meta['confidence'],
                    filename=meta['filename'],
                    page_start=meta['page_start'],
                    page_end=meta['page_end'],
                    distance=dist,
                    relevance=1 - dist,
                    rerank_score=0.0,  # Will be computed later
                    citation=f"[{meta['filename']}, с. {meta['page_start']}-{meta['page_end']}]"
                )
                all_chunks[key] = chunk
            else:
                # Update with better relevance if found
                if (1 - dist) > all_chunks[key].relevance:
                    all_chunks[key].distance = dist
                    all_chunks[key].relevance = 1 - dist
    
    return list(all_chunks.values())

print("Hybrid retrieval function defined")

In [None]:
def calculate_diversity_score(chunks: List[RetrievedChunk]) -> List[float]:
    """
    Calculate diversity scores based on content type and source variety.
    """
    # Count occurrences
    type_counts = Counter(c.content_type for c in chunks)
    file_counts = Counter(c.filename for c in chunks)
    
    diversity_scores = []
    for chunk in chunks:
        # Penalize over-represented types and files
        type_penalty = 1.0 / type_counts[chunk.content_type]
        file_penalty = 1.0 / file_counts[chunk.filename]
        diversity_score = (type_penalty + file_penalty) / 2
        diversity_scores.append(diversity_score)
    
    # Normalize to [0, 1]
    max_score = max(diversity_scores) if diversity_scores else 1.0
    return [s / max_score for s in diversity_scores]

def calculate_content_type_score(chunk: RetrievedChunk) -> float:
    """
    Score chunks by content type preference for task generation.
    
    Preference order:
    1. explanation (best for understanding concepts)
    2. definition (good for terminology)
    3. problem (examples of tasks)
    4. example, theorem, etc.
    """
    type_scores = {
        'explanation': 1.0,
        'definition': 0.9,
        'problem': 0.8,
        'example': 0.7,
        'theorem': 0.7,
        'formula': 0.6
    }
    return type_scores.get(chunk.content_type, 0.5)

def rerank_chunks(
    chunks: List[RetrievedChunk],
    final_k: int = FINAL_K
) -> List[RetrievedChunk]:
    """
    Re-rank chunks using weighted combination of:
    - Semantic relevance (from embedding distance)
    - Content diversity (variety of types/sources)
    - Content type preference (explanations > examples)
    """
    diversity_scores = calculate_diversity_score(chunks)
    
    for i, chunk in enumerate(chunks):
        relevance = chunk.relevance
        diversity = diversity_scores[i]
        content_type = calculate_content_type_score(chunk)
        
        # Weighted combination
        rerank_score = (
            RELEVANCE_WEIGHT * relevance +
            DIVERSITY_WEIGHT * diversity +
            CONTENT_TYPE_WEIGHT * content_type
        )
        chunk.rerank_score = rerank_score
    
    # Sort by rerank score and take top-k
    chunks.sort(key=lambda c: c.rerank_score, reverse=True)
    return chunks[:final_k]

print("Re-ranking functions defined")

## 5. Advanced RAG Pipeline

In [None]:
SYSTEM_PROMPT = """Ти — досвідчений викладач математики для українських учнів 10-11 класів.

Твоє завдання:
- Згенерувати математичну задачу з розв'язанням на основі ТІЛЬКИ наданого контексту
- Використовувати ТІЛЬКИ українську мову
- Використовувати математичну термінологію з підручників
- Обов'язково посилатися на джерела
- Надати чітке покрокове розв'язання

Формат відповіді:
**Задача:** [текст задачі на основі контексту]

**Розв'язання:**
[покрокове рішення з посиланнями на джерела]

**Відповідь:** [фінальна відповідь]

ВАЖЛИВО: Використовуй ТІЛЬКИ інформацію з наданого контексту!"""

def format_context(chunks: List[RetrievedChunk]) -> str:
    """Format chunks with re-rank scores."""
    context_parts = []
    for i, chunk in enumerate(chunks, 1):
        header = f"[Джерело {i}] {chunk.citation} | Тип: {chunk.content_type} | Оцінка: {chunk.rerank_score:.3f}"
        context_parts.append(f"{header}\n{chunk.text}")
    return "\n\n".join(context_parts)

def generate_answer(prompt: str) -> str:
    """Generate using LLM."""
    messages = [{"role": "user", "content": prompt}]
    formatted = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(
        outputs[0][inputs['input_ids'].shape[1]:],
        skip_special_tokens=True
    ).strip()

def advanced_rag_generate(
    question: str,
    verbose: bool = False
) -> AdvancedRAGResponse:
    """
    Advanced RAG pipeline with query expansion and re-ranking.
    """
    if verbose:
        print(f"\nQuestion: {question}")
        print("  Step 1: Query expansion...")
    
    # 1. Query expansion
    expanded_queries = expand_query(question)
    if verbose:
        for i, q in enumerate(expanded_queries, 1):
            print(f"    {i}. {q}")
        print(f"  Step 2: Hybrid retrieval (k={RETRIEVAL_K})...")
    
    # 2. Retrieve with multiple queries
    chunks = retrieve_with_queries(expanded_queries)
    if verbose:
        print(f"    Retrieved {len(chunks)} unique chunks")
        print(f"  Step 3: Re-ranking to top-{FINAL_K}...")
    
    # 3. Re-rank
    reranked_chunks = rerank_chunks(chunks, final_k=FINAL_K)
    if verbose:
        avg_rel = np.mean([c.relevance for c in reranked_chunks])
        avg_rerank = np.mean([c.rerank_score for c in reranked_chunks])
        print(f"    Avg relevance: {avg_rel:.3f} | Avg rerank score: {avg_rerank:.3f}")
        print("  Step 4: Generating answer...")
    
    # 4. Generate
    context = format_context(reranked_chunks)
    prompt = f"{SYSTEM_PROMPT}\n\nКОНТЕКСТ З ПІДРУЧНИКІВ:\n{context}\n\nЗАПИТАННЯ:\n{question}\n\nТВОЯ ВІДПОВІДЬ:"
    answer = generate_answer(prompt)
    
    if verbose:
        print(f"    Generated {len(answer)} chars")
    
    return AdvancedRAGResponse(
        question=question,
        expanded_queries=expanded_queries,
        answer=answer,
        citations=[c.citation for c in reranked_chunks],
        retrieved_chunks=reranked_chunks,
        avg_relevance=float(np.mean([c.relevance for c in reranked_chunks])),
        avg_rerank_score=float(np.mean([c.rerank_score for c in reranked_chunks])),
        answer_length=len(answer)
    )

print("Advanced RAG pipeline defined")

## 6. Test Questions

In [None]:
from common import STANDARD_TEST_QUESTIONS, EVALUATION_DATASET

TEST_QUESTIONS = STANDARD_TEST_QUESTIONS
print(f"Test set: {len(TEST_QUESTIONS)} questions")

# Create mapping of questions to expected answers
question_to_expected = {q['input']: q['expected_answer'] for q in EVALUATION_DATASET}
print(f"Expected answers loaded for {len(question_to_expected)} questions")

## 7. Run Advanced RAG Experiment

In [None]:
print("="*80)
print("RUNNING ADVANCED RAG EXPERIMENT")
print("="*80)

responses = []

for i, question in enumerate(TEST_QUESTIONS, 1):
    print(f"\n[{i}/{len(TEST_QUESTIONS)}] {question}")
    print("-"*80)
    
    response = advanced_rag_generate(question, verbose=True)
    responses.append(response)
    
    print(f"\nAnswer:\n{response.answer}")
    print(f"\nTop-3 Citations:")
    for j, citation in enumerate(response.citations[:3], 1):
        print(f"  {j}. {citation}")

print(f"\n{'='*80}")
print(f"Completed {len(responses)} advanced RAG responses")
print("="*80)

## 8. Evaluation

In [None]:
import common

print("Evaluation functions loaded from common.py")

In [None]:
# Evaluate
print("="*80)
print("EVALUATION")
print("="*80)

evaluations = []

for i, response in enumerate(responses, 1):
    expected_answer = question_to_expected.get(response.question, None)
    metrics = common.evaluate_advanced_rag(
        response.answer, 
        response.answer_length, 
        response.avg_relevance, 
        response.avg_rerank_score,
        expected_answer
    )
    evaluations.append({
        'question': response.question,
        'metrics': metrics,
        'answer_length': response.answer_length,
        'avg_relevance': response.avg_relevance,
        'avg_rerank': response.avg_rerank_score
    })
    
    print(f"\n{i}. {response.question[:50]}...")
    print(f"   Overall: {metrics['overall_score']:.3f} | "
          f"Rerank: {metrics['rerank_quality']:.3f} | "
          f"Ukrainian: {metrics['ukrainian_ratio']:.3f}")

# Summary
print(f"\n{'='*80}")
print("SUMMARY")
print("="*80)

avg_metrics = {
    'overall_score': np.mean([e['metrics']['overall_score'] for e in evaluations]),
    'retrieval_quality': np.mean([e['metrics']['retrieval_quality'] for e in evaluations]),
    'rerank_quality': np.mean([e['metrics']['rerank_quality'] for e in evaluations]),
    'ukrainian_ratio': np.mean([e['metrics']['ukrainian_ratio'] for e in evaluations]),
    'completeness': np.mean([e['metrics']['completeness'] for e in evaluations]),
    'correctness': np.mean([e['metrics']['correctness'] for e in evaluations]),
    'structure_rate': sum(e['metrics']['has_structure'] for e in evaluations) / len(evaluations),
    'citation_rate': sum(e['metrics']['has_citations'] for e in evaluations) / len(evaluations)
}

for key, value in avg_metrics.items():
    print(f"  {key:20s}: {value:.3f}")


## 9. Save Results

In [None]:
results = {
    'experiment': 'advanced_rag',
    'description': 'Query expansion + hybrid retrieval + re-ranking',
    'model': 'gemma-2-9b-it',
    'config': {
        'query_expansions': NUM_QUERY_EXPANSIONS,
        'retrieval_k': RETRIEVAL_K,
        'final_k': FINAL_K,
        'relevance_weight': RELEVANCE_WEIGHT,
        'diversity_weight': DIVERSITY_WEIGHT,
        'content_type_weight': CONTENT_TYPE_WEIGHT
    },
    'avg_metrics': avg_metrics,
    'responses': [r.to_dict() for r in responses],
    'evaluations': evaluations
}

with open(OUTPUT_DIR / 'results.json', 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print(f"Results saved to {OUTPUT_DIR}")
print("\n" + "="*80)
print("EXPERIMENT 3 COMPLETE")
print("="*80)
print(f"\nOverall Score: {avg_metrics['overall_score']:.3f}")
print(f"Retrieval Quality: {avg_metrics['retrieval_quality']:.3f}")
print(f"Re-rank Quality: {avg_metrics['rerank_quality']:.3f}")
print(f"Ukrainian Ratio: {avg_metrics['ukrainian_ratio']:.3f}")