In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os

base_dir = "."  # Change this to your root directory path if needed
normal_dir = os.path.join(base_dir, "Normal")
attack_dir = os.path.join(base_dir, "Attack")

datasets = ["mnist", "fashion_mnist", "femnist"]
k_values = ["k3", "k5"]
distribs = ["iid", "noniid"]
alphas = ["alpha1", "alpha3", "alpha5"]  # Maps to α = 0.1, 0.3, 0.5
malicious = ["client1", "client2"]
scenarios = ["Scenario 1", "Scenario 2"]

def read_accuracy(file_path, column_name, column_idx):
    try:
        df = pd.read_csv(file_path)
        return df[column_name].values
    except FileNotFoundError:
        print(f"Warning: File not found - {file_path}")
        return np.zeros(10)
    except KeyError:
        print(f"Warning: '{column_name}' column not found in {file_path}")
        return np.zeros(10)

for dataset in datasets:
    plt.figure(figsize=(14, 8))
    for k in k_values:
        for dist in distribs:
            for i, alpha in enumerate(alphas):
                file_name = f"{dataset}-{dist}-{k}-{alpha}.csv"
                file_path = os.path.join(normal_dir, dataset, k, file_name)
                accuracies = read_accuracy(file_path, "Round_Accuracy", 0)
                alpha_val = f"{0.1 * (i + 1):.1f}"
                plt.plot(range(1, 11), accuracies, label=f"{k} {dist} α={alpha_val}", 
                         linestyle="-" if dist == "iid" else "--", linewidth=2)
    plt.title(f"{dataset.capitalize()} (Normal)", fontsize=30)
    plt.xlabel("Round", fontsize=25)
    plt.ylabel("Accuracy (%)", fontsize=25)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=20)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"{dataset}_normal_convergence.png", dpi=300, bbox_inches="tight")
    plt.close()

    plt.figure(figsize=(14, 8))
    for k in k_values:
        for dist in distribs:
            for scen_idx, atk in enumerate(["backdoor", "feature"]):
                for i, mal in enumerate(malicious):
                    file_name = f"{dataset}-{dist}-{mal}-{atk}.csv"
                    file_path = os.path.join(attack_dir, dataset, k, file_name)
                    accuracies = read_accuracy(file_path, "Round_Test_Accuracy", 1)
                    plt.plot(range(1, 11), accuracies, label=f"{k} {dist} {scenarios[scen_idx]} {10*(i+1)}%", 
                             linestyle="-" if dist == "iid" else "--", linewidth=2)
    plt.title(f"{dataset.capitalize()} (Attack)", fontsize=30)
    plt.xlabel("Round", fontsize=25)
    plt.ylabel("Accuracy (%)", fontsize=25)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=20)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"{dataset}_attack_convergence.png", dpi=300, bbox_inches="tight")
    plt.close()

print("Figures generated.")

Figures generated.
