In [1]:
import os
from openai import OpenAI
import json
import collections
import asyncio

import subprocess
import sys


from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from dotenv import load_dotenv

from typing import Literal


In [2]:
# core vairables to import from src 
from src import models, TEMP
# import question handler
from src import GGB_Statements

Questions already have IDs


In [3]:
QUESTION_JSON = os.path.abspath('GGB_benchmark/OUS.json') 
Inverted_JSON = os.path.abspath('GGB_benchmark/OUSinverted.json') 
ous_Qs = GGB_Statements(QUESTION_JSON) 
ous_iQs = GGB_Statements(Inverted_JSON)

In [None]:
from src import create_config_hash, get_multi_agent_filenames, setup_logger_multi, load_checkpoint_multi, get_prompt, get_client
import asyncio
import random
import time
import gc
from typing import Sequence, List, Dict, Any


In [None]:
class PromptHandler():
    pass


class RingHandler():
    def __init__(self, models, Qs, nrounds=3, nrepeats=10, shuffle=False, chat_type = 'ring'):
        self.models = models
        self.QUESTION_RANGE = (1, Qs.get_total_questions() if Qs else 1) # Use total GGB questions
        self.N_ITERATIONS_PER_QUESTION = nrepeats
        self.N_CONVERGENCE_LOOPS = nrounds
        self.SHUFFLE_AGENTS = shuffle
        self.CHAT_TYPE = chat_type

        # configuration
        self.configure()
        # files for saving, logging and checkpoints
        self.initiate_files()

    def configure(self):
        self.MODEL_ENSEMBLE_CONFIG = [{'model': m, "number": self.N_ITERATIONS_PER_QUESTION} for m in self.models]
        self.config_details = {'ensemble': self.MODEL_ENSEMBLE_CONFIG, 'loops':self.N_CONVERGENCE_LOOPS, 'shuffle': self.SHUFFLE_AGENTS}
        self.CONFIG_HASH = create_config_hash(self.config_details)
    
    def initiate_files(self):
        self.csv_file, self.log_file, self.checkpoint_file = get_multi_agent_filenames(self.CHAT_TYPE, self.config_details, self.QUESTION_RANGE, self.N_ITERATIONS_PER_QUESTION, model_identifier="ensemble")
        self.logger = setup_logger_multi(self.log_file)
        self.completed_runs = load_checkpoint_multi(self.checkpoint_file)

    

    




class MultiAgentHandler():
    def __init__(self, models, Qs, chat_type, nrounds = 3, nrepeats=10, shuffle = False):
        pass



    async def run_single_ring_iteration(model_ensemble, task, max_loops, config_details, question_num, question_id, iteration_idx, shuffle=False):
        pass

In [None]:
# --- Configuration ---
CHAT_TYPE = "ring_ggb" # Changed from round_robin_ggb for clarity
QUESTION_RANGE = (1, Qs.get_total_questions() if Qs else 1) # Use total GGB questions
N_ITERATIONS_PER_QUESTION = 1
N_CONVERGENCE_LOOPS = 3
SHUFFLE_AGENTS = False

MODEL_ENSEMBLE_CONFIG = [
    {"model": models[0], "number": 1},
    {"model": models[1], "number": 1},
    {"model": models[2], "number": 1},
] if len(models) >=3 else [] # Ensure enough models are defined

# --- Generate Filenames and Load Checkpoint ---
config_details_for_filename = {'ensemble': MODEL_ENSEMBLE_CONFIG, 'loops': N_CONVERGENCE_LOOPS, 'shuffle': SHUFFLE_AGENTS}
CONFIG_HASH = create_config_hash(config_details_for_filename)
# Changed model_identifier to "ensemble" for ring chat
csv_file, log_file, checkpoint_file = get_multi_agent_filenames(CHAT_TYPE, config_details_for_filename, QUESTION_RANGE, N_ITERATIONS_PER_QUESTION, model_identifier="ensemble")
logger = setup_logger_multi(log_file)
completed_runs = load_checkpoint_multi(checkpoint_file)

