In [None]:
# 1. Imports and Setup
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
import numpy as np
import re
import sys
import glob
from collections import Counter, defaultdict
import ast  # For safely evaluating string representations of lists/dicts

# Define directories
RESULTS_DIR = 'results' # Directory containing single-agent CSV results
RESULTS_DIR_SINGLE = 'results'
RESULTS_DIR_MULTI = 'results_multi'
PLOT_DIR = 'plots'
os.makedirs(PLOT_DIR, exist_ok=True)

# Define the datasets (categories) to include
# These should match the category names returned by Question_Handler
INCLUDED_DATASETS = ['MFQ_30', '6_concepts']

# Create plots directory if it doesn't exist
os.makedirs(PLOT_DIR, exist_ok=True)

# Add MoralBench repo to path to import Question_Handler
MORAL_BENCH_REPO_DIR = '../MoralBench_AgentEnsembles' # Adjust if needed
moral_bench_path = os.path.abspath(MORAL_BENCH_REPO_DIR)
if moral_bench_path not in sys.path:
    sys.path.insert(0, moral_bench_path)
print(f"Using MoralBench repository at: {moral_bench_path}")

In [None]:
# Updated mapping based on standard MFT + Liberty
dict_map = {
   'authority': 'Authority',
   'fairness': 'Fairness',
   'harm': 'Harm', # Care/Harm
   'ingroup': 'Loyalty', # Loyalty/Betrayal
   'purity': 'Sanctity', # Sanctity/Degradation
   'liberty': 'Liberty'
}

# Define the order for plotting categories
PLOT_CATEGORIES = ['Harm', 'Fairness', 'Loyalty', 'Authority', 'Sanctity', 'Liberty']

In [None]:
# 2. Question Handler Definition (Copied for self-containment)
# Note: Ideally, this would be imported from a shared module.
class Question_Handler():
  def __init__(self, repo_dir):
    self.repo_dir = os.path.abspath(repo_dir) # Use absolute path
    self.questions_dir = os.path.join(self.repo_dir, 'questions')
    self.answers_dir = os.path.join(self.repo_dir, 'answers')
    self.categories = self.list_categories()
    self._build_question_map()

  def _build_question_map(self):
      """Builds a map from question number to (category, index)."""
      self.question_map = {}
      current_question_num = 1
      for category in self.categories:
          count = self.get_question_count(category)
          for i in range(count):
              self.question_map[current_question_num] = {'category': category, 'index': i}
              current_question_num += 1
      self.total_questions = current_question_num - 1

  def get_question_category_and_index(self, question_number):
      """Gets the category and index for a given question number."""
      return self.question_map.get(question_number)

  def get_question_category(self, question_number):
      """Gets the category for a given question number."""
      mapping = self.question_map.get(question_number)
      return mapping['category'] if mapping else None

  def get_question_count(self, category_folder):
      """
      Get the number of questions in a specific category folder.
      """
      questions_path = os.path.join(self.questions_dir, category_folder)
      if not os.path.exists(questions_path):
          # print(f"Warning: Category folder {questions_path} does not exist!")
          return 0
      try:
          question_files = [f for f in os.listdir(questions_path) if f.endswith('.txt')]
          return len(question_files)
      except FileNotFoundError:
          # print(f"Warning: Error accessing category folder {questions_path}.")
          return 0

  def list_categories(self):
      """
      List all available question categories.
      """
      if not os.path.exists(self.questions_dir):
          print(f"Warning: Questions directory {self.questions_dir} not found!")
          return []
      try:
          categories = sorted([d for d in os.listdir(self.questions_dir) if os.path.isdir(os.path.join(self.questions_dir, d))])
          return categories
      except FileNotFoundError:
           print(f"Warning: Error listing categories in {self.questions_dir}.")
           return []

  def load_question_answer(self, category_folder, index):
      """
      Load a question and its possible answers using an index.
      """
      questions_path = os.path.join(self.questions_dir, category_folder)
      if not os.path.exists(questions_path):
          # print(f"Warning: Category folder {questions_path} does not exist!")
          return None

      try:
          # Get all question files and sort them
          question_files = sorted([f for f in os.listdir(questions_path) if f.endswith('.txt')])

          if index < 0 or index >= len(question_files):
              # print(f"Warning: Index {index} is out of range for category {category_folder}! Valid range: 0-{len(question_files)-1}")
              return None

          # Get question filename and ID
          question_file = question_files[index]
          question_id = os.path.splitext(question_file)[0]

          # Read question content
          question_path = os.path.join(questions_path, question_file)
          with open(question_path, 'r', encoding='utf-8') as f:
              question_text = f.read()

          # Load answers from JSON
          answers_path = os.path.join(self.repo_dir, 'answers', f"{category_folder}.json") # Corrected path
          question_answers = None
          if os.path.exists(answers_path):
              try:
                  with open(answers_path, 'r', encoding='utf-8') as f:
                      all_answers = json.load(f)
                  question_answers = all_answers.get(question_id, {})
              except json.JSONDecodeError:
                  print(f"Warning: Error decoding JSON from {answers_path}")
              except Exception as e:
                  print(f"Warning: Error reading answers file {answers_path}: {e}")
          # else:
              # print(f"Warning: Answers file {answers_path} for {category_folder} does not exist!")

          return {
              'question_id': question_id,
              'question_text': question_text,
              'answers': question_answers
          }
      except FileNotFoundError:
          # print(f"Warning: Error accessing files in {questions_path}.")
          return None
      except Exception as e:
          print(f"Warning: Unexpected error loading question {category_folder}/{index}: {e}")
          return None

  def get_question(self, number):
      """Gets question data by absolute number."""
      mapping = self.get_question_category_and_index(number)
      if mapping:
          return self.load_question_answer(mapping['category'], mapping['index'])
      else:
          # print(f"Warning: Question number {number} not found in map.")
          return None

  def get_total_question_count(self):
      """Returns the total number of questions across all categories."""
      return self.total_questions

# --- Initialize Question Handler ---
try:
    Qs = Question_Handler(MORAL_BENCH_REPO_DIR)
    print(f"Question Handler initialized. Found {Qs.get_total_question_count()} questions in {len(Qs.categories)} categories.")
    print(f"Available categories: {Qs.categories}")
except Exception as e:
    print(f"Error initializing Question_Handler: {e}")
    Qs = None

In [None]:
Qs.get_question_category_and_index(88)

# Qs.load_question_answer('MFQ_30', 88)

In [None]:
# making sure they all have question_id
if Qs:
    for i in range(1,89):
        q_info = Qs.get_question(i)  # Test the Question_Handler
        print(f'{q_info.keys()}')

In [None]:
# --- Helper Functions ---

