# Experiment 5: Multi-Agent System

In [None]:
# Setup
import sys
import json
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict, field
from enum import Enum
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# RAG
import chromadb
from chromadb.utils import embedding_functions

# LLM
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

sys.path.append('..')
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

print("Imports loaded")

In [None]:
# Configuration
DB_PATH = Path("../data/vector_db")
MODEL_PATH = Path("/home/sskaplun/study/genAI/kaggle/models/gemma-2-9b-it")
OUTPUT_DIR = Path("../evaluation/experiment_05")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

COLLECTION_NAME = "ukrainian_math"
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"

# Multi-agent parameters
TOP_K = 5
TEMPERATURE = 0.7
MAX_NEW_TOKENS = 512
MAX_ITERATIONS = 2  # Max refinement iterations
QUALITY_THRESHOLD = 0.7  # Validation threshold

print(f"Max Iterations: {MAX_ITERATIONS}")
print(f"Quality Threshold: {QUALITY_THRESHOLD}")
print(f"CUDA: {torch.cuda.is_available()}")

In [None]:
class AgentRole(Enum):
    TOPIC = "topic_agent"
    TASK_GENERATOR = "task_generator"
    SOLUTION = "solution_agent"
    QUALITY = "quality_agent"
    ORCHESTRATOR = "orchestrator"

@dataclass
class AgentMessage:
    role: AgentRole
    content: str
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class MultiAgentResponse:
    question: str
    final_answer: str
    task_text: str
    solution_text: str
    conversation_history: List[AgentMessage]
    citations: List[str]
    avg_relevance: float  # NEW: retrieval quality from TopicAgent
    iterations: int
    quality_score: float
    answer_length: int
    
    def to_dict(self):
        return {
            'question': self.question,
            'final_answer': self.final_answer,
            'task_text': self.task_text,
            'solution_text': self.solution_text,
            'avg_relevance': self.avg_relevance,
            'iterations': self.iterations,
            'quality_score': self.quality_score,
            'answer_length': self.answer_length,
            'num_messages': len(self.conversation_history)
        }

print("Dataclasses defined")

## 1. Load Infrastructure

In [None]:
print("="*80)
print("LOADING INFRASTRUCTURE")
print("="*80)

# Vector DB
client = chromadb.PersistentClient(path=str(DB_PATH))
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
    model_name=EMBEDDING_MODEL
)
collection = client.get_collection(
    name=COLLECTION_NAME,
    embedding_function=embedding_function
)
print(f"Vector DB: {collection.count():,} chunks")

# LLM
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
model = AutoModelForCausalLM.from_pretrained(
    str(MODEL_PATH),
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch.float16
)
print("LLM loaded")

## 2. Define Agents

In [None]:
def generate_llm(prompt: str, max_tokens: int = MAX_NEW_TOKENS, temp: float = TEMPERATURE) -> str:
    """Core LLM generation function used by all agents."""
    messages = [{"role": "user", "content": prompt}]
    formatted = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temp,
            top_p=0.9,
            do_sample=temp > 0,
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(
        outputs[0][inputs['input_ids'].shape[1]:],
        skip_special_tokens=True
    ).strip()

print("Core LLM function defined")

In [None]:
class TopicAgent:
    """Agent 1: Retrieves relevant textbook context for a topic."""
    
    def __init__(self, collection, k: int = TOP_K):
        self.collection = collection
        self.k = k
        self.role = AgentRole.TOPIC
    
    def retrieve(self, topic: str) -> AgentMessage:
        """Retrieve context for topic."""
        results = self.collection.query(
            query_texts=[topic],
            n_results=self.k
        )
        
        chunks = []
        citations = []
        
        for i, (doc, meta, dist) in enumerate(zip(
            results['documents'][0],
            results['metadatas'][0],
            results['distances'][0]
        ), 1):
            citation = f"[{meta['filename']}, с. {meta['page_start']}-{meta['page_end']}]"
            header = f"[Джерело {i}] {citation} | Тип: {meta['content_type']}"
            chunks.append(f"{header}\n{doc}")
            citations.append(citation)
        
        context = "\n\n".join(chunks)
        avg_relevance = float(np.mean([1 - d for d in results['distances'][0]]))
        
        return AgentMessage(
            role=self.role,
            content=context,
            metadata={
                'citations': citations,
                'avg_relevance': avg_relevance,
                'num_chunks': len(chunks)
            }
        )

print("TopicAgent defined")

