In [1]:
import sys
sys.path.append("..")

from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
from IPython.display import Audio
import numpy as np
import pandas as pd
import pickle
import yaml

from lib.notebooks import plot_groups_metrics
from communicative_agent import CommunicativeAgent

In [2]:
agents_path = glob("../out/communicative_agent/*/")
agents_path.sort()

In [3]:
# Dictionary to store metrics for different groups of agents
groups_metrics = {}

# Dictionary to store various parameters and performance metrics for each agent
agents_loss = {
    "path": [],                        # Path to agent directory
    "datasets": [],                    # Training datasets used
    "inverse_learning_rate": [],       # Learning rate for inverse model
    "inverse_layers": [],             # Layer configuration of inverse model
    "inverse_dropout_p": [],          # Dropout probability for inverse model
    "direct_learning_rate": [],       # Learning rate for direct model
    "direct_layers": [],              # Layer configuration of direct model
    "direct_dropout_p": [],           # Dropout probability for direct model
    "jerk_weight": [],               # Weight of jerk loss term
    "direct_estimation_error": [],    # Test error of direct model
    "inverse_estimation_error": [],   # Test error of inverse model
    "jerk": [],                      # Jerk metric on test set
    "repetition_error": [],          # Repetition error on test set
}

# Iterate through all agent directories to collect metrics and parameters
for agent_path in agents_path:
    # Load agent configuration without neural networks for efficiency
    agent = CommunicativeAgent.reload(agent_path, load_nn=False)
    config = agent.config
    
    # Load stored metrics from training
    with open("%s/metrics.pickle" % agent_path, "rb") as f:
        metrics = pickle.load(f)
    
    # Store basic agent information
    agents_loss["path"].append(agent_path[-5:-1])
    agents_loss["datasets"].append(",".join(agent.sound_quantizer.config['dataset']['names']))

    # Store inverse model parameters
    agents_loss["inverse_learning_rate"].append(config['training']['inverse_model_learning_rate'])
    agents_loss["inverse_layers"].append(f"{config['model']['inverse_model']['num_layers']}x{config['model']['inverse_model']['hidden_size']}")
    agents_loss["inverse_dropout_p"].append(config['model']['inverse_model']['dropout_p'])

    # Store direct model parameters, handling cases where synthesizer is used as direct model
    if 'use_synth_as_direct_model' not in config['model']:
        agents_loss["direct_learning_rate"].append(config['training']['direct_model_learning_rate'])
        agents_loss["direct_layers"].append(f"{len(config['model']['direct_model']['hidden_layers'])}x{config['model']['direct_model']['hidden_layers'][0]}")
        agents_loss["direct_dropout_p"].append(config['model']['direct_model']['dropout_p'])
    else:
        # Use placeholder values when synthesizer is used as direct model
        agents_loss["direct_learning_rate"].append(0)
        agents_loss["direct_layers"].append("synth")
        agents_loss["direct_dropout_p"].append(0)
    
    agents_loss["jerk_weight"].append(config['training']['jerk_loss_weight'])

    # Find index of best validation performance for reporting test metrics
    final_loss_index = np.argmin(metrics["validation"]["inverse_model_repetition_error"])
    
    # Store test performance metrics
    if 'use_synth_as_direct_model' not in config['model']:
        agents_loss["direct_estimation_error"].append(metrics["test"]["direct_model_estimation_error"][final_loss_index])
    else:
        agents_loss["direct_estimation_error"].append(0)

    agents_loss["inverse_estimation_error"].append(metrics["test"]["inverse_model_estimation_error"][final_loss_index])
    agents_loss["jerk"].append(metrics["test"]["inverse_model_jerk"][final_loss_index])
    agents_loss["repetition_error"].append(metrics["test"]["inverse_model_repetition_error"][final_loss_index])
    
    # Create group name based on key configuration parameters
    group_name = "\n".join((
        f"datasets={','.join(agent.sound_quantizer.config['dataset']['names'])}",
        f"synth_art={agent.synthesizer.config['dataset']['art_type']}",
        f"jerk_w={config['training']['jerk_loss_weight']}",
        # f"frame_padding={config['model']['sound_quantizer']['frame_padding']}",
    ))
    
    # Store metrics for this group
    if group_name not in groups_metrics:
        groups_metrics[group_name] = {}
    groups_metrics[group_name][agent_path] = metrics

# Convert collected data to pandas DataFrame for easier analysis
agents_loss = pd.DataFrame(agents_loss)

In [4]:
# Extract unique dataset names and jerk weight values from the agents_loss DataFrame
datasets = pd.unique(agents_loss["datasets"])
jerk_weights = pd.unique(agents_loss["jerk_weight"])

def show_top_agents(measure="repetition_error", datasets=datasets[0], jerk_weight=jerk_weights[0], 
                   use_synth_as_direct=False, ascending=True):
    """
    Display top performing agents based on specified criteria and performance measure.
    
    Args:
        measure (str): Performance metric to sort by (default: "repetition_error")
        datasets (str): Dataset name to filter results
        jerk_weight (float): Jerk weight value to filter results  
        use_synth_as_direct (bool): If True, show only agents using synthesizer as direct model
        ascending (bool): Sort order - True for ascending (better), False for descending
    
    Returns:
        Displays DataFrame of at most top 30 agents matching criteria, sorted by performance measure
    """
    # Filter agents by selected dataset
    subset = agents_loss[agents_loss["datasets"] == datasets]
    # Further filter by jerk weight parameter
    subset = subset[subset["jerk_weight"] == jerk_weight]
    
    # Filter based on model architecture choice
    if use_synth_as_direct:
        subset = subset[subset["direct_layers"] == "synth"]  # Only synthesizer-as-direct models
    else:
        subset = subset[subset["direct_layers"] != "synth"]  # Only neural network models
    
    # Display at most top 30 agents sorted by the specified measure
    display(subset.sort_values(measure, ascending=ascending).head(30))

# Create interactive widget to explore agent performances
ipw.interactive(show_top_agents, 
                measure=agents_loss,  # Available performance metrics
                datasets=datasets,    # Dataset options
                jerk_weight=jerk_weights,  # Jerk weight options
                use_synth_as_direct=False, # Toggle synthesizer vs neural network models
                ascending=True)       # Toggle sort order

interactive(children=(Dropdown(description='measure', index=12, options=('path', 'datasets', 'inverse_learning…

In [5]:
# Define key performance metrics to visualize
# Uncomment additional metrics as needed for deeper analysis
metrics_name = [
    # "sound_quantizer_reconstruction_error",  # Error in reconstructing original sound
    # "sound_quantizer_vq_loss",              # Vector quantization loss
    # "direct_model_estimation_error",         # Error in direct model predictions
    "inverse_model_estimation_error",        # Error in inverse model parameter estimation
    "inverse_model_repetition_error",        # Consistency error across repeated trials
    "inverse_model_jerk",                    # Smoothness of generated trajectories
]

def show_metrics(split_name="test"):
    """
    Visualize training metrics across different data splits.
    
    Args:
        split_name (str): Data split to analyze - "train", "validation", or "test"
                         Defaults to "test" split.
    
    Displays:
        Interactive plot showing the progression of selected metrics over training
        for the specified data split. Metrics are color-coded and labeled for clarity.
    """
    plot_groups_metrics(groups_metrics, metrics_name, split_name)

# Create interactive widget to toggle between different data splits
# This allows quick comparison of model performance across train/val/test sets
display(ipw.interactive(show_metrics, split_name=["train", "validation", "test"]))

interactive(children=(Dropdown(description='split_name', index=2, options=('train', 'validation', 'test'), val…