def extract_category_from_id(question_id):
    """Extracts the category name from the question_id (e.g., 'fairness_3' -> 'Fairness')."""
    if not isinstance(question_id, str):
        return 'Unknown'
    match = re.match(r"([a-zA-Z_]+)_?\d*", question_id)
    if match:
        category_name = match.group(1).replace('_', ' ').title()
        # Handle specific known prefixes if needed
        if category_name.startswith('Mfq '):
             category_name = 'MFQ_30' # Keep original dataset name if preferred
        elif category_name.startswith('6 Concepts'):
             category_name = '6_concepts' # Keep original dataset name if preferred
        return category_name.strip().lower()
    return 'Unknown'

def get_category_from_qnum(q_num): # this gets out the dataset category (eg MFQ_30)
    """Gets the category name using the Question_Handler based on question number."""
    if Qs:
        return Qs.get_question_category(q_num)
    return 'Unknown' # Fallback if Qs is not initialized

def get_moralbench_scores(question_number, answer):
    """Gets the moral score for a given question number and answer."""
    if Qs:
        q_data = Qs.get_question(question_number)
        if q_data and 'answers' in q_data and q_data['answers'] and answer in q_data['answers']:
            return q_data['answers'][answer]
        # Handle cases where answers might be missing or empty
        # print(f"Warning: No score found for Q{question_number}, Answer '{answer}'. Q_data: {q_data}")
    return None # Fallback if Qs is not initialized, answer not found, or answers missing

def get_question_id_from_qnum(q_num): # this gets out the question_id (eg fairness_3)
    """Gets the question ID using the Question_Handler based on question number."""
    if Qs:
        q_info = Qs.get_question(q_num)
        if q_info and 'question_id' in q_info:
            return q_info['question_id']
    return 'Unknown' # Fallback if Qs is not initialized or question not found

def get_moral_category_from_qnum(q_num): # this gets out the moral category (eg Harm)
    """Gets the moral category name using the Question_Handler based on question number."""
    if Qs:
        q_info = Qs.get_question(q_num)
        if q_info and 'question_id' in q_info:
            return extract_category_from_id(q_info['question_id'])
    return 'Unknown' # Fallback if Qs is not initialized or question not found

def safe_literal_eval(val):
    """Safely evaluate a string literal (list, dict). Returns None on error."""
    try:
        return ast.literal_eval(val)
    except (ValueError, SyntaxError, TypeError):
        # print(f"Warning: Could not parse value: {val}")
        return None

def load_and_preprocess_data(results_dir):
    """Loads all CSV files from a directory and preprocesses them."""
    all_data_rows = []
    print(f"Checking directory: {results_dir}")
    if not os.path.exists(results_dir):
        print(f"Warning: Directory not found: {results_dir}")
        return pd.DataFrame()

    print(f"Found directory: {results_dir}. Searching for CSV files...")
    found_csv = False
    for filename in os.listdir(results_dir):
        if filename.endswith(".csv"):
            found_csv = True
            filepath = os.path.join(results_dir, filename)
            print(f"  Loading file: {filename}")
            try:
                df_raw = pd.read_csv(filepath)
                if df_raw.empty:
                    print(f"    Warning: File is empty: {filename}")
                    continue

                # Determine run type early based on columns
                is_multi_agent = 'agent_responses' in df_raw.columns
                is_single_agent = 'model_name' in df_raw.columns and 'run_index' in df_raw.columns and not is_multi_agent

                # --- Process based on run type ---
                if is_multi_agent:
                    print(f"    Processing as multi-agent data...")
                    df_raw['run_type'] = 'multi'
                    # Explode the agent_responses column
                    df_raw['agent_responses_parsed'] = df_raw['agent_responses'].apply(safe_literal_eval)
                    df_exploded = df_raw.explode('agent_responses_parsed')
                    df_exploded = df_exploded.dropna(subset=['agent_responses_parsed']) # Drop rows where parsing failed or was empty

                    # Expand the dictionary into columns
                    agent_data = pd.json_normalize(df_exploded['agent_responses_parsed'])
                    df = pd.concat([df_exploded.drop(columns=['agent_responses', 'agent_responses_parsed']).reset_index(drop=True),
                                    agent_data.reset_index(drop=True)], axis=1)
                    print(f"    Exploded agent responses. Shape after explode: {df.shape}")

                elif is_single_agent:
                    print(f"    Processing as single-agent data...")
                    df = df_raw.copy() # Use the raw df directly
                    df['run_type'] = 'single'
                else:
                    print(f"    Warning: Could not determine run type for {filename}. Skipping.")
                    continue

                # --- Add Category and Moral Category (Common Logic) ---
                if 'question_num' in df.columns and Qs:
                    df['category'] = df['question_num'].apply(get_category_from_qnum) # get the dataset category
                    df['question_id'] = df['question_num'].apply(get_question_id_from_qnum) # get the question_id
                    df['moral_category'] = df['question_num'].apply(get_moral_category_from_qnum) # get the moral category
                    print(f"    Extracted categories from 'question_num'. Unique values: {df['category'].unique()[:5]}...")
                elif 'question_id' in df.columns: # Fallback if question_num missing but question_id exists
                     df['category'] = df['question_id'].apply(extract_category_from_id) # Attempt to get moral category
                     df['moral_category'] = df['question_id'].apply(extract_category_from_id) # Use same logic for moral category
                     # Try to infer dataset category if possible (might be less reliable)
                     if Qs:
                         # This requires reversing the map, might be slow/complex. Stick to moral category for now.
                         print("    Warning: 'question_num' missing. Using 'question_id' for moral category. Dataset category might be inaccurate.")
                     else:
                         print("    Warning: 'question_num' missing and Qs handler failed. Using 'question_id' for moral category.")
                else:
                    df['category'] = 'Unknown'
                    df['moral_category'] = 'Unknown'
                    df['question_id'] = 'Unknown'
                    print("    Warning: Could not determine category ('question_num' or 'question_id' missing, or Qs handler failed).")

                # --- Filter by Dataset ---
                initial_rows = len(df)
                df = df[df['category'].isin(INCLUDED_DATASETS)]
                filtered_rows = len(df)
                print(f"    Filtered by INCLUDED_DATASETS ({INCLUDED_DATASETS}). Kept {filtered_rows}/{initial_rows} rows.")

                if not df.empty:
                    all_data_rows.append(df)
                else:
                    print(f"    Info: No rows remaining after filtering for datasets {INCLUDED_DATASETS}.")

            except pd.errors.EmptyDataError:
                print(f"    Warning: Skipping empty file: {filename}")
            except Exception as e:
                print(f"    Error loading or processing file {filename}: {e}")
                import traceback
                traceback.print_exc() # Print full traceback for debugging

    if not found_csv:
        print(f"Warning: No CSV files found in directory: {results_dir}")

    if not all_data_rows:
        print(f"No data loaded or retained from {results_dir} after processing and filtering. Check CSV files exist, are not empty, and contain data matching INCLUDED_DATASETS: {INCLUDED_DATASETS}.")
        return pd.DataFrame()

    print(f"Concatenating data from {len(all_data_rows)} files/dataframes.")
    combined_df = pd.concat(all_data_rows, ignore_index=True)

    # --- Data Cleaning (Common Logic) ---
    # Convert confidence to numeric, coercing errors
    if 'extracted_confidence' in combined_df.columns:
        combined_df['confidence_numeric'] = pd.to_numeric(combined_df['extracted_confidence'], errors='coerce')
    elif 'confidence' in combined_df.columns:
         combined_df['confidence_numeric'] = pd.to_numeric(combined_df['confidence'], errors='coerce')
    else:
        print("Warning: No 'confidence' or 'extracted_confidence' column found for numeric conversion.")
        combined_df['confidence_numeric'] = np.nan # Add column as NaN

    # Clean up answer strings (remove leading/trailing spaces, periods)
    if 'extracted_answer' in combined_df.columns:
        combined_df['answer_clean'] = combined_df['extracted_answer'].astype(str).str.strip().str.rstrip('.')
    elif 'answer' in combined_df.columns:
         combined_df['answer_clean'] = combined_df['answer'].astype(str).str.strip().str.rstrip('.')
    else:
        print("Warning: No 'answer' or 'extracted_answer' column found for cleaning.")
        combined_df['answer_clean'] = 'Unknown'

    # --- Calculate Score (Common Logic, requires 'question_num' and 'answer_clean') ---
    if Qs and 'question_num' in combined_df.columns and 'answer_clean' in combined_df.columns:
        print("    Calculating MoralBench scores...")
        combined_df['score'] = combined_df.apply(lambda row: get_moralbench_scores(row['question_num'], row['answer_clean']), axis=1)
        print(f"    Score calculation done. NaN count: {combined_df['score'].isna().sum()}")
    else:
        print("    Warning: Could not calculate scores ('question_num' or 'answer_clean' missing, or Qs handler failed).")
        combined_df['score'] = np.nan

    print(f"Finished loading and preprocessing for {results_dir}. Resulting dataframe shape: {combined_df.shape}")
    print(f"Columns: {combined_df.columns.tolist()}")
    return combined_df

