## dot plots

In [11]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

In [12]:
def collect_results(base_dir):
    """
    Collect MCD, LSD, and FAD metrics from evaluation CSV files.
    """
    base_dir = Path(base_dir)
    metrics = ["MCD", "LSD", "FAD"]

    results = {metric: {file_type: [] for file_type in ["base_model", "per_chan_4_bit", "per_chan_8_bit", 
                                                        "per_group_4_bit", "per_group_8_bit", 
                                                        "per_tensor_4_bit", "per_tensor_8_bit"]}
               for metric in metrics}

    # iterate over each channel folder
    for channel_folder in sorted(base_dir.iterdir(), key=lambda x: int(x.name) if x.name.isdigit() else float('inf')):
        if not channel_folder.is_dir():
            continue

        # load the CSV file for the current sample
        csv_file = channel_folder / "evaluation_results.csv"
        if not csv_file.exists():
            print(f"No CSV file found in {channel_folder}")
            continue

        try:
            df = pd.read_csv(csv_file)
            
            # extract metric values for each file type
            for _, row in df.iterrows():
                file_name = row["file"]
                file_type = file_name.split("_channel")[0]  # Get the prefix before "_channel"
                
                for metric in metrics:
                    results[metric][file_type].append(row[metric])
        except Exception as e:
            print(f"Failed to process {csv_file}: {e}")
            continue

    return results

In [95]:
def plot_metrics(results, base_dir):
    """
    Plot MCD, LSD, and FAD dotplots across channels.
    """
    metrics = list(results.keys())
    # file_type_colors = {
    #     "base_model": "silver",
    #     "per_chan_4_bit": "tomato",
    #     "per_chan_8_bit": "gold",
    #     "per_group_4_bit": "yellowgreen",
    #     "per_group_8_bit": "skyblue",
    #     "per_tensor_4_bit": "thistle",
    #     "per_tensor_8_bit": "pink",
    # }
    file_type_colors = {
        "base_model": "black",
        "per_chan_4_bit": "red",
        "per_chan_8_bit": "orange",
        "per_group_4_bit": "green",
        "per_group_8_bit": "blue",
        "per_tensor_4_bit": "purple",
        "per_tensor_8_bit": "pink",
    }

    for metric in metrics:
        plt.figure(figsize=(14, 6))
        plt.title(f"{metric} Across Samples", fontsize=24)
        plt.xlabel("Sample", fontsize=21)
        plt.ylabel(metric, fontsize=21)
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.xticks(range(1, 17), labels=[str(i) for i in range(1, 17)])

        for file_type, color in file_type_colors.items():
            for channel_idx, channel_data in enumerate(results[metric][file_type], start=1):
                # Ensure channel_data is a list
                if not isinstance(channel_data, list):
                    channel_data = [channel_data]

                plt.scatter([channel_idx] * len(channel_data), channel_data, label=file_type if channel_idx == 1 else None,
                            color=color, alpha=0.7, s=150)

        # add legend outside the plot to the right
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1), fontsize=18)
        plt.grid(True, linestyle="--", alpha=0.6)
        plt.tight_layout()

        # save plot as a SVG file
        output_path = Path(base_dir) / f"{metric}_dotplot.svg"
        plt.savefig(output_path)
        plt.close()
        print(f"Saved {metric} dotplot to {output_path}")

In [96]:
base_directory = "./examples/wavs"
results = collect_results(base_directory)
plot_metrics(results, base_directory)

Saved MCD dotplot to examples/wavs/MCD_dotplot.svg
Saved LSD dotplot to examples/wavs/LSD_dotplot.svg
Saved FAD dotplot to examples/wavs/FAD_dotplot.svg


## box plots

In [61]:
def collect_metrics_for_boxplots(base_dir):
    """
    Collect MCD, LSD, and FAD metrics grouped by quantization methods.
    """
    base_dir = Path(base_dir)
    metrics = ["MCD", "LSD", "FAD"]
    metric_data = {metric: {} for metric in metrics}

    for channel_folder in sorted(base_dir.iterdir(), key=lambda x: int(x.name) if x.name.isdigit() else float('inf')):
        if not channel_folder.is_dir():
            continue

        csv_file = channel_folder / "evaluation_results.csv"
        if not csv_file.exists():
            print(f"No evaluation_results.csv found in {channel_folder}")
            continue

        try:
            df = pd.read_csv(csv_file)

            for metric in metrics:
                if metric not in df.columns:
                    print(f"{metric} not found in {csv_file}")
                    continue

                for _, row in df.iterrows():
                    file_type = row["file"].split("_channel")[0]
                    if file_type not in metric_data[metric]:
                        metric_data[metric][file_type] = []
                    metric_data[metric][file_type].append(row[metric])

        except Exception as e:
            print(f"Failed to process {csv_file}: {e}")
            continue

    return metric_data

In [105]:
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from pathlib import Path