In [None]:
class TaskGeneratorAgent:
    """Agent 2: Generates math problems from context."""
    
    def __init__(self):
        self.role = AgentRole.TASK_GENERATOR
        self.system_prompt = """Ти — експерт-агент з ГЕНЕРАЦІЇ математичних задач українською мовою.

Твоя ЄДИНА задача: згенерувати ТІЛЬКИ текст задачі (умову) на основі контексту.

Правила:
- Використовуй ТІЛЬКИ інформацію з наданого контексту
- Формулюй задачу чітко та зрозуміло
- Включи конкретні числові значення
- Використовуй українську математичну термінологію
- НЕ пиши розв'язання (це зробить інший агент)

Формат:
**Задача:** [текст умови задачі]

Все! Більше нічого не пиши."""
    
    def generate(self, context: str, topic: str) -> AgentMessage:
        """Generate task from context."""
        prompt = f"{self.system_prompt}\n\nКОНТЕКСТ:\n{context}\n\nТЕМА: {topic}\n\nТВОЯ ЗАДАЧА:"
        task = generate_llm(prompt, max_tokens=300)
        
        return AgentMessage(
            role=self.role,
            content=task,
            metadata={'topic': topic}
        )

print("TaskGeneratorAgent defined")

In [None]:
class SolutionAgent:
    """Agent 3: Solves math problems step-by-step."""
    
    def __init__(self):
        self.role = AgentRole.SOLUTION
        self.system_prompt = """Ти — експерт-агент з РОЗВ'ЯЗУВАННЯ математичних задач українською мовою.

Твоя ЄДИНА задача: надати покрокове розв'язання для заданої задачі.

Правила:
- Розв'язуй крок за кроком
- Поясни кожен крок зрозумілою мовою
- Використовуй формули з контексту
- Перевіряй обчислення
- Дай фінальну відповідь

Формат:
**Розв'язання:**
1. [перший крок з поясненням]
2. [другий крок]
...

**Відповідь:** [фінальна відповідь]"""
    
    def solve(self, task: str, context: str) -> AgentMessage:
        """Solve the task."""
        prompt = f"{self.system_prompt}\n\nКОНТЕКСТ (формули та теорія):\n{context}\n\n{task}\n\nТВОЄ РОЗВ'ЯЗАННЯ:"
        solution = generate_llm(prompt, max_tokens=500)
        
        return AgentMessage(
            role=self.role,
            content=solution,
            metadata={'task': task}
        )

print("SolutionAgent defined")

In [None]:
class QualityAgent:
    """Agent 4: Validates task and solution quality."""
    
    def __init__(self):
        self.role = AgentRole.QUALITY
        self.system_prompt = """Ти — експерт-агент з КОНТРОЛЮ ЯКОСТІ математичних задач та розв'язань.

Твоя задача: оцінити якість задачі та розв'язання за критеріями:
1. Чіткість формулювання задачі (0-1)
2. Коректність розв'язання (0-1)
3. Повнота пояснення (0-1)
4. Українська мова (0-1)
5. Відповідність контексту (0-1)

Формат відповіді:
ОЦІНКА: [число від 0.0 до 1.0]
КОМЕНТАР: [короткий коментар]
ПРОПОЗИЦІЇ: [що покращити, якщо оцінка < 0.7]"""
    
    def validate(self, task: str, solution: str, context: str) -> Tuple[float, str, AgentMessage]:
        """Validate quality and return score, feedback, and message."""
        prompt = f"{self.system_prompt}\n\nКОНТЕКСТ:\n{context}\n\n{task}\n\n{solution}\n\nТВОЯ ОЦІНКА:"
        feedback = generate_llm(prompt, max_tokens=200, temp=0.3)
        
        # Extract score (simple regex)
        import re
        score_match = re.search(r'ОЦІНКА:\s*([0-9.]+)', feedback)
        score = float(score_match.group(1)) if score_match else 0.5
        score = max(0.0, min(1.0, score))  # Clamp to [0, 1]
        
        return score, feedback, AgentMessage(
            role=self.role,
            content=feedback,
            metadata={'score': score}
        )

print("QualityAgent defined")

## 3. Orchestrator