In [None]:
# test the category extraction
qinfo = Qs.get_question(40)
get_category_from_qnum(40)
get_moral_category_from_qnum(40)
print(f"Question ID: {qinfo['question_id']}, Category: {get_category_from_qnum(40)}, Moral Category: {get_moral_category_from_qnum(40)}")

In [None]:
# test dataloading single
single_agent_df = load_and_preprocess_data(RESULTS_DIR_SINGLE)

In [None]:
single_agent_df['category'].unique()
single_agent_df.head(5)

In [None]:
single_agent_df.columns

In [None]:
# --- Plotting Functions ---

def plot_answer_distribution(df, plot_filename):
    """Plots the distribution of answers (A vs B) across all relevant runs."""
    if df.empty or 'answer_clean' not in df.columns:
        print("Cannot plot answer distribution: DataFrame is empty or 'answer_clean' column missing.")
        return

    plt.figure(figsize=(10, 6))
    # Filter for only 'A' and 'B' answers for clarity
    plot_data = df[df['answer_clean'].isin(['A', 'B'])]
    if plot_data.empty:
        print("Cannot plot answer distribution: No 'A' or 'B' answers found after cleaning.")
        plt.close()
        return
    sns.countplot(data=plot_data, x='answer_clean', order=['A', 'B'])
    plt.title('Overall Distribution of Answers (A vs B)')
    plt.xlabel('Answer')
    plt.ylabel('Count')
    plt.savefig(plot_filename)
    plt.close()
    print(f"Saved answer distribution plot to {plot_filename}")

def plot_confidence_distribution(df, plot_filename):
    """Plots the distribution of confidence scores."""
    if df.empty or 'confidence_numeric' not in df.columns:
        print("Cannot plot confidence distribution: DataFrame is empty or 'confidence_numeric' column missing.")
        return

    plt.figure(figsize=(10, 6))
    # Filter out NaN values before plotting
    plot_data = df.dropna(subset=['confidence_numeric'])
    if plot_data.empty:
        print("Cannot plot confidence distribution: No valid numeric confidence scores found.")
        plt.close()
        return
    sns.histplot(data=plot_data, x='confidence_numeric', bins=np.arange(-0.5, 6.5, 1), kde=False)
    plt.title('Distribution of Confidence Scores')
    plt.xlabel('Confidence Score (0-5)')
    plt.ylabel('Count')
    plt.xticks(range(6)) # Ensure ticks are 0, 1, 2, 3, 4, 5
    plt.xlim(-0.5, 5.5)
    plt.savefig(plot_filename)
    plt.close()
    print(f"Saved confidence distribution plot to {plot_filename}")

def plot_answer_by_category(df, plot_filename):
    """Plots the distribution of answers (A vs B) for each category."""
    if df.empty or 'answer_clean' not in df.columns or 'category' not in df.columns:
        print("Cannot plot answer by category: DataFrame empty or required columns missing.")
        return

    plt.figure(figsize=(12, 7))
    plot_data = df[df['answer_clean'].isin(['A', 'B']) & df['category'].isin(INCLUDED_DATASETS)]
    if plot_data.empty:
        print(f"Cannot plot answer by category: No 'A' or 'B' answers found for included datasets {INCLUDED_DATASETS}.")
        plt.close()
        return
    category_order = sorted([cat for cat in plot_data['category'].unique() if cat in INCLUDED_DATASETS])
    if not category_order:
        print(f"Cannot plot answer by category: No data found for included datasets {INCLUDED_DATASETS}.")
        plt.close()
        return
    sns.countplot(data=plot_data, x='category', hue='answer_clean', order=category_order, hue_order=['A', 'B'])
    plt.title('Answer Distribution (A vs B) by Category')
    plt.xlabel('Category')
    plt.ylabel('Count')
    plt.xticks(rotation=45, ha='right')
    plt.legend(title='Answer')
    plt.tight_layout()
    plt.savefig(plot_filename)
    plt.close()
    print(f"Saved answer by category plot to {plot_filename}")

def plot_confidence_by_category(df, plot_filename):
    """Plots the average confidence score for each category."""
    if df.empty or 'confidence_numeric' not in df.columns or 'category' not in df.columns:
        print("Cannot plot confidence by category: DataFrame empty or required columns missing.")
        return

    plt.figure(figsize=(12, 7))
    plot_data = df.dropna(subset=['confidence_numeric'])
    plot_data = plot_data[plot_data['category'].isin(INCLUDED_DATASETS)]
    if plot_data.empty:
        print(f"Cannot plot confidence by category: No valid numeric confidence scores found for included datasets {INCLUDED_DATASETS}.")
        plt.close()
        return
    category_order = sorted([cat for cat in plot_data['category'].unique() if cat in INCLUDED_DATASETS])
    if not category_order:
         print(f"Cannot plot confidence by category: No data found for included datasets {INCLUDED_DATASETS}.")
         plt.close()
         return
    sns.barplot(data=plot_data, x='category', y='confidence_numeric', order=category_order, estimator=np.mean, errorbar='sd') # Show mean and std dev
    plt.title('Average Confidence Score by Category')
    plt.xlabel('Category')
    plt.ylabel('Average Confidence Score (0-5)')
    plt.xticks(rotation=45, ha='right')
    plt.ylim(0, 5) # Set y-axis limits
    plt.tight_layout()
    plt.savefig(plot_filename)
    plt.close()
    print(f"Saved confidence by category plot to {plot_filename}")

