In [6]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt

# Initialize wandb API
api = wandb.Api()

# User inputs
project_name = "thesis_baselines"  # Replace with your wandb project name
entity_name = "shayan000"    # Replace with your wandb entity
metric_list = ["test/accuracy", "test/f1_macro", "test/roc_auc"]  # List of metrics to plot

# Fetch runs
runs = api.runs(f"{entity_name}/{project_name}")
runs = [run for run in runs if run.name in ["d_['MGYS00005285']   s_GO-slim_abundances", "d_['MGYS00003619']   s_GO-slim_abundances", "d_['MGYS00003677']   s_GO-slim_abundances"]]

# Collect metrics from the summary table for each run
data = []
for run in runs:
    if "Test Metrics Summary table" in run.summary:
        table = run.summary["Test Metrics Summary table"]  # Get the wandb.Table object
        print(table)
        print(type(table))
        df = table.to_dataframe()  # Convert the table to a pandas DataFrame
        run_metrics = {"Run": run.name}  # Initialize dictionary for current run
        for metric in metric_list:
            if metric in df['Metric'].values:
                value = df.loc[df['Metric'] == metric, 'Mean'].values[0]  # Extract 'Mean' value
                run_metrics[metric] = value
        data.append(run_metrics)

# Create a DataFrame from collected data
df_metrics = pd.DataFrame(data)

# Plot the data
df_metrics.set_index('Run', inplace=True)
df_metrics.plot(kind='bar', figsize=(12, 6))
plt.title("Metrics Comparison Across Runs")
plt.ylabel("Metric Values")
plt.xlabel("Runs")
plt.xticks(rotation=45)
plt.legend(title="Metrics")
plt.tight_layout()
plt.show()


{'_latest_artifact_path': 'wandb-client-artifact://10828xktzo8gu74fslfqlbuzq2duvrik70yka4raopfq9owuxf5ehcl5uyxjbrck4m43sw1n7nkb06tnc3s53wsx1hf1bliu92igyiynldj9nm9slzbz1o6lodlian5m:latest/Test Metrics Summary table.table.json', '_type': 'table-file', 'artifact_path': 'wandb-client-artifact://v7z01vp5l3jl3bvqxrwztsywavwkkon1b0z2b6968suw4ed5nwup02iqwnmmlrp5az3sul7r7tdgx1i5xy19zpj5jminudqmxu2sh4oyhhlzeunwxeh36stlkedojcue/Test Metrics Summary table.table.json', 'ncols': 3, 'nrows': 20, 'path': 'media/table/Test Metrics Summary table_5_e1a0a069afc0f0167924.table.json', 'sha256': 'e1a0a069afc0f0167924195ce72c527700aefcf104c4a5aea1f279c0f55b1945', 'size': 1392}
<class 'wandb.old.summary.SummarySubDict'>


KeyError: 'to_dataframe'

In [25]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def fetch_and_plot_metrics(
    entity,
    project,
    metrics_to_plot,
    run_names=None,
    n_recent_runs=5,
    figsize=(12, 6),
    title="Metrics Comparison Across Runs",
    table_name="Test Metrics Summary table"
):
    """
    Fetch metrics from W&B table and create a bar plot comparison.
    
    Parameters:
    -----------
    entity : str
        W&B entity name (username or team name)
    project : str
        W&B project name
    metrics_to_plot : list
        List of metric names to include in the plot
    run_names : list, optional
        List of specific run names to include. If None, uses n_recent_runs
    n_recent_runs : int, optional
        Number of most recent runs to include if run_names is None
    figsize : tuple, optional
        Figure size (width, height)
    title : str, optional
        Plot title
    table_name : str, optional
        Name of the W&B table containing the metrics
    
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The generated figure
    """
    # Initialize W&B API
    api = wandb.Api()
    
    # Get runs
    runs = api.runs(f"{entity}/{project}")
    
    if run_names:
        runs = [run for run in runs if run.name in run_names]
    else:
        runs = runs[:n_recent_runs]
    
    # Collect metrics from tables
    data = []
    for run in runs:
        try:
            # Access the table from summary
            print(run.summary._json_dict)
            table = run.summary._json_dict[table_name]
            
            # Extract data and column names
            # table_data = table['data']
            # columns = table['columns']
            
            print(table)
            # Convert to DataFrame
            df = table.to_dataframe()
            
            # Get mean values from the table
            metrics_dict = {metric: df[df['Metric'] == metric]['Mean'].iloc[0] 
                          for metric in metrics_to_plot if metric in df['Metric'].values}
            metrics_dict['run_name'] = run.name
            data.append(metrics_dict)
            
        except Exception as e:
            print(f"Warning: Could not process table for run {run.name}. Error: {str(e)}")
            continue
    
    if not data:
        raise ValueError("No valid data could be collected from any runs")
    
    # Convert to DataFrame
    df_plot = pd.DataFrame(data)
    
    # Prepare data for plotting
    plot_data = df_plot.melt(
        id_vars=['run_name'],
        value_vars=metrics_to_plot,
        var_name='Metric',
        value_name='Value'
    )
    
    # Create plot
    plt.figure(figsize=figsize)
    
    # Create grouped bar plot
    sns.barplot(
        data=plot_data,
        x='run_name',
        y='Value',
        hue='Metric',
        palette='husl'
    )
    
    # Customize plot
    plt.title(title, pad=20)
    plt.xticks(rotation=45, ha='right')
    plt.xlabel('Run Name')
    plt.ylabel('Value')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    
    return plt.gcf()

# Configuration
ENTITY = "shayan000"
PROJECT = "thesis_baselines"

# Metrics to compare
METRICS = [
    "test/accuracy",
    "test/f1_macro",
    "test/precision_macro"
]

run_names = ["d_['MGYS00005285']   s_GO-slim_abundances", "d_['MGYS00003619']   s_GO-slim_abundances", "d_['MGYS00003677']   s_GO-slim_abundances"]

fig = fetch_and_plot_metrics(
    entity=ENTITY,
    project=PROJECT,
    metrics_to_plot=METRICS,
    n_recent_runs=3,
    run_names=run_names
)

plt.show()

{'Data description': {'_latest_artifact_path': 'wandb-client-artifact://2191imov5u6du0m5axe6qg0e7ue4ynytcwvv50k78wnn8kydggzlfjen7jww80v9lbekvy9vfqw8ysntpjgh3eamkz9a69vvp32g5hhjzdcumfco6y5p9dm5potwpmi6:latest/Data description.table.json', '_type': 'table-file', 'artifact_path': 'wandb-client-artifact://r05vjicj9ohws9q7t3lavz5ea03bjprra9gfwoz4b5hbo3gdt3meac30183ivojup8ma2vfud6kfexigxehloyd9hcaz81xytgb6v5pk1yc00xlo7d424td5c03xg04f/Data description.table.json', 'ncols': 9, 'nrows': 116, 'path': 'media/table/Data description_0_5cf58a8e15e59f1e1c74.table.json', 'sha256': '5cf58a8e15e59f1e1c7458ef47523394b96b9c06135e7cc5abf9534a8c4131ae', 'size': 11369}, 'Outer fold.test/accuracy': 0.6153846153846154, 'Outer fold.test/average_precision(_macro)': 0.625987525987526, 'Outer fold.test/average_precision_micro': 0.625987525987526, 'Outer fold.test/average_precision_weighted': 0.625987525987526, 'Outer fold.test/balanced_accuracy': 0.6013513513513513, 'Outer fold.test/f1(_binary)': 0.675324675324675

ValueError: No valid data could be collected from any runs