In [11]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

In [12]:
def collect_results(base_dir):
    """
    Collect MCD, LSD, and FAD metrics from evaluation CSV files.

    Parameters:
    - base_dir (str): Path to the base directory containing numbered subfolders.

    Returns:
    - results (dict): A dictionary containing metric data for each channel and file type.
    """
    base_dir = Path(base_dir)
    metrics = ["MCD", "LSD", "FAD"]

    # Initialize a dictionary to store results
    results = {metric: {file_type: [] for file_type in ["base_model", "per_chan_4_bit", "per_chan_8_bit", 
                                                        "per_group_4_bit", "per_group_8_bit", 
                                                        "per_tensor_4_bit", "per_tensor_8_bit"]}
               for metric in metrics}

    # Iterate over each channel folder
    for channel_folder in sorted(base_dir.iterdir(), key=lambda x: int(x.name) if x.name.isdigit() else float('inf')):
        if not channel_folder.is_dir():
            continue  # Skip non-folder entries

        # Load the CSV file for the current channel
        csv_file = channel_folder / "evaluation_results.csv"
        if not csv_file.exists():
            print(f"No CSV file found in {channel_folder}")
            continue

        try:
            df = pd.read_csv(csv_file)
            
            # Extract metric values for each file type
            for _, row in df.iterrows():
                file_name = row["file"]
                file_type = file_name.split("_channel")[0]  # Get the prefix before "_channel"
                
                for metric in metrics:
                    results[metric][file_type].append(row[metric])
        except Exception as e:
            print(f"Failed to process {csv_file}: {e}")
            continue

    return results

In [32]:
def plot_metrics(results, base_dir):
    """
    Plot MCD, LSD, and FAD dotplots across channels.

    Parameters:
    - results (dict): Dictionary containing metric data for each channel and file type.
    - base_dir (str): Base directory to save the plots.
    """
    metrics = list(results.keys())
    file_type_colors = {
        "base_model": "black",
        "per_chan_4_bit": "red",
        "per_chan_8_bit": "orange",
        "per_group_4_bit": "green",
        "per_group_8_bit": "blue",
        "per_tensor_4_bit": "purple",
        "per_tensor_8_bit": "pink",
    }

    for metric in metrics:
        plt.figure(figsize=(14, 6))
        plt.title(f"{metric} Across Samples", fontsize=16)
        plt.xlabel("Sample", fontsize=14)
        plt.ylabel(metric, fontsize=14)
        plt.xticks(range(1, 17), labels=[str(i) for i in range(1, 17)])

        # Plot dots for each file type across all channels
        for file_type, color in file_type_colors.items():
            for channel_idx, channel_data in enumerate(results[metric][file_type], start=1):
                # Ensure channel_data is a list
                if not isinstance(channel_data, list):
                    channel_data = [channel_data]

                plt.scatter([channel_idx] * len(channel_data), channel_data, label=file_type if channel_idx == 1 else None,
                            color=color, alpha=0.7, s=100)

        # Add legend outside the plot to the right
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1), fontsize=12)
        plt.grid(True, linestyle="--", alpha=0.6)
        plt.tight_layout()

        # Save plot as a PNG file
        output_path = Path(base_dir) / f"{metric}_dotplot.png"
        plt.savefig(output_path)
        plt.close()
        print(f"Saved {metric} dotplot to {output_path}")

In [33]:
# Base directory containing channel subfolders
base_directory = "./examples/wavs"

# Step 1: Collect results
results = collect_results(base_directory)

# Step 2: Plot and save the dotplots
plot_metrics(results, base_directory)

Saved MCD dotplot to examples/wavs/MCD_dotplot.png
Saved LSD dotplot to examples/wavs/LSD_dotplot.png
Saved FAD dotplot to examples/wavs/FAD_dotplot.png