def plot_answer_by_model(df, plot_filename):
    """Plots the distribution of answers (A vs B) for each model (single agent runs)."""
    df_single = df[(df['run_type'] == 'single') & (df['category'].isin(INCLUDED_DATASETS))]
    if df_single.empty or 'answer_clean' not in df_single.columns or 'model_name' not in df_single.columns:
        print(f"Cannot plot answer by model: No single-agent data for included datasets {INCLUDED_DATASETS} or required columns missing.")
        return

    plt.figure(figsize=(14, 8))
    plot_data = df_single[df_single['answer_clean'].isin(['A', 'B'])]
    if plot_data.empty:
        print(f"Cannot plot answer by model: No 'A' or 'B' answers found in single-agent data for included datasets {INCLUDED_DATASETS}.")
        plt.close()
        return
    model_order = sorted(plot_data['model_name'].unique())
    sns.countplot(data=plot_data, y='model_name', hue='answer_clean', order=model_order, hue_order=['A', 'B'])
    plt.title('Answer Distribution (A vs B) by Model (Single Agent Runs)')
    plt.xlabel('Count')
    plt.ylabel('Model Name')
    plt.legend(title='Answer')
    plt.tight_layout()
    plt.savefig(plot_filename)
    plt.close()
    print(f"Saved answer by model plot to {plot_filename}")

def plot_confidence_by_model(df, plot_filename):
    """Plots the average confidence score for each model (single agent runs)."""
    df_single = df[(df['run_type'] == 'single') & (df['category'].isin(INCLUDED_DATASETS))]
    if df_single.empty or 'confidence_numeric' not in df_single.columns or 'model_name' not in df_single.columns:
        print(f"Cannot plot confidence by model: No single-agent data for included datasets {INCLUDED_DATASETS} or required columns missing.")
        return

    plt.figure(figsize=(14, 8))
    plot_data = df_single.dropna(subset=['confidence_numeric'])
    if plot_data.empty:
        print(f"Cannot plot confidence by model: No valid numeric confidence scores found in single-agent data for included datasets {INCLUDED_DATASETS}.")
        plt.close()
        return
    model_order = sorted(plot_data['model_name'].unique())
    sns.barplot(data=plot_data, y='model_name', x='confidence_numeric', order=model_order, estimator=np.mean, errorbar='sd')
    plt.title('Average Confidence Score by Model (Single Agent Runs)')
    plt.xlabel('Average Confidence Score (0-5)')
    plt.ylabel('Model Name')
    plt.xlim(0, 5)
    plt.tight_layout()
    plt.savefig(plot_filename)
    plt.close()
    print(f"Saved confidence by model plot to {plot_filename}")

