In [None]:
import pandas as pd
import ast
import matplotlib.pyplot as plt

In [None]:
# Load files
file_paths = ["data/adam_results.csv", "data/adam_results_8.csv", "data/sgd_results.csv", "data/sgd_results8.csv", "data/ivon_results.csv", "data/ivon_results_8.csv"]
optimizers = ["Adam", "Adam", "SGD", "SGD", "IVON", "IVON"]

# Load data
datasets = []
for path in file_paths:
    data = pd.read_csv(path)
    for column in ['train_loss', 'val_loss', 'train_acc', 'val_acc']:
        data[column] = data[column].apply(ast.literal_eval)
    datasets.append(data)

def compare_optimizers_subplots(datasets, optimizers, model_type, ensemble_model_type):
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    metrics = ['train_loss', 'val_loss', 'train_acc', 'val_acc']
    titles = ["Train Loss", "Validation Loss", "Train Accuracy", "Validation Accuracy"]
    ylabels = ["Loss", "Loss", "Accuracy (%)", "Accuracy (%)"]

    for i, ax in enumerate(axes.flat):
        var = metrics[i]
        
        for data, optimizer in zip(datasets, optimizers):
            subset1 = data[data['model_type'] == model_type]
            subset2 = data[data['model_type'] == ensemble_model_type]

            for _, row in subset1.iterrows():
                epochs = range(1, len(row[var]) + 1)
                ax.plot(epochs, row[var], label=f"{optimizer} - {model_type}")

            for _, row in subset2.iterrows():
                epochs = range(1, len(row[var]) + 1)
                ax.plot(epochs, row[var], label=f"{optimizer} - {ensemble_model_type}")

        ax.set_title(titles[i], fontsize=12)
        ax.set_xlabel("Epochs", fontsize=10)
        ax.set_ylabel(ylabels[i], fontsize=10)
        ax.legend(fontsize=8, loc='best')
    
    fig.suptitle("Comparison of Optimizers for single network and batchensembles of 8", fontsize=16)
    plt.tight_layout()
    plt.show()

# Compare models across optimizers
# compare_optimizers_subplots(datasets, optimizers, "simple", "batchensemble")
compare_optimizers_subplots(datasets, optimizers, "complex", "batchensemble_complex")