In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os


def plot_training_loss(file_path: str, output_path: str):
    """
    Reads the training loss data from a CSV file, plots the loss curves for each run,
    and saves the plot as an image.

    Args:
        file_path (str): Path to the input CSV file.
        output_path (str): Path to save the output plot image.
    """
    # Read the CSV file
    df = pd.read_csv(file_path)

    # Create a figure and axis
    plt.figure(figsize=(10, 6))

    # Get a colormap with distinguishable colors
    colormap = plt.colormaps['tab10']  # Updated to use matplotlib.colormaps[name]
    colors = [colormap(i) for i in range(10)]  # Generate up to 10 colors

    # Group by "Run" and plot each group's loss curve
    for i, (run_name, group) in enumerate(df.groupby("Run")):
        plt.plot(
            group["step"], group["value"], label=run_name, color=colors[i % len(colors)], linestyle='-', alpha=0.7
        )

    # Customize the plot
    plt.title("Training Loss Comparison")
    plt.xlabel("Steps")
    plt.ylabel("Training Loss")
    plt.legend(title="Run", loc="upper right", fontsize="small")  # Add a legend for runs
    plt.grid(alpha=0.5)

    # Save the plot to the specified file
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path, dpi=300)
    plt.close()

In [None]:
# File and output paths
input_csv_path = "./training_loss/C10.csv"
output_image_path = "./training_loss/comparison_plot.png"

# Call the function to plot and save the figure
plot_training_loss(input_csv_path, output_image_path)