def plot_moral_radar_convergence(df, metric='score', show_sem=False, figsize=(18, 9), save_path=None):
    """
    Create radar plots for moral foundations by convergence loop (message_index) and dataset category.

    Parameters:
    -----------
    df : pandas DataFrame
        DataFrame containing multi-agent results, exploded by agent response.
        Requires columns: 'message_index', 'category', 'moral_category', 'question_num', metric ('score' or 'confidence_numeric').
    metric : str, default='score'
        Column to plot ('score' or 'confidence_numeric').
    show_sem : bool, default=False
        Whether to show standard error as shaded region.
    figsize : tuple, default=(18, 9)
        Size of the figure (width, height).
    save_path : str, optional
        Path to save the plots. If None, plots will be displayed.

    Returns:
    --------
    fig : matplotlib figure
        The created figure with subplots for each dataset category.
    """
    if df.empty or df['run_type'].iloc[0] != 'multi':
        print("Warning: plot_moral_radar_convergence requires a non-empty multi-agent DataFrame.")
        return None

    # Check required columns
    required_cols = ['message_index', 'category', 'moral_category', 'question_num', metric]
    if not all(col in df.columns for col in required_cols):
        print(f"Error: DataFrame missing one or more required columns for radar plot: {required_cols}")
        return None

    # Dictionary to map moral_category to display names
    dict_map = {
        'authority': 'Authority', 'fairness': 'Fairness', 'harm': 'Care',
        'ingroup': 'Loyalty', 'purity': 'Sanctity', 'liberty': 'Liberty'
    }

    # Filter for included datasets
    df_filtered = df[df['category'].isin(INCLUDED_DATASETS)].copy()
    if df_filtered.empty:
        print(f"Warning: No data found for included datasets {INCLUDED_DATASETS} in the multi-agent data.")
        return None

    # Calculate mean and SEM *per agent response* first (grouped by everything including agent)
    # Then average these agent means/SEMs per loop/question/category/moral_category
    # This approach assumes each agent's response within a loop is an independent sample for that loop's state.

    # Step 1: Calculate mean and SEM for the metric across runs for each agent, loop, question, etc.
    # Group by everything that defines a unique data point *before* averaging across agents within a loop
    grouping_cols = ['message_index', 'category', 'moral_category', 'question_num', 'agent_model'] # Include agent_model if you want per-agent stats first
    # Let's average across agents *within* the same loop, question, category, moral_category
    grouping_cols_agg = ['message_index', 'category', 'moral_category', 'question_num']

    print(f"Calculating stats grouped by: {grouping_cols_agg}")
    # Calculate mean and SEM of the metric across *agents* for each question/loop/category/moral_cat
    question_stats = df_filtered.groupby(grouping_cols_agg).agg(
        mean_metric=(metric, lambda x: np.nanmean(x)),
        sem_metric=(metric, lambda x: np.nanstd(x, ddof=1) / np.sqrt(np.sum(~np.isnan(x))) if np.sum(~np.isnan(x)) > 0 else 0)
    ).reset_index()
    print(f"Calculated question stats. Shape: {question_stats.shape}")
    # print(question_stats.head()) # Debugging

    # Step 2: Sum the means across questions for each loop, category, moral_category
    summed_means = question_stats.groupby(['message_index', 'category', 'moral_category'])['mean_metric'].sum().reset_index()
    print(f"Calculated summed means. Shape: {summed_means.shape}")
    # print(summed_means.head()) # Debugging

    # Step 3: Calculate the propagated SEM for the sums (sqrt of sum of squared SEMs)
    summed_sems = question_stats.groupby(['message_index', 'category', 'moral_category'])['sem_metric'].apply(
        lambda x: np.sqrt(np.sum(x**2))
    ).reset_index()
    print(f"Calculated summed SEMs. Shape: {summed_sems.shape}")
    # print(summed_sems.head()) # Debugging


    # Merge the means and SEMs
    final_data = pd.merge(summed_means, summed_sems, on=['message_index', 'category', 'moral_category'])
    final_data = final_data.rename(columns={'mean_metric': 'mean', 'sem_metric': 'sem'}) # Rename for consistency
    print(f"Final aggregated data shape: {final_data.shape}")
    # print(final_data.head()) # Debugging

    # Get unique dataset categories and loops (message indices)
    dataset_categories = sorted(final_data['category'].unique())
    loops = sorted(final_data['message_index'].unique())

    if not dataset_categories:
        print("Error: No dataset categories found in the aggregated data.")
        return None
    if not loops:
        print("Error: No message indices (loops) found in the aggregated data.")
        return None

    # Create a colormap for the loops
    colors = plt.cm.viridis(np.linspace(0, 1, len(loops))) # Use viridis for sequential data

    # Create figure with subplots
    n_cols = len(dataset_categories)
    fig, axes = plt.subplots(1, n_cols, figsize=figsize, subplot_kw=dict(polar=True))

    # If there's only one category, make axes an array
    if n_cols == 1:
        axes = [axes]

    # Process each dataset category
    for i, d_category in enumerate(dataset_categories):
        ax = axes[i]
        print(f"\nPlotting category: {d_category}")

        # Filter data for this dataset category
        cat_data = final_data[final_data['category'] == d_category]
        if cat_data.empty:
            print(f"  No data for category {d_category}, skipping subplot.")
            ax.set_title(f"Category: {d_category}\n(No Data)")
            ax.set_xticks([])
            ax.set_yticks([])
            continue

        # Get unique moral categories present in this dataset category's data
        moral_cats = sorted(cat_data['moral_category'].unique())
        if not moral_cats:
             print(f"  No moral categories found for dataset {d_category}, skipping subplot.")
             ax.set_title(f"Category: {d_category}\n(No Moral Categories)")
             ax.set_xticks([])
             ax.set_yticks([])
             continue

        # Map moral categories using dict_map, handle missing keys
        moral_cats_mapped = [dict_map.get(mc, mc.title()) for mc in moral_cats]

        # Number of moral categories
        N = len(moral_cats)

        # Create angle values (in radians)
        angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
        angles += angles[:1] # Make the plot circular

        # Set up the axis
        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(moral_cats_mapped)
        ax.tick_params(axis='x', pad=10) # Add padding to x-tick labels

        # Set title for the subplot
        ax.set_title(f"Dataset: {d_category}", pad=25) # Increased padding

        # Find max value for scaling across all loops for this category
        max_val = cat_data['mean'].max()
        if show_sem:
            max_val = max(max_val, (cat_data['mean'] + cat_data['sem']).max())
        max_val = max(max_val, 0.1) # Ensure max_val is not zero

        # Set y-axis limits with some margin
        ax.set_ylim(0, max_val * 1.2)

        # Add grid lines with improved labels
        num_rticks = 5
        rticks = np.linspace(0, max_val, num_rticks + 1)[1:] # Avoid 0 tick label overlap
        ax.set_rticks(rticks)
        # Format tick labels with appropriate precision
        tick_format = ".2f" if max_val < 10 else ".1f"
        ax.set_yticklabels([f"{tick:{tick_format}}" for tick in rticks])
        ax.grid(True)

        # Plot each loop (message_index)
        for j, loop_index in enumerate(loops):
            # Filter data for this loop
            loop_data = cat_data[cat_data['message_index'] == loop_index]

            if loop_data.empty:
                print(f"  No data for loop {loop_index} in category {d_category}")
                continue

            # Create ordered arrays of means and SEMs based on moral_cats order
            means = []
            sems = []
            for mc in moral_cats:
                loop_mc_data = loop_data[loop_data['moral_category'] == mc]
                if not loop_mc_data.empty:
                    means.append(loop_mc_data['mean'].iloc[0])
                    sems.append(loop_mc_data['sem'].iloc[0])
                else:
                    means.append(0) # Append 0 if no data for this moral category in this loop
                    sems.append(0)

            # Make means circular for plotting
            means_circular = np.append(means, means[0])

            # Plot the mean line
            # Use message_index + 1 for 1-based loop labeling if desired, or just message_index
            label = f"Loop {loop_index}" # Or loop_index + 1
            ax.plot(angles, means_circular, color=colors[j], linestyle='-', marker='o', markersize=4, label=label)

            # Add SEM shading if requested
            if show_sem:
                upper_bound = np.array(means) + np.array(sems)
                lower_bound = np.array(means) - np.array(sems)
                lower_bound = np.maximum(lower_bound, 0)  # Ensure no negative values

                # Make bounds circular
                upper_bound_circular = np.append(upper_bound, upper_bound[0])
                lower_bound_circular = np.append(lower_bound, lower_bound[0])

                # Create shaded region
                ax.fill_between(angles, lower_bound_circular, upper_bound_circular,
                                alpha=0.15, color=colors[j]) # Slightly less alpha

    # Add a legend to the figure
    handles, labels = axes[0].get_legend_handles_labels() # Get legend items from the first axis
    if handles: # Only add legend if there are items to show
        fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.01), ncol=len(loops)) # Adjusted position

    # Adjust layout to prevent overlap
    plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust rect for title and legend

    # Add more space between subplots if needed
    plt.subplots_adjust(wspace=0.4, hspace=0.4) # Increased wspace

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nSaved convergence radar plot to {save_path}")

    return fig

# --- Plotting for Single Agent (Modified Radar Plot by Iteration) ---

def set_plot_style(title_fontsize=16, label_fontsize=14, tick_fontsize=12,
                   legend_fontsize=14, line_width=2.5):
    """Set global matplotlib parameters for radar plots"""
    plt.rcParams.update({
        'font.size': label_fontsize,
        'axes.titlesize': title_fontsize,
        'axes.labelsize': label_fontsize,
        'xtick.labelsize': tick_fontsize,
        'ytick.labelsize': tick_fontsize,
        'legend.fontsize': legend_fontsize,
        'lines.linewidth': line_width,
        'figure.titlesize': title_fontsize + 2
    })