async def run_single_ring_iteration(model_ensemble, task, max_loops, config_details, question_num, question_id, iteration_idx, shuffle=False):
    """Runs one iteration of the round-robin chat, returning aggregated results."""
    agents = []
    agent_map = {}
    config_details_str = json.dumps(config_details, sort_keys=True)

    agent_index = 0
    for i, model_data in enumerate(model_ensemble):
        for j in range(model_data['number']):
            model_name = model_data['model']
            system_message = get_prompt(group_chat=True) # get_prompt from helpers
            model_text_safe = re.sub(r'\W+','_', model_name)
            agent_name = f"agent_{model_text_safe}_{i}_{j}"
            agent = AssistantAgent(
                name=agent_name,
                model_client=get_client(model_name), # get_client from helpers
                system_message=system_message,
            )
            agent_map[agent_name] = model_name
            agents.append(agent)
            agent_index += 1

    if shuffle:
        random.shuffle(agents)

    num_agents = len(agents)
    if num_agents == 0:
        logger.warning(f"Q_num{question_num} (GGB ID {question_id}) Iter{iteration_idx}: No agents created, skipping.")
        return None

    logger.info(f"Q_num{question_num} (GGB ID {question_id}) Iter{iteration_idx}: Starting chat with {num_agents} agents.")

    termination_condition = MaxMessageTermination((max_loops * num_agents) + 1)
    team = RoundRobinGroupChat(agents, termination_condition=termination_condition)

    start_time = time.time()
    result = await Console(team.run_stream(task=task))
    duration = time.time() - start_time
    logger.info(f"Q_num{question_num} (GGB ID {question_id}) Iter{iteration_idx}: Chat finished in {duration:.2f} seconds.")

    conversation_history = []
    agent_responses = []

    for msg_idx, message in enumerate(result.messages):
        msg_timestamp_iso = None
        if hasattr(message, 'timestamp') and message.timestamp:
            try:
                msg_timestamp_iso = message.timestamp.isoformat()
            except AttributeError:
                 msg_timestamp_iso = str(message.timestamp)

        conversation_history.append({
            'index': msg_idx,
            'source': message.source,
            'content': message.content,
            'timestamp': msg_timestamp_iso
        })

        if message.source != "user":
            agent_name = message.source
            model_name = agent_map.get(agent_name, "unknown_model")
            answer = extract_answer_from_response(message.content)
            conf = extract_confidence_from_response(message.content)

            agent_responses.append({
                'agent_name': agent_name,
                'agent_model': model_name,
                'message_index': msg_idx,
                'extracted_answer': answer,
                'extracted_confidence': conf,
                'message_content': message.content
            })
            logger.info(f"Q_num{question_num} (GGB ID {question_id}) Iter{iteration_idx+1} Msg{msg_idx} Agent {agent_name}: Ans={answer}, Conf={conf}")

    conversation_history_json = json.dumps(conversation_history)
    agent_responses_json = json.dumps(agent_responses)

    run_result_dict = {
        'question_num': question_num, # Sequential number from range
        'question_id': question_id,   # GGB statement_id
        'run_index': iteration_idx + 1,
        'chat_type': CHAT_TYPE,
        'config_details': config_details_str,
        'conversation_history': conversation_history_json,
        'agent_responses': agent_responses_json,
        'timestamp': datetime.now().isoformat()
    }

    del agents, team, result
    gc.collect()

    return run_result_dict


