In [1]:
import os
import torch
import matplotlib.pyplot as plt
from mlxtend.plotting import plot_confusion_matrix
from tabulate import tabulate

In [2]:
def print_metrics(history, mode='train', epoch=-1):
    mode_history = history.get(mode.lower(), None)
    if mode_history is None:
        print(f"Error: Mode '{mode}' not found in metrics_results.")
        return
    
    loss = mode_history['loss'][epoch]
    accuracy = mode_history['accuracy'][epoch]
    precision = mode_history['precision'][epoch]
    recall = mode_history['recall'][epoch]
    f1 = mode_history['f1'][epoch]
    specificity = mode_history['specificity'][epoch]
    cm = mode_history['confusion_matrix'][epoch]
    time = mode_history['time'][epoch]
    
    table = [
        ["Loss", f"{loss:.8f}"],
        ["Accuracy", f"{accuracy:.4f}"],
        ["Precision", f"{precision:.4f}"],
        ["Recall", f"{recall:.4f}"],
        ["F1 Score", f"{f1:.4f}"],
        ["Specificity", f"{recall:.4f}"],
        ["TP", cm[0][0]],
        ["TN", cm[1][1]],
        ["FP", cm[0][1]],
        ["FN", cm[1][0]],
        ["Time", f"{time:.2f}s"]
    ]
    
    print(tabulate(table, headers=[mode.capitalize(), 'Value'], tablefmt="grid"))

def plot_single_metric(ax, metric_name, title, train_metric_1, val_metric_1, train_metric_2=None, val_metric_2=None, color='blue', label_1='Model 1', label_2='Model 2'):
    train_values_1 = train_metric_1[metric_name]
    val_values_1 = val_metric_1[metric_name]
    
    ax.plot(train_values_1, label=label_1 + ' Training ' + title, color=color, linewidth=2)
    ax.plot(val_values_1, label=label_1 + ' Validation ' + title, linestyle='--', color=color, linewidth=2)
    
    if train_metric_2 is not None and val_metric_2 is not None:
        train_values_2 = train_metric_2[metric_name]
        val_values_2 = val_metric_2[metric_name]
        ax.plot(train_values_2, label=label_2 + ' Training ' + title, color='red', linewidth=2)
        ax.plot(val_values_2, label=label_2 + ' Validation ' + title, linestyle='--', color='red', linewidth=2)
    
    ax.set_xlabel('Epoch')
    ax.set_ylabel(metric_name.capitalize())
    ax.set_title('Training and Validation ' + title)
    ax.legend()
    ax.grid(True, linestyle=':', linewidth=0.5)

def plot_all_metrics(history_1, name_1, history_2=None, name_2=None, save_path=None):
    metrics = list(history_1['train'].keys())
    metrics.remove('confusion_matrix')
    num_metrics = len(metrics)
    
    fig, axs = plt.subplots(num_metrics, 1, figsize=(10, 5*num_metrics))
    
    for j, metric_name in enumerate(metrics):
        ax = axs[j]
        color = 'blue'
        if history_2 is not None and name_2 is not None:
            plot_single_metric(ax, metric_name, metric_name.capitalize(), history_1['train'], history_1['val'], history_2['train'], history_2['val'], color=color, label_1=name_1, label_2=name_2)
            ax.set_title(name_1 + ' vs ' + name_2, fontsize=12, fontweight='bold')  
        else:
            plot_single_metric(ax, metric_name, metric_name.capitalize(), history_1['train'], history_1['val'], color=color, label_1=name_1)
    
    fig.suptitle('Comparison of Models', fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
    if save_path:
        plt.savefig(save_path)
    plt.show()

In [None]:
MODELS_DIR_PATH = "../saved_models"
MODEL1_NAME = "m5_frames_40_clip_04_2024-05-11_17-54-59.pt"
MODEL2_NAME = "m2_2024-05-11_07-42-40.pt"

checkpoint1 = torch.load(os.path.join(MODELS_DIR_PATH, MODEL1_NAME))
checkpoint2 = torch.load(os.path.join(MODELS_DIR_PATH, MODEL2_NAME))

history_1 = checkpoint1['history']
history_2 = checkpoint2['history']
hyperparameters_1 = checkpoint1['hyperparameters']
hyperparameters_2 = checkpoint2['hyperparameters']

print(f"Hyperparameters for ({MODEL1_NAME}):")
print(tabulate(hyperparameters_1.items(), headers=["Parameter", "Value"]))
print()
print(f"Hyperparameters for ({MODEL2_NAME}):")
print(tabulate(hyperparameters_2.items(), headers=["Parameter", "Value"]))

In [None]:
plot_all_metrics(history_1, "Final Model", history_2, MODEL2_NAME, save_path=None)

In [None]:
EPOCH_1 = -1
EPOCH_2 = -1

print(f"Metrics for ({MODEL1_NAME}):")
print_metrics(checkpoint1['history'], 'test', EPOCH_1)
print(f"\nMetrics for ({MODEL2_NAME}):")
print_metrics(checkpoint2['history'], 'test', EPOCH_2)

# Define the class labels
class_labels = ['ls_p', 'ls_a']

# Plot confusion matrix for Model 1
plot_confusion_matrix(conf_mat=checkpoint1['history']['test']['confusion_matrix'][EPOCH_1].numpy(), class_names=class_labels)
plt.title(f'Confusion Matrix - {"Final Model"}')
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.show()

# Plot confusion matrix for Model 2
plot_confusion_matrix(conf_mat=checkpoint2['history']['test']['confusion_matrix'][EPOCH_2].numpy(), class_names=class_labels)
plt.title(f'Confusion Matrix - {MODEL2_NAME}')
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.show()