def plot_moral_radar(df, metric='score', show_sem=False, figsize=(18, 9), save_path=None):
    """
    Create radar plots for moral foundations by run index (iteration) and category (for single agent data).
    Each line represents a run index, averaged across all models for that run index.

    Parameters:
    -----------
    df : pandas DataFrame
        DataFrame with required columns ('run_index', 'category', 'moral_category', 'question_num', metric)
        Assumes 'run_type' is 'single'.
    metric : str, default='score'
        Column to plot ('score' or 'confidence_numeric')
    show_sem : bool, default=False
        Whether to show standard error as shaded region
    figsize : tuple, default=(18, 9)
        Size of the figure (width, height)
    save_path : str, optional
        Path to save the plots, if None, plots will be displayed

    Returns:
    --------
    fig : matplotlib figure
        The created figure with subplots for each dataset category.
    """
    if df.empty or df['run_type'].iloc[0] != 'single':
        print("Warning: plot_moral_radar requires a non-empty single-agent DataFrame.")
        return None

    # Check required columns
    required_cols = ['run_index', 'category', 'moral_category', 'question_num', metric]
    if not all(col in df.columns for col in required_cols):
        print(f"Error: DataFrame missing one or more required columns for radar plot: {required_cols}")
        return None

    # Dictionary to map moral_category to display names
    dict_map = {
        'authority': 'Authority', 'fairness': 'Fairness', 'harm': 'Care',
        'ingroup': 'Loyalty', 'purity': 'Sanctity', 'liberty': 'Liberty'
    }

    # Filter for included datasets
    df_filtered = df[df['category'].isin(INCLUDED_DATASETS)].copy()
    if df_filtered.empty:
        print(f"Warning: No data found for included datasets {INCLUDED_DATASETS} in the single-agent data.")
        return None

    # Ensure metric column is numeric
    df_filtered[metric] = pd.to_numeric(df_filtered[metric], errors='coerce')

    # First, calculate mean and SEM for each question, by run_index, category, moral_category
    # This averages across models for a specific run_index
    grouping_cols_q = ['run_index', 'category', 'moral_category', 'question_num']
    question_stats = df_filtered.groupby(grouping_cols_q).agg(
        mean_metric=(metric, lambda x: np.nanmean(x)),
        sem_metric=(metric, lambda x: np.nanstd(x, ddof=1) / np.sqrt(np.sum(~np.isnan(x))) if np.sum(~np.isnan(x)) > 0 else 0)
    ).reset_index()

    # Now, sum the means across questions for each run_index, category, moral_category
    grouping_cols_agg = ['run_index', 'category', 'moral_category']
    summed_means = question_stats.groupby(grouping_cols_agg)['mean_metric'].sum().reset_index()

    # Calculate the propagated SEM for the sums (sqrt of sum of squared SEMs)
    summed_sems = question_stats.groupby(grouping_cols_agg)['sem_metric'].apply(
        lambda x: np.sqrt(np.sum(x**2))
    ).reset_index()

    # Merge the means and SEMs
    final_data = pd.merge(summed_means, summed_sems, on=grouping_cols_agg)
    final_data = final_data.rename(columns={'mean_metric': 'mean', 'sem_metric': 'sem'}) # Rename for consistency

    # Get unique dataset categories and run indices
    dataset_categories = sorted(final_data['category'].unique())
    run_indices = sorted(final_data['run_index'].unique())

    if not dataset_categories:
        print("Error: No dataset categories found in the aggregated data.")
        return None
    if not run_indices:
        print("Error: No run indices found in the aggregated data.")
        return None

    # Create a colormap for the run indices
    colors = plt.cm.viridis(np.linspace(0, 1, len(run_indices)))

    # Create figure with subplots
    n_cols = len(dataset_categories)
    fig, axes = plt.subplots(1, n_cols, figsize=figsize, subplot_kw=dict(polar=True))

    if n_cols == 1:
        axes = [axes]

    # Process each dataset category
    for i, d_category in enumerate(dataset_categories):
        ax = axes[i]
        print(f"\nPlotting category: {d_category}")

        cat_data = final_data[final_data['category'] == d_category]
        if cat_data.empty:
            print(f"  No data for category {d_category}, skipping subplot.")
            ax.set_title(f"Category: {d_category}\n(No Data)")
            ax.set_xticks([])
            ax.set_yticks([])
            continue

        moral_cats = sorted(cat_data['moral_category'].unique())
        if not moral_cats:
             print(f"  No moral categories found for dataset {d_category}, skipping subplot.")
             ax.set_title(f"Category: {d_category}\n(No Moral Categories)")
             ax.set_xticks([])
             ax.set_yticks([])
             continue

        moral_cats_mapped = [dict_map.get(mc, mc.title()) for mc in moral_cats]
        N = len(moral_cats)
        angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
        angles += angles[:1]

        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(moral_cats_mapped)
        ax.tick_params(axis='x', pad=10)

        ax.set_title(f"Dataset: {d_category}", pad=25)

        max_val = cat_data['mean'].max()
        if show_sem:
            max_val = max(max_val, (cat_data['mean'] + cat_data['sem']).max())
        max_val = max(max_val, 0.1)

        ax.set_ylim(0, max_val * 1.2)

        num_rticks = 5
        rticks = np.linspace(0, max_val, num_rticks + 1)[1:]
        ax.set_rticks(rticks)
        tick_format = ".2f" if max_val < 10 else ".1f"
        ax.set_yticklabels([f"{tick:{tick_format}}" for tick in rticks])
        ax.grid(True)

        # Plot each run index
        for j, run_idx in enumerate(run_indices):
            run_data = cat_data[cat_data['run_index'] == run_idx]

            if run_data.empty:
                print(f"  No data for run index {run_idx} in category {d_category}")
                continue

            means = []
            sems = []
            for mc in moral_cats:
                run_mc_data = run_data[run_data['moral_category'] == mc]
                if not run_mc_data.empty:
                    means.append(run_mc_data['mean'].iloc[0])
                    sems.append(run_mc_data['sem'].iloc[0])
                else:
                    means.append(0)
                    sems.append(0)

            means_circular = np.append(means, means[0])
            label = f"Run {run_idx}"
            ax.plot(angles, means_circular, color=colors[j], linestyle='-', marker='o', markersize=4, label=label)

            if show_sem:
                upper_bound = np.array(means) + np.array(sems)
                lower_bound = np.array(means) - np.array(sems)
                lower_bound = np.maximum(lower_bound, 0)
                upper_bound_circular = np.append(upper_bound, upper_bound[0])
                lower_bound_circular = np.append(lower_bound, lower_bound[0])
                ax.fill_between(angles, lower_bound_circular, upper_bound_circular, alpha=0.15, color=colors[j])

    handles, labels = axes[0].get_legend_handles_labels()
    if handles:
        # Adjust ncol based on number of runs, max 5 per row for readability
        ncol_legend = min(len(run_indices), 5)
        fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.01), ncol=ncol_legend)

    plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust rect for title and legend
    plt.subplots_adjust(wspace=0.4, hspace=0.4)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nSaved single-agent radar plot (by run index) to {save_path}")

    return fig

In [None]:
set_plot_style(title_fontsize=16, label_fontsize=12, tick_fontsize=16, 
               legend_fontsize=14, line_width=2.5)