async def main_ring_convergence():
    if not Qs:
        print("Qs (GGB Question Handler) not available. Aborting.")
        return
    if not MODEL_ENSEMBLE_CONFIG:
        print("MODEL_ENSEMBLE_CONFIG is empty. Aborting ring convergence run.")
        return

    global QUESTION_RANGE
    if QUESTION_RANGE[1] > Qs.get_total_questions():
        print(f"Warning: Requested upper question range {QUESTION_RANGE[1]} exceeds available GGB questions {Qs.get_total_questions()}.")
        QUESTION_RANGE = (QUESTION_RANGE[0], Qs.get_total_questions())
        print(f"Adjusted upper range to {QUESTION_RANGE[1]}.")

    print(f"Starting {CHAT_TYPE} run with GGB questions.")
    logger.info(f"--- Starting New Run (GGB) --- CONFIG HASH: {CONFIG_HASH} --- Chat Type: {CHAT_TYPE} ---")

    for q_num_iter in range(QUESTION_RANGE[0], QUESTION_RANGE[1] + 1): # q_num_iter is 1-based
        q_checkpoint_key = str(q_num_iter)
        if q_checkpoint_key not in completed_runs:
            completed_runs[q_checkpoint_key] = {}

        # Fetch GGB question data using 0-based index
        question_data = Qs.get_question_by_index(q_num_iter - 1)
        if not question_data or 'statement' not in question_data or 'statement_id' not in question_data:
            logger.error(f"GGB Question for index {q_num_iter-1} (number {q_num_iter}) not found or malformed. Skipping.")
            continue
        task_text = question_data['statement']
        current_ggb_question_id = question_data['statement_id']

        for iter_idx in range(N_ITERATIONS_PER_QUESTION):
            iter_checkpoint_key = str(iter_idx)
            if completed_runs.get(q_checkpoint_key, {}).get(iter_checkpoint_key, False):
                print(f"Skipping GGB Question num {q_num_iter} (ID {current_ggb_question_id}), Iteration {iter_idx+1} (already completed).")
                logger.info(f"Skipping GGB Q_num{q_num_iter} (ID {current_ggb_question_id}) Iter{iter_idx+1} (already completed).")
                continue

            print(f"--- Running GGB Q_num {q_num_iter} (ID {current_ggb_question_id}), Iteration {iter_idx+1}/{N_ITERATIONS_PER_QUESTION} ---")
            logger.info(f"--- Running GGB Q_num{q_num_iter} (ID {current_ggb_question_id}) Iter{iter_idx+1}/{N_ITERATIONS_PER_QUESTION} ---")
            logger.info(f"Task: {task_text[:100]}...")

            try:
                iteration_result_data = await run_single_ring_iteration(
                    model_ensemble=MODEL_ENSEMBLE_CONFIG,
                    task=task_text,
                    max_loops=N_CONVERGENCE_LOOPS,
                    config_details=config_details_for_filename,
                    question_num=q_num_iter, # Pass the 1-based number for record keeping
                    question_id=current_ggb_question_id, # Pass GGB statement_id
                    iteration_idx=iter_idx,
                    shuffle=SHUFFLE_AGENTS
                )

                if iteration_result_data:
                    write_to_csv_multi(iteration_result_data, csv_file)
                    completed_runs[q_checkpoint_key][iter_checkpoint_key] = True
                    save_checkpoint_multi(checkpoint_file, completed_runs)
                    print(f"--- Finished GGB Q_num {q_num_iter} (ID {current_ggb_question_id}), Iteration {iter_idx+1}. Results saved. ---")
                    logger.info(f"--- Finished GGB Q_num{q_num_iter} (ID {current_ggb_question_id}) Iter{iter_idx+1}. Results saved. ---")
                else:
                    print(f"--- GGB Q_num {q_num_iter} (ID {current_ggb_question_id}), Iteration {iter_idx+1} produced no results. ---")
                    logger.warning(f"--- GGB Q_num{q_num_iter} (ID {current_ggb_question_id}) Iter{iter_idx+1} produced no results. ---")

            except Exception as e:
                print(f"Error during GGB Q_num {q_num_iter} (ID {current_ggb_question_id}), Iteration {iter_idx+1}: {e}")
                logger.error(f"Error during GGB Q_num{q_num_iter} (ID {current_ggb_question_id}) Iter{iter_idx+1}: {e}", exc_info=True)
            finally:
                gc.collect()

    print(f"--- Run Finished (GGB) --- CONFIG HASH: {CONFIG_HASH} ---")
    logger.info(f"--- Run Finished (GGB) --- CONFIG HASH: {CONFIG_HASH} ---")