<a href="https://colab.research.google.com/github/rapturt9/wisdom_agents/blob/sinem/run_single_agents_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ARE YOU IN COLAB?
in_colab = False

In [None]:
# Core Variables are now primarily in helpers.py
# These can be overridden here if needed for specific notebook runs
TEMP_OVERRIDE = 1
MODELS_OVERRIDE = ["openai/gpt-4o-mini", "anthropic/claude-3.5-haiku", "mistralai/mistral-7b-instruct"] # Example override

# 1. API Definitions/Setup & Imports

In [None]:
if in_colab:
    !pip install -U "autogen-agentchat" "autogen-ext[openai,azure]"
    !pip install python-dotenv
    # If helpers.py and GreatestGoodBenchmark.json are not in the root, upload them or clone the repo
    # Example: !git clone <your_repo_url> and then adjust sys.path if needed

In [None]:
import os
import json
import collections
import sys

# Add the parent directory to sys.path to find the helpers module if running from a subdirectory
# Or ensure helpers.py is in the same directory or PYTHONPATH
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir))) # Adjust if your structure differs
try:
    from helpers import Qs, get_prompt as common_get_prompt, models as common_models, TEMP as common_TEMP, get_client as common_get_client
except ImportError as e:
    print(f"Error importing from helpers: {e}")
    print("Please ensure helpers.py is accessible and all its dependencies are installed.")
    # Fallback or raise error if helpers are critical
    # For now, we'll define a placeholder Qs if import fails, but this will cause issues later.
    if 'Qs' not in globals(): Qs = None 
    if 'common_get_prompt' not in globals(): common_get_prompt = lambda: "" 
    if 'common_models' not in globals(): common_models = []
    if 'common_TEMP' not in globals(): common_TEMP = 1
    if 'common_get_client' not in globals(): common_get_client = lambda x: None

# Use overridden values if they exist, otherwise use values from helpers
TEMP = TEMP_OVERRIDE if 'TEMP_OVERRIDE' in globals() and TEMP_OVERRIDE is not None else common_TEMP
models = MODELS_OVERRIDE if 'MODELS_OVERRIDE' in globals() and MODELS_OVERRIDE else common_models
get_prompt = common_get_prompt # Using the one from helpers
get_client = common_get_client # Using the one from helpers

from openai import OpenAI # OpenAI might still be needed directly or is handled by get_client
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
# from autogen_ext.models.openai import OpenAIChatCompletionClient # This is now in helpers.get_client
from dotenv import load_dotenv
load_dotenv()

API_KEY = os.environ.get("OPENROUTER_API_KEY") # Ensure API_KEY is loaded for get_client in helpers
if not API_KEY and not in_colab: # In Colab, key might be loaded differently by get_client
    print("Warning: OPENROUTER_API_KEY not found in .env file for local execution.")

In [None]:
# Question_Handler class is removed as we are using Qs from helpers.py
# ...existing code...

In [None]:
# @title: code for writing files and saving checkpoints
import os
import csv
import asyncio
import json
from datetime import datetime

def get_consistent_filenames(model_name, question_range, num_runs):
    """Generates consistent base filename and full paths for csv, log, and checkpoint files."""
    safe_model_name = model_name.replace("/", "_").replace(":", "_")
    q_start, q_end = question_range
    base_filename = f"single_{safe_model_name}_q{q_start}-{q_end}_n{num_runs}"

    csv_dir = 'results'
    log_dir = 'logs'
    checkpoint_dir = 'checkpoints'
    os.makedirs(csv_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)

    csv_file = os.path.join(csv_dir, f"{base_filename}.csv")
    log_file = os.path.join(log_dir, f"{base_filename}.log")
    checkpoint_file = os.path.join(checkpoint_dir, f"{base_filename}_checkpoint.json")

    return csv_file, log_file, checkpoint_file


def save_checkpoint(checkpoint_file, completed_runs):
    """Save the current progress to the specified checkpoint file."""
    try:
        with open(checkpoint_file, 'w') as f:
            json.dump(completed_runs, f, indent=4)
        # print(f"Checkpoint saved to {checkpoint_file}") # Can be verbose
    except Exception as e:
        print(f"Error saving checkpoint to {checkpoint_file}: {e}")


def load_checkpoint(checkpoint_file):
    """Load progress from a checkpoint file."""
    if not os.path.exists(checkpoint_file):
        print(f"Checkpoint file {checkpoint_file} not found. Starting fresh.")
        return {}
    try:
        with open(checkpoint_file, 'r') as f:
            completed_runs = json.load(f)
        print(f"Loaded checkpoint from {checkpoint_file}")
        # Optional: Add more detail about loaded data if needed
        # Example: print(f"... found {len(completed_runs.get(list(completed_runs.keys())[0], {}))} completed questions for the first model.")
        return completed_runs
    except json.JSONDecodeError:
        print(f"Error decoding JSON from checkpoint file {checkpoint_file}. Starting fresh.")
        return {}
    except Exception as e:
        print(f"Error loading checkpoint {checkpoint_file}: {e}. Starting fresh.")
        return {}