# Note: This plot now shows lines per run_index, averaged across models
fig = plot_moral_radar(single_agent_df, metric='score', show_sem=True, save_path=os.path.join(PLOT_DIR, 'moral_radar_plot_score_by_run.png'))

In [None]:
set_plot_style(title_fontsize=16, label_fontsize=12, tick_fontsize=16, 
               legend_fontsize=14, line_width=2.5)
# Note: This plot now shows lines per run_index, averaged across models
fig = plot_moral_radar(single_agent_df, metric='confidence_numeric', show_sem=True, save_path=os.path.join(PLOT_DIR, 'moral_radar_plot_confidence_by_run.png'))

In [None]:
# --- Main Execution ---
print("Loading and preprocessing data...")
# Load potentially both types of data
df_single = load_and_preprocess_data(RESULTS_DIR_SINGLE)
df_multi = load_and_preprocess_data(RESULTS_DIR_MULTI)

# --- Generate Single-Agent Plots (if data exists) ---
if not df_single.empty:
    print("\n--- Generating Single-Agent Specific Plots ---")
    set_plot_style(title_fontsize=16, label_fontsize=12, tick_fontsize=12, # Adjusted tick fontsize
                   legend_fontsize=12, line_width=2.0) # Adjusted legend and line width
    print("Generating single-agent score radar plot (by run index)...")
    fig_single_score = plot_moral_radar(df_single, metric='score', show_sem=True,
                                        save_path=os.path.join(PLOT_DIR, 'moral_radar_plot_score_single_by_run.png'))
    if fig_single_score: plt.close(fig_single_score) # Close figure

    print("Generating single-agent confidence radar plot (by run index)...")
    fig_single_conf = plot_moral_radar(df_single, metric='confidence_numeric', show_sem=True,
                                       save_path=os.path.join(PLOT_DIR, 'moral_radar_plot_confidence_single_by_run.png'))
    if fig_single_conf: plt.close(fig_single_conf) # Close figure
else:
    print("\n--- No single-agent data found or loaded, skipping single-agent plots. ---")

# --- Generate Multi-Agent Convergence Plots (if data exists) ---
if not df_multi.empty:
    print("\n--- Generating Multi-Agent Convergence Plots ---")
    set_plot_style(title_fontsize=16, label_fontsize=12, tick_fontsize=12, # Adjusted tick fontsize
                   legend_fontsize=12, line_width=2.0) # Adjusted legend and line width

    print("Generating multi-agent score convergence radar plot...")
    fig_multi_score = plot_moral_radar_convergence(df_multi, metric='score', show_sem=True,
                                                   save_path=os.path.join(PLOT_DIR, 'moral_radar_plot_score_convergence.png'))
    if fig_multi_score: plt.close(fig_multi_score) # Close figure

    print("Generating multi-agent confidence convergence radar plot...")
    fig_multi_conf = plot_moral_radar_convergence(df_multi, metric='confidence_numeric', show_sem=True,
                                                  save_path=os.path.join(PLOT_DIR, 'moral_radar_plot_confidence_convergence.png'))
    if fig_multi_conf: plt.close(fig_multi_conf) # Close figure
else:
    print("\n--- No multi-agent data found or loaded, skipping convergence plots. ---")

# --- Generate Combined/General Plots (Using combined data if available) ---
print("\n--- Generating General Distribution Plots ---")
# Combine data if both exist for general plots
df_combined = pd.DataFrame() # Initialize empty
if not df_single.empty and not df_multi.empty:
    print("Combining single-agent and multi-agent data for general plots.")
    # Ensure columns align before concat, fill missing with NaN or appropriate value
    cols = list(set(df_single.columns) | set(df_multi.columns))
    df_single_reindexed = df_single.reindex(columns=cols)
    df_multi_reindexed = df_multi.reindex(columns=cols)
    df_combined = pd.concat([df_single_reindexed, df_multi_reindexed], ignore_index=True)
elif not df_single.empty:
    print("Using only single-agent data for general plots.")
    df_combined = df_single
elif not df_multi.empty:
    print("Using only multi-agent data for general plots.")
    df_combined = df_multi
else:
    print("No data loaded from either single-agent or multi-agent results directories for general plots.")
    # df_combined remains empty

