In [13]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Create directory for saving plots if it doesn't exist
os.makedirs('saved_plots', exist_ok=True)

# Set the style for plots
plt.style.use('ggplot')
sns.set_palette(["#FF7F50", "#1E90FF"])  # Orange and blue colors

# Define the directory containing the CSV files
data_dir = 'combined_results'

# Define our agents in explicit order
our_agents = [f'Group40Agent0{i}' for i in range(1, 6)]

# Define opponent types with proper naming
opponent_types = ['our', 'anl', 'baseline', 'students']
opponent_names = {'our': 'Our Agents', 'anl': 'ANL2022', 'baseline': 'Baseline', 'students': 'Student'}

# Initialize a dictionary to store all data
all_data = []

# Process each file
for file in os.listdir(data_dir):
    if file.endswith('.csv'):
        # Extract opponent type and domain type from filename
        parts = file.split('_')
        opponent = parts[0]
        domain_type = parts[1].replace('.csv', '')

        # Read the CSV file
        df = pd.read_csv(os.path.join(data_dir, file), index_col=0)
        
        # Filter for our agents and relevant metrics
        relevant_metrics = ['avg_utility', 'avg_nash_product', 'avg_social_welfare', 'avg_num_offers']
        df = df[df.index.isin(our_agents)][relevant_metrics]
        
        # Add metadata columns with proper naming
        df['opponent'] = opponent_names.get(opponent, opponent)
        df['domain_type'] = domain_type
        df['agent'] = df.index
        
        # Add to our collection
        all_data.append(df.reset_index(drop=True))

# Combine all data into a single DataFrame
plot_data = pd.concat(all_data)

# Convert 'agent' to categorical with explicit ordering
plot_data['agent'] = pd.Categorical(plot_data['agent'], 
                                   categories=our_agents,
                                   ordered=True)

# Sort the DataFrame by agent to ensure proper ordering
plot_data = plot_data.sort_values('agent')

# Define metrics and their display names
metrics = {
    'avg_utility': 'Average Utility',
    'avg_nash_product': 'Average Nash Product',
    'avg_social_welfare': 'Average Social Welfare',
    'avg_num_offers': 'Average Number of Offers'
}

# Function to create consistent bar plots
def create_barplot(data, x, y, hue, title, xlabel, ylabel, filename):
    plt.figure(figsize=(10, 6))
    ax = sns.barplot(
        data=data,
        x=x,
        y=y,
        hue=hue,
        palette=["#FF7F50", "#1E90FF"],  # Orange and blue
        errorbar=None,
        dodge=True
    )
    
    plt.title(title, fontsize=14)
    plt.ylabel(ylabel, fontsize=12)
    plt.xlabel(xlabel, fontsize=12)
    plt.xticks(rotation=45)
    
    # Add value labels on top of each bar
    for container in ax.containers:
        ax.bar_label(container, fmt='%.3f', padding=3, fontsize=9)
    
    # Move legend outside the plot consistently
    plt.legend(title=hue.capitalize(), bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(f'saved_plots/{filename}.png', dpi=300, bbox_inches='tight')
    plt.close()

# Create and save the 2x2 grid of subplots for each metric
fig, axes = plt.subplots(2, 2, figsize=(18, 14))
fig.suptitle('Agent Performance Across Different Opponent Types and Domains', fontsize=16, y=1.02)

for ax, (metric, title) in zip(axes.flatten(), metrics.items()):
    sns.barplot(
        data=plot_data, 
        x='agent', 
        y=metric, 
        hue='domain_type', 
        ax=ax,
        palette=["#FF7F50", "#1E90FF"],
        errorbar=None,
        dodge=True,
        order=our_agents
    )
    
    ax.set_title(title, fontsize=14)
    ax.set_ylabel(title, fontsize=12)
    ax.set_xlabel('Agent', fontsize=12)
    ax.tick_params(axis='x', rotation=45)
    
    for container in ax.containers:
        ax.bar_label(container, fmt='%.3f', padding=3, fontsize=9)

    # Move legend outside the plot consistently
    ax.legend(title='Domain type', bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.savefig('saved_plots/overall_comparison.png', dpi=300, bbox_inches='tight')
plt.close()

# Create and save separate plots for each opponent type
for opponent in opponent_types:
    opponent_name = opponent_names.get(opponent, opponent)
    opponent_data = plot_data[plot_data['opponent'] == opponent_name]
    
    for metric, title in metrics.items():
        create_barplot(
            data=opponent_data,
            x='agent',
            y=metric,
            hue='domain_type',
            title=f'{title} Against {opponent_name}',
            xlabel='Agent',
            ylabel=title,
            filename=f'{opponent}_{metric}'
        )

# Create and save line plots for each metric
for metric, title in metrics.items():
    plt.figure(figsize=(12, 6))
    sns.lineplot(
        data=plot_data,
        x='agent',
        y=metric,
        hue='opponent',
        style='domain_type',
        markers=True,
        dashes=False,
        markersize=10,
        linewidth=2.5,
        sort=False,
        palette=sns.color_palette("husl", len(opponent_names))
    )
    
    plt.title(f'{title} Comparison Across Agents and Domains')
    plt.ylabel(title)
    plt.xlabel('Agent')
    plt.xticks(rotation=45)
    plt.legend(title='Opponent', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'saved_plots/{metric}_trends.png', dpi=300, bbox_inches='tight')
    plt.close()

print("All plots saved to the 'saved_plots' directory")

All plots saved to the 'saved_plots' directory