In [None]:
import os
# OpenAI import might be redundant if get_client handles it all
import json
import collections
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
# OpenAIChatCompletionClient is now handled by get_client from helpers
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import MaxMessageTermination
import logging # Added for logger setup in run_single_agent_and_save

def extract_answer_from_response(content):
    # Extract the answer from the response. Adapt this to your exact response structure.
    start_index = content.find("<ANSWER>")
    end_index = content.find("</ANSWER>")
    if start_index != -1 and end_index != -1:
        return content[start_index + len("<ANSWER>"):end_index].strip()
    return "No answer found in the agent's response."

def extract_confidence_from_response(content):
  start_index = content.find("<CONF>")
  end_index = content.find("</CONF>")
  if start_index != -1 and end_index != -1:
    return content[start_index + len("<CONF>"):end_index].strip()
  return "No confidence found in the agent's response."

class Single_Agent_Handler():
  def __init__(self, model_name:str, ggb_question_handler, prompt_template = None): # Renamed to ggb_question_handler
    self.model_name = model_name
    self.ggb_questions = ggb_question_handler # Using GGB_Statements instance
    self.client = get_client(model_name) # get_client is from helpers
    if prompt_template is None:
      self.prompt = get_prompt(group_chat=False) # get_prompt is from helpers
    else:
      self.prompt = prompt_template

  async def run_single_agent_single_question(self, question_number=1): # question_number is 1-based
    # returns full response (content of message), answer, confidence, question_id
    question_data = self.ggb_questions.get_question_by_index(question_number - 1) # 0-based index

    if question_data is None or 'statement' not in question_data or 'statement_id' not in question_data:
      print(f"Question data for index {question_number-1} (number {question_number}) not found or malformed!")
      return None, None, None, None
    question_text = question_data['statement']
    question_id = question_data['statement_id'] # This is the GGB statement_id

    agent = AssistantAgent(
        name="assistant_agent",
        model_client=self.client,
        system_message=self.prompt
    )

    team = RoundRobinGroupChat([agent], termination_condition=MaxMessageTermination(2))
    result = await Console(team.run_stream(task=question_text))

    response_content = result.messages[-1].content
    answer = extract_answer_from_response(response_content)
    confidence = extract_confidence_from_response(response_content)

    return answer, confidence, response_content, question_id

  async def run_single_agent_multiple_times(self, question_number=1, num_runs=10):
    results = []
    for _ in range(num_runs):
        run_output = await self.run_single_agent_single_question(question_number)
        if run_output and run_output[0] is not None: # Check if answer is not None
            results.append(run_output) # (answer, confidence, response_content, question_id)
        else:
            print(f"Task returned None or malformed data for question {question_number}")
            # Append a placeholder if necessary, or handle error
            results.append((None, None, None, self.ggb_questions.get_question_by_index(question_number - 1).get('statement_id', 'unknown_id_error')))

    answers = [res[0] for res in results]
    confidences = [res[1] for res in results]
    responses = [res[2] for res in results]
    question_ids = [res[3] for res in results] # All should be the same for a given question_number

    return answers, confidences, responses, question_ids[0] if question_ids else None

  async def run_single_agent_and_save(self, question_range=(1, 88), num_runs=1):
    model_name = self.model_name
    q_start, q_end = question_range
    csv_file, log_file, checkpoint_file = get_consistent_filenames(model_name, question_range, num_runs)
    completed_runs = load_checkpoint(checkpoint_file)
    all_results_this_session = []
    question_numbers_to_process = list(range(q_start, q_end + 1))

    logger_name = os.path.basename(log_file).replace('.log', '')
    logger = logging.getLogger(logger_name)
    if not logger.handlers:
        logger.setLevel(logging.INFO)
        file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    print(f"Starting/Resuming run for model {model_name} using GGB questions")
    logger.info(f"--- Starting/Resuming Run (GGB) --- Model: {model_name}, Questions: {question_range}, Runs: {num_runs} ---")

    model_checkpoint_key = str(model_name) 
    if model_checkpoint_key not in completed_runs:
        completed_runs[model_checkpoint_key] = {}

    for question_num in question_numbers_to_process:
        q_checkpoint_key = str(question_num)
        if completed_runs[model_checkpoint_key].get(q_checkpoint_key, False):
            continue

        try:
            print(f"Processing GGB question number {question_num} (index {question_num-1})...")
            logger.info(f"Processing GGB question number {question_num} (index {question_num-1})")

            # Fetch GGB question_data to log statement_id and text
            question_data = self.ggb_questions.get_question_by_index(question_num - 1)
            if not question_data or 'statement_id' not in question_data:
                logger.warning(f"GGB Question for index {question_num-1} not found or malformed! Skipping.")
                continue
            current_question_id = question_data['statement_id'] # This is GGB statement_id
            logger.info(f"GGB Stmt ID: {current_question_id}, Text: {question_data['statement'][:100]}...")

            answers, confidences, responses, q_id_from_run = await self.run_single_agent_multiple_times(
                question_number=question_num,
                num_runs=num_runs
            )
            if q_id_from_run != current_question_id and q_id_from_run is not None:
                 logger.warning(f"Mismatch in question ID for Q_num {question_num}. Expected {current_question_id}, got {q_id_from_run}")
            # Use current_question_id as the definitive ID for this loop iteration

            question_results_for_csv = []
            for i in range(len(answers)):
                result_obj = {
                    "model_name": model_name,
                    "question_num": question_num, # This is the sequential number from range
                    "question_id": current_question_id, # This is GGB statement_id
                    "run_index": i + 1,
                    "answer": answers[i],
                    "confidence": confidences[i],
                    "full_response": responses[i]
                }
                question_results_for_csv.append(result_obj)

            self._write_to_csv(question_results_for_csv, csv_file)
            all_results_this_session.extend(question_results_for_csv)
            completed_runs[model_checkpoint_key][q_checkpoint_key] = True
            save_checkpoint(checkpoint_file, completed_runs)
            print(f"  GGB Question number {question_num} (Stmt ID: {current_question_id}) completed and saved.")
            logger.info(f"GGB Question number {question_num} (Stmt ID: {current_question_id}) completed.")

        except Exception as e:
            print(f"Error processing GGB question number {question_num}: {str(e)}")
            logger.error(f"Error processing GGB question number {question_num}: {str(e)}", exc_info=True)

    processed_count = len(all_results_this_session)
    print(f"Run finished for model {model_name}. Added {processed_count} new GGB results this session.")
    logger.info(f"--- Run Finished (GGB) --- Model: {model_name}. Added {processed_count} new results. ---")
    return all_results_this_session, csv_file, log_file

  def _write_to_csv(self, results, csv_file):
    file_exists = os.path.exists(csv_file)
    is_empty = not file_exists or os.path.getsize(csv_file) == 0
    os.makedirs(os.path.dirname(csv_file) if os.path.dirname(csv_file) else '.', exist_ok=True)
    with open(csv_file, 'a', newline='', encoding='utf-8') as f:
        if results:
            # Ensure question_id is part of fieldnames
            fieldnames = ['model_name', 'question_num', 'question_id', 'run_index', 'answer', 'confidence', 'full_response']
            writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction='ignore')
            if is_empty:
                writer.writeheader()
            writer.writerows(results)