In [None]:
class Orchestrator:
    """Agent 5: Coordinates multi-agent collaboration."""
    
    def __init__(
        self,
        topic_agent: TopicAgent,
        task_agent: TaskGeneratorAgent,
        solution_agent: SolutionAgent,
        quality_agent: QualityAgent,
        quality_threshold: float = QUALITY_THRESHOLD,
        max_iterations: int = MAX_ITERATIONS
    ):
        self.topic_agent = topic_agent
        self.task_agent = task_agent
        self.solution_agent = solution_agent
        self.quality_agent = quality_agent
        self.quality_threshold = quality_threshold
        self.max_iterations = max_iterations
        self.role = AgentRole.ORCHESTRATOR
    
    def run(self, question: str, verbose: bool = False) -> MultiAgentResponse:
        """Orchestrate multi-agent workflow."""
        conversation = []
        
        if verbose:
            print(f"\n[ORCHESTRATOR] Starting workflow for: {question}")
            print("-" * 80)
        
        # Step 1: Topic Agent retrieves context
        if verbose:
            print("\n[1] TopicAgent: Retrieving context...")
        
        topic_msg = self.topic_agent.retrieve(question)
        conversation.append(topic_msg)
        
        if verbose:
            print(f"    Retrieved {topic_msg.metadata['num_chunks']} chunks")
            print(f"    Avg relevance: {topic_msg.metadata['avg_relevance']:.3f}")
        
        context = topic_msg.content
        citations = topic_msg.metadata['citations']
        avg_relevance = topic_msg.metadata['avg_relevance']  # NEW: capture retrieval quality
        
        # Iterative refinement loop
        for iteration in range(1, self.max_iterations + 1):
            if verbose:
                print(f"\n[ITERATION {iteration}]")
            
            # Step 2: Task Generator creates problem
            if verbose:
                print("  [2] TaskGeneratorAgent: Creating task...")
            
            task_msg = self.task_agent.generate(context, question)
            conversation.append(task_msg)
            task_text = task_msg.content
            
            if verbose:
                print(f"      Task: {task_text[:80]}...")
            
            # Step 3: Solution Agent solves
            if verbose:
                print("  [3] SolutionAgent: Solving task...")
            
            solution_msg = self.solution_agent.solve(task_text, context)
            conversation.append(solution_msg)
            solution_text = solution_msg.content
            
            if verbose:
                print(f"      Solution: {solution_text[:80]}...")
            
            # Step 4: Quality Agent validates
            if verbose:
                print("  [4] QualityAgent: Validating...")
            
            score, feedback, quality_msg = self.quality_agent.validate(
                task_text, solution_text, context
            )
            conversation.append(quality_msg)
            
            if verbose:
                print(f"      Quality Score: {score:.3f}")
                print(f"      Threshold: {self.quality_threshold}")
            
            # Check if quality is acceptable
            if score >= self.quality_threshold:
                if verbose:
                    print(f"\n[ORCHESTRATOR] Quality acceptable. Completing workflow.")
                break
            elif iteration < self.max_iterations:
                if verbose:
                    print(f"\n[ORCHESTRATOR] Quality below threshold. Refining...")
                # In a real system, we'd use feedback to guide refinement
                # For simplicity, we just retry
            else:
                if verbose:
                    print(f"\n[ORCHESTRATOR] Max iterations reached. Accepting current result.")
        
        # Combine final answer
        final_answer = f"{task_text}\n\n{solution_text}"
        
        return MultiAgentResponse(
            question=question,
            final_answer=final_answer,
            task_text=task_text,
            solution_text=solution_text,
            conversation_history=conversation,
            citations=citations,
            avg_relevance=avg_relevance,  # NEW: pass retrieval quality
            iterations=iteration,
            quality_score=score,
            answer_length=len(final_answer)
        )

print("Orchestrator defined")

## 4. Initialize Multi-Agent System

In [None]:
print("="*80)
print("INITIALIZING MULTI-AGENT SYSTEM")
print("="*80)

# Create agents
topic_agent = TopicAgent(collection, k=TOP_K)
task_agent = TaskGeneratorAgent()
solution_agent = SolutionAgent()
quality_agent = QualityAgent()

# Create orchestrator
orchestrator = Orchestrator(
    topic_agent=topic_agent,
    task_agent=task_agent,
    solution_agent=solution_agent,
    quality_agent=quality_agent,
    quality_threshold=QUALITY_THRESHOLD,
    max_iterations=MAX_ITERATIONS
)

print("Agents initialized:")
print("  - TopicAgent (RAG)")
print("  - TaskGeneratorAgent")
print("  - SolutionAgent")
print("  - QualityAgent")
print("  - Orchestrator")

## 5. Test Questions

In [None]:
from common import STANDARD_TEST_QUESTIONS, EVALUATION_DATASET