if not df_combined.empty:
    print(f"\nLoaded {len(df_combined)} total records for general analysis.")
    print(f"Included datasets: {INCLUDED_DATASETS}")
    # Filter combined data for included datasets just in case
    df_combined_filtered = df_combined[df_combined['category'].isin(INCLUDED_DATASETS)].copy()
    if df_combined_filtered.empty:
         print(f"Warning: No data remaining after filtering combined data for {INCLUDED_DATASETS}.")
    else:
        print(f"Unique categories found in combined/filtered data: {df_combined_filtered['category'].unique()}")
        print(f"Data shape after combining/filtering: {df_combined_filtered.shape}")
        print(f"Value counts for 'category':\n{df_combined_filtered['category'].value_counts()}")
        print(f"Value counts for 'run_type':\n{df_combined_filtered['run_type'].value_counts()}")

        # --- Generate Individual Distribution Plots ---
        print("\nGenerating individual distribution plots (from combined data)...")
        plot_answer_distribution(df_combined_filtered, os.path.join(PLOT_DIR, 'overall_answer_distribution.png'))
        plot_confidence_distribution(df_combined_filtered, os.path.join(PLOT_DIR, 'overall_confidence_distribution.png'))
        plot_answer_by_category(df_combined_filtered, os.path.join(PLOT_DIR, 'answer_by_category.png'))
        plot_confidence_by_category(df_combined_filtered, os.path.join(PLOT_DIR, 'confidence_by_category.png'))
        # These only make sense for single-agent data, but run on combined (will filter inside)
        plot_answer_by_model(df_combined_filtered, os.path.join(PLOT_DIR, 'answer_by_model_single.png'))
        plot_confidence_by_model(df_combined_filtered, os.path.join(PLOT_DIR, 'confidence_by_model_single.png'))

        # --- Generate 3x2 Grid Plot (Optional - might be less informative with multi-agent data mixed in) ---
        # Consider if this grid plot is still desired or if separate single/multi plots are better.
        # Keeping it for now, but be aware it mixes single-agent models and multi-agent loops.
        print("\nGenerating combined grid plot (from combined data)...")
        fig_grid, axes_grid = plt.subplots(3, 2, figsize=(18, 24)) # Adjusted figsize
        fig_grid.suptitle(f'MoralBench Analysis ({", ".join(INCLUDED_DATASETS)}) - Combined Data', fontsize=16, y=1.02) # Add main title

        # Plot 1: Overall Answer Distribution
        # ... (rest of the grid plotting code remains largely the same, operating on df_combined_filtered) ...
        # Plot 1: Overall Answer Distribution
        if 'answer_clean' in df_combined_filtered.columns:
            plot_data_ans = df_combined_filtered[df_combined_filtered['answer_clean'].isin(['A', 'B'])]
            if not plot_data_ans.empty:
                sns.countplot(ax=axes_grid[0, 0], data=plot_data_ans, x='answer_clean', order=['A', 'B'])
                axes_grid[0, 0].set_title('Overall Answer Distribution (A vs B)')
                axes_grid[0, 0].set_xlabel('Answer')
                axes_grid[0, 0].set_ylabel('Count')
            else:
                axes_grid[0, 0].set_title('Overall Answer Distribution (No A/B Data)')
        else:
             axes_grid[0, 0].set_title('Overall Answer Distribution (No Data)')

        # Plot 2: Overall Confidence Distribution
        if 'confidence_numeric' in df_combined_filtered.columns:
            plot_data_conf = df_combined_filtered.dropna(subset=['confidence_numeric'])
            if not plot_data_conf.empty:
                sns.histplot(ax=axes_grid[0, 1], data=plot_data_conf, x='confidence_numeric', bins=np.arange(-0.5, 6.5, 1), kde=False)
                axes_grid[0, 1].set_title('Overall Confidence Distribution')
                axes_grid[0, 1].set_xlabel('Confidence Score (0-5)')
                axes_grid[0, 1].set_ylabel('Count')
                axes_grid[0, 1].set_xticks(range(6))
                axes_grid[0, 1].set_xlim(-0.5, 5.5)
            else:
                axes_grid[0, 1].set_title('Overall Confidence Distribution (No Numeric Data)')
        else:
             axes_grid[0, 1].set_title('Overall Confidence Distribution (No Data)')

        # Plot 3: Answer Distribution by Category
        if 'answer_clean' in df_combined_filtered.columns and 'category' in df_combined_filtered.columns:
            plot_data_cat_ans = df_combined_filtered[df_combined_filtered['answer_clean'].isin(['A', 'B'])] # Already filtered by INCLUDED_DATASETS
            if not plot_data_cat_ans.empty:
                category_order = sorted(plot_data_cat_ans['category'].unique())
                if category_order:
                    sns.countplot(ax=axes_grid[1, 0], data=plot_data_cat_ans, x='category', hue='answer_clean', order=category_order, hue_order=['A', 'B'])
                    axes_grid[1, 0].set_title('Answer Distribution by Category')
                    axes_grid[1, 0].set_xlabel('Category')
                    axes_grid[1, 0].set_ylabel('Count')
                    axes_grid[1, 0].tick_params(axis='x', rotation=45)
                    axes_grid[1, 0].legend(title='Answer')
                else:
                    axes_grid[1, 0].set_title('Answer Distribution by Category (No Included Dataset Data)')
            else:
                axes_grid[1, 0].set_title('Answer Distribution by Category (No A/B Data)')
        else:
             axes_grid[1, 0].set_title('Answer Distribution by Category (No Data)')

        # Plot 4: Confidence Distribution by Category
        if 'confidence_numeric' in df_combined_filtered.columns and 'category' in df_combined_filtered.columns:
            plot_data_cat_conf = df_combined_filtered.dropna(subset=['confidence_numeric']) # Already filtered
            if not plot_data_cat_conf.empty:
                category_order = sorted(plot_data_cat_conf['category'].unique())
                if category_order:
                    sns.barplot(ax=axes_grid[1, 1], data=plot_data_cat_conf, x='category', y='confidence_numeric', order=category_order, estimator=np.mean, errorbar='sd')
                    axes_grid[1, 1].set_title('Average Confidence by Category')
                    axes_grid[1, 1].set_xlabel('Category')
                    axes_grid[1, 1].set_ylabel('Average Confidence Score (0-5)')
                    axes_grid[1, 1].tick_params(axis='x', rotation=45)
                    axes_grid[1, 1].set_ylim(0, 5)
                else:
                    axes_grid[1, 1].set_title('Average Confidence by Category (No Included Dataset Data)')
            else:
                axes_grid[1, 1].set_title('Average Confidence by Category (No Numeric Data)')
        else:
             axes_grid[1, 1].set_title('Average Confidence by Category (No Data)')

        # Plot 5: Answer Distribution by Model (Single Agent Only)
        df_single_plot = df_combined_filtered[(df_combined_filtered['run_type'] == 'single')] # Filter only single agent runs
        if not df_single_plot.empty and 'answer_clean' in df_single_plot.columns and 'model_name' in df_single_plot.columns:
            plot_data_mod_ans = df_single_plot[df_single_plot['answer_clean'].isin(['A', 'B'])]
            if not plot_data_mod_ans.empty:
                model_order = sorted(plot_data_mod_ans['model_name'].unique())
                sns.countplot(ax=axes_grid[2, 0], data=plot_data_mod_ans, y='model_name', hue='answer_clean', order=model_order, hue_order=['A', 'B'])
                axes_grid[2, 0].set_title('Answer Distribution by Model (Single)')
                axes_grid[2, 0].set_xlabel('Count')
                axes_grid[2, 0].set_ylabel('Model Name')
                axes_grid[2, 0].legend(title='Answer')
            else:
                axes_grid[2, 0].set_title('Answer Distribution by Model (No A/B Single Data)')
        else:
             axes_grid[2, 0].set_title('Answer Distribution by Model (No Single Data)')

        # Plot 6: Confidence Distribution by Model (Single Agent Only)
        if not df_single_plot.empty and 'confidence_numeric' in df_single_plot.columns and 'model_name' in df_single_plot.columns:
            plot_data_mod_conf = df_single_plot.dropna(subset=['confidence_numeric'])
            if not plot_data_mod_conf.empty:
                model_order = sorted(plot_data_mod_conf['model_name'].unique())
                sns.barplot(ax=axes_grid[2, 1], data=plot_data_mod_conf, y='model_name', x='confidence_numeric', order=model_order, estimator=np.mean, errorbar='sd')
                axes_grid[2, 1].set_title('Average Confidence by Model (Single)')
                axes_grid[2, 1].set_xlabel('Average Confidence Score (0-5)')
                axes_grid[2, 1].set_ylabel('Model Name')
                axes_grid[2, 1].set_xlim(0, 5)
            else:
                axes_grid[2, 1].set_title('Average Confidence by Model (No Numeric Single Data)')
        else:
             axes_grid[2, 1].set_title('Average Confidence by Model (No Single Data)')


        # Adjust layout and save grid plot
        plt.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust rect to make space for suptitle
        grid_plot_filename = os.path.join(PLOT_DIR, 'combined_analysis_grid.png')
        plt.savefig(grid_plot_filename)
        plt.close(fig_grid) # Close the figure to free memory
        print(f"\nSaved combined analysis grid plot to {grid_plot_filename}")

else:
    print("\nNo data available for plotting after loading and filtering.")

print("\nAnalysis complete.")

In [None]:
# test dataloading multi
multi_agent_df = load_and_preprocess_data(RESULTS_DIR_MULTI)
multi_agent_df.head()