In [None]:
# --- Configuration ---
QUESTION_RANGE = (1, Qs.get_total_questions() if Qs else 1)  # Use total questions from GGB
NUM_RUNS = 10             # Define the number of runs per question
# Use models list from helpers, potentially overridden by MODELS_OVERRIDE
MODELS_TO_RUN = models[:3] # Select which models to run (e.g., first 3 from resolved 'models' list)

# --- Execution Loop ---
async def run_all_models():
    if Qs is None:
        print("Error: Qs (GGB_Statements handler from helpers.py) is not initialized. Cannot run.")
        return

    print(f"Total GGB questions available: {Qs.get_total_questions()}")
    # Adjust QUESTION_RANGE if it exceeds available questions
    global QUESTION_RANGE
    if QUESTION_RANGE[1] > Qs.get_total_questions():
        print(f"Warning: Requested upper question range {QUESTION_RANGE[1]} exceeds available GGB questions {Qs.get_total_questions()}.")
        print(f"Adjusting upper range to {Qs.get_total_questions()}.")
        QUESTION_RANGE = (QUESTION_RANGE[0], Qs.get_total_questions())

    for this_model in MODELS_TO_RUN:
        print(f"\n--- Initializing handler for model: {this_model} with GGB Questions ---")
        # Pass the imported Qs (GGB_Statements instance) to the handler
        handler = Single_Agent_Handler(this_model, Qs)

        results_session, csv_file_path, log_file_path = await handler.run_single_agent_and_save(
            question_range=QUESTION_RANGE,
            num_runs=NUM_RUNS
        )

        print(f"Run session completed for {this_model}. Results appended to {csv_file_path}")
        print(f"Full logs appended to {log_file_path}")

        del handler
        del results_session
        print(f"--- Finished handler for model: {this_model} ---\n")

# --- Start Execution ---
async def main():
    await run_all_models()

if __name__ == '__main__':
    # Standard way to run asyncio main in a script/notebook
    # In Jupyter, top-level await might work, but this is more robust.
    if 'get_ipython' in globals() and get_ipython().__class__.__name__ == 'ZMQInteractiveShell':
        # Running in Jupyter, ensure nest_asyncio if needed or handle event loop
        import nest_asyncio
        nest_asyncio.apply()
        asyncio.run(main())
    else:
        asyncio.run(main())