In [None]:
import numpy as np
from matplotlib import pyplot as plt
import os
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize, LinearSegmentedColormap
import csv
import ast  # To safely parse the string list
from collections import defaultdict

In [None]:
def init_plotting():
    plt.rcParams['figure.figsize'] = (7., 4.)
    plt.rcParams['font.size'] = 12
    # plt.rcParams['font.family'] = 'T'
    plt.rcParams['axes.labelsize'] = plt.rcParams['font.size']
    plt.rcParams['axes.titlesize'] = 1.5*plt.rcParams['font.size']
    plt.rcParams['legend.fontsize'] = plt.rcParams['font.size']
    plt.rcParams['xtick.labelsize'] = 10
    plt.rcParams['ytick.labelsize'] = 10
    # plt.rcParams['savefig.dpi'] = 2*plt.rcParams['savefig.dpi']
    plt.rcParams['xtick.major.size'] = 3
    plt.rcParams['xtick.minor.size'] = 3
    plt.rcParams['xtick.major.width'] = 1
    plt.rcParams['xtick.minor.width'] = 1
    plt.rcParams['ytick.major.size'] = 3
    plt.rcParams['ytick.minor.size'] = 3
    plt.rcParams['ytick.major.width'] = 1
    plt.rcParams['ytick.minor.width'] = 1
    plt.rcParams['legend.frameon'] = False
    plt.rcParams['legend.loc'] = 'best'
    plt.rcParams['axes.linewidth'] = 2.

    # plt.gca().spines['right'].set_color('none')
    # plt.gca().spines['top'].set_color('none')
    # plt.gca().xaxis.set_ticks_position('bottom')
    # plt.gca().yaxis.set_ticks_position('left')

init_plotting()

In [None]:
def smooth(y, weight=0.85):
    """Simple exponential moving average smoothing."""
    smoothed = []
    last = y[0]
    for val in y:
        smoothed_val = last * weight + (1 - weight) * val
        smoothed.append(smoothed_val)
        last = smoothed_val
    return smoothed

def moving_std(data, window=5):
    """Rolling standard deviation."""
    stds = []
    padded = [data[0]] * (window // 2) + data + [data[-1]] * (window // 2)
    for i in range(len(data)):
        window_vals = padded[i:i+window]
        stds.append(np.std(window_vals))
    return stds

def truncate_colormap(cmap, minval=0.0, maxval=0.95, n=256):
    """Trim the extremes of a colormap to avoid overly bright/dark colors."""
    new_cmap = LinearSegmentedColormap.from_list(
        f'trunc({cmap.name},{minval:.2f},{maxval:.2f})',
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap

In [None]:
# Initialize data structure
data = defaultdict(lambda: defaultdict(dict))
root = "results_true"
sample_sizes = [32, 64, 128, 256, 512, 1024]

# Read and process CSV files
for sample in sample_sizes:
    folder_path = os.path.join(root, str(sample))
    for filename in os.listdir(folder_path):
        if not filename.endswith(".csv"):
            continue

        parts = filename[:-4].split("_")
        if len(parts) != 3:
            continue

        _, num_classes, learning_rate = parts
        num_classes = int(num_classes)
        learning_rate = float(learning_rate)

        filepath = os.path.join(folder_path, filename)
        per_cycle_values = []

        with open(filepath, newline='') as csvfile:
            reader = csv.reader(csvfile)
            next(reader)  # Skip header

            cycle_dict = defaultdict(list)

            for row in reader:
                cycle = int(row[0])
                acc_list_str = row[1]
                last_class_trained = int(row[2])
                total_classes = int(row[3])

                # Only keep rows where last class trained == total_classes - 1
                if last_class_trained != total_classes - 1:
                    continue

                # Parse accuracy list safely
                acc_list = ast.literal_eval(acc_list_str)[:total_classes]
                mean_accuracy = sum(acc_list) / total_classes

                cycle_dict[cycle].append(mean_accuracy)

            # Now compute per-cycle average from filtered rows
            for cycle in sorted(cycle_dict):
                cycle_avg = sum(cycle_dict[cycle]) / len(cycle_dict[cycle])
                per_cycle_values.append(cycle_avg)

        data[sample][learning_rate][num_classes] = per_cycle_values


In [None]:
for sample in data:
    for lr in data[sample]:
        # plt.figure(figsize=(10, 6))

        # Set up color gradient for 9 class counts
        norm = mcolors.Normalize(vmin=2, vmax=10)
        plasma = truncate_colormap(plt.colormaps["plasma"], minval=0.1, maxval=0.95)

        for nc in sorted(data[sample][lr]):
            raw_y = data[sample][lr][nc]
            x = np.arange(1, len(raw_y) + 1)
            y = smooth(raw_y, weight=0.65)
            std = moving_std(raw_y, window=7)

            color = plasma(norm(nc))

            # Plot smoothed line
            plt.plot(x, y, label=f"{nc} classes", color=color, linewidth=2)

            # Shaded variance region
            y = np.array(y)
            std = np.array(std)
            plt.fill_between(x, y - std, y + std, color=color, alpha=0.5)

        # plt.title(f"Total samples per class: {sample}, Learning Rate: {lr}", fontsize=14)
        # plt.xlabel("Cycle", fontsize=12)
        # plt.ylabel("Avg Class Accuracy", fontsize=12)
        # plt.legend(title="Num Classes", fontsize=10)
        plt.title(f"Total samples per class: {sample}, Learning Rate: {lr}")
        plt.xlabel("Cycle")
        plt.ylabel("Avg Class Accuracy")
        plt.legend(title="Num Classes")
        plt.grid(True, linestyle="--")
        plt.tight_layout()
        plt.savefig(f"results/plot_{sample}_{lr}.png")
        plt.close()