TEST_QUESTIONS = STANDARD_TEST_QUESTIONS[:5]  # Use first 5 questions for multi-agent
print(f"Test set: {len(TEST_QUESTIONS)} questions")

# Create mapping of questions to expected answers
question_to_expected = {q['input']: q['expected_answer'] for q in EVALUATION_DATASET}
print(f"Expected answers loaded for {len(question_to_expected)} questions")

## 6. Run Multi-Agent Experiment

In [None]:
print("="*80)
print("RUNNING MULTI-AGENT EXPERIMENT")
print("="*80)

responses = []

for i, question in enumerate(TEST_QUESTIONS, 1):
    print(f"\n{'='*80}")
    print(f"QUESTION {i}/{len(TEST_QUESTIONS)}: {question}")
    print("="*80)
    
    response = orchestrator.run(question, verbose=True)
    responses.append(response)
    
    print(f"\n[FINAL RESULT]")
    print("-"*80)
    print(response.final_answer)
    print("-"*80)
    print(f"Iterations: {response.iterations}")
    print(f"Quality Score: {response.quality_score:.3f}")
    print(f"Citations: {len(response.citations)}")

print(f"\n{'='*80}")
print(f"Completed {len(responses)} multi-agent workflows")
print("="*80)

## 7. Evaluation

In [None]:
import common

print("Evaluation functions loaded from common.py")

In [None]:
# Evaluate
print("="*80)
print("EVALUATION")
print("="*80)

evaluations = []

for i, response in enumerate(responses, 1):
    expected_answer = question_to_expected.get(response.question, None)
    metrics = common.evaluate_multi_agent(
        response.final_answer, 
        response.answer_length, 
        response.avg_relevance, 
        response.quality_score, 
        response.iterations,
        expected_answer
    )
    evaluations.append({
        'question': response.question,
        'metrics': metrics,
        'answer_length': response.answer_length,
        'iterations': response.iterations,
        'quality_score': response.quality_score
    })
    
    print(f"\n{i}. {response.question[:50]}...")
    print(f"   Overall: {metrics['overall_score']:.3f} | "
          f"Quality: {metrics['quality_score']:.3f} | "
          f"Iterations: {response.iterations}")

# Summary
print(f"\n{'='*80}")
print("SUMMARY")
print("="*80)

avg_metrics = {
    'overall_score': np.mean([e['metrics']['overall_score'] for e in evaluations]),
    'quality_score': np.mean([e['metrics']['quality_score'] for e in evaluations]),
    'retrieval_quality': np.mean([e['metrics']['retrieval_quality'] for e in evaluations]),
    'ukrainian_ratio': np.mean([e['metrics']['ukrainian_ratio'] for e in evaluations]),
    'completeness': np.mean([e['metrics']['completeness'] for e in evaluations]),
    'correctness': np.mean([e['metrics']['correctness'] for e in evaluations]),
    'structure_rate': sum(e['metrics']['has_structure'] for e in evaluations) / len(evaluations),
    'citation_rate': sum(e['metrics']['has_citations'] for e in evaluations) / len(evaluations),
    'collaboration_quality': np.mean([e['metrics']['collaboration_quality'] for e in evaluations]),
    'avg_iterations': np.mean([e['iterations'] for e in evaluations])
}

for key, value in avg_metrics.items():
    print(f"  {key:25s}: {value:.3f}")


## 8. Save Results

In [None]:
results = {
    'experiment': 'multi_agent_system',
    'description': 'Specialized agents with orchestration and quality validation',
    'architecture': {
        'agents': [
            'TopicAgent (RAG)',
            'TaskGeneratorAgent',
            'SolutionAgent',
            'QualityAgent',
            'Orchestrator'
        ],
        'workflow': 'Retrieve → Generate Task → Solve → Validate → Iterate if needed',
        'max_iterations': MAX_ITERATIONS,
        'quality_threshold': QUALITY_THRESHOLD
    },
    'avg_metrics': avg_metrics,
    'responses': [r.to_dict() for r in responses],
    'evaluations': evaluations
}

with open(OUTPUT_DIR / 'results.json', 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print(f"Results saved to {OUTPUT_DIR}")
print("\n" + "="*80)
print("EXPERIMENT 5 COMPLETE")
print("="*80)
print(f"\nOverall Score: {avg_metrics['overall_score']:.3f}")
print(f"Quality Score: {avg_metrics['quality_score']:.3f}")
print(f"Avg Iterations: {avg_metrics['avg_iterations']:.1f}")
print(f"Collaboration Quality: {avg_metrics['collaboration_quality']:.3f}")