def plot_boxplots(metric_data, output_dir):
    """
    Create box plots for MCD, LSD, and FAD metrics grouped by quantization methods.
    """
    output_dir = Path(output_dir)
    metrics = ["MCD", "LSD", "FAD"]

    # Define custom outline and fill colors
    outline_colors = ["black", "red", "orange", "green", "blue", "purple", "pink"]
    fill_colors = [(*to_rgba(color)[:3], 0.3) for color in outline_colors]  # Lighter fill with 30% opacity

    for metric in metrics:
        plt.figure(figsize=(12, 6))
        plt.title(f"{metric} Across Quantization Levels", fontsize=24)
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.xlabel("Quantization Level", fontsize=21)

        # Prepare data for boxplot
        data = [metric_data[metric][file_type] for file_type in sorted(metric_data[metric].keys())]
        labels = sorted(metric_data[metric].keys())

        # Create the boxplot with custom styles
        boxplot = plt.boxplot(
            data,
            labels=labels,
            patch_artist=True,
            whiskerprops=dict(linewidth=2, linestyle="--"),
            capprops=dict(linewidth=2),  # Bold caps
            medianprops=dict(linewidth=2, color="darkred"),
            flierprops=dict(marker="o", markersize=6, linestyle="none", markeredgecolor="black", markerfacecolor="grey")
        )

        # Apply custom colors for each box
        for patch, outline_color, fill_color in zip(boxplot["boxes"], outline_colors, fill_colors):
            patch.set_edgecolor(outline_color)  # Set the outline color
            patch.set_linewidth(2)             # Bold the outline
            patch.set_facecolor(fill_color)    # Set the lighter fill color

        plt.xticks(rotation=45, fontsize=18)
        plt.grid(axis="y", linestyle="--", alpha=0.7)
        plt.tight_layout()

        output_path = output_dir / f"{metric}_boxplot.svg"
        plt.savefig(output_path)
        plt.close()
        print(f"Saved {metric} box plot to {output_path}")

In [106]:
base_directory = "./examples/wavs"
metrics_data = collect_metrics_for_boxplots(base_directory)
plot_boxplots(metrics_data, base_directory)

Saved MCD box plot to examples/wavs/MCD_boxplot.svg
Saved LSD box plot to examples/wavs/LSD_boxplot.svg
Saved FAD box plot to examples/wavs/FAD_boxplot.svg


  boxplot = plt.boxplot(
  boxplot = plt.boxplot(
  boxplot = plt.boxplot(


## table

In [45]:
def save_combined_metrics_to_csv(base_dir):
    """
    Combine MCD, LSD, and FAD metrics from all evaluation_results.csv files into three large CSV files,
    with channels as rows and file types as columns.

    Parameters:
    - base_dir (str): Base directory containing numbered subfolders with evaluation_results.csv files.
    """
    base_dir = Path(base_dir)
    metrics = ["MCD", "LSD", "FAD"]
    
    # Initialize dictionaries to store combined results for each metric
    combined_results = {metric: {} for metric in metrics}

    # Iterate over each numbered folder
    for channel_folder in sorted(base_dir.iterdir(), key=lambda x: int(x.name) if x.name.isdigit() else float('inf')):
        if not channel_folder.is_dir():
            continue

        # Load the evaluation_results.csv file
        csv_file = channel_folder / "evaluation_results.csv"
        if not csv_file.exists():
            print(f"No evaluation_results.csv found in {channel_folder}")
            continue

        try:
            df = pd.read_csv(csv_file)
            
            # Add metrics to combined_results
            for metric in metrics:
                if metric not in df.columns:
                    print(f"{metric} not found in {csv_file}")
                    continue
                
                # Store results for each file type in this channel
                channel_name = f"sample {channel_folder.name}"
                if channel_name not in combined_results[metric]:
                    combined_results[metric][channel_name] = {}
                
                for _, row in df.iterrows():
                    file_type = row["file"].split("_channel")[0]
                    combined_results[metric][channel_name][file_type] = row[metric]

        except Exception as e:
            print(f"Failed to process {csv_file}: {e}")
            continue

    # Save each metric to a separate CSV file
    for metric in metrics:
        metric_data = {}

        # Collect all file types across channels
        file_types = set()
        for channel_data in combined_results[metric].values():
            file_types.update(channel_data.keys())
        file_types = sorted(file_types)  # Ensure consistent ordering

        # Create columns for each file type and rows for each channel
        for channel_name in sorted(combined_results[metric].keys(), key=lambda x: int(x.split()[-1])):
            row = {file_type: combined_results[metric][channel_name].get(file_type, None) for file_type in file_types}
            metric_data[channel_name] = row

        # Convert to DataFrame and transpose
        df_metric = pd.DataFrame(metric_data).T
        df_metric.index.name = ""
        df_metric.reset_index(inplace=True)

        # Save to CSV
        output_csv = base_dir / f"{metric}_results.csv"
        df_metric.to_csv(output_csv, index=False)
        print(f"Saved {metric} results to {output_csv}")

In [46]:
save_combined_metrics_to_csv(base_directory)

Saved MCD results to examples/wavs/MCD_results.csv
Saved LSD results to examples/wavs/LSD_results.csv
Saved FAD results to examples/wavs/FAD_results.csv
