In [1]:
import argparse
import random
import os
import torch
import re
from typing import Dict, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from ptflops import get_model_complexity_info  # Install ptflops if not already installed


from tools import *
from models import create_net

In [2]:
base_folder = "../ckpt/"
save_base_folder = "../notebooks/ckpts/"

In [3]:
def parse_folder_name(folder_name: str) -> argparse.Namespace:
    """
    Parse the folder name to extract arguments as a dictionary.
    """
    # Define regex to match args attributes in folder name
    pattern = (
        r"(?P<dataset>cifar\d+)-"
        r"(?P<arch>[a-zA-Z0-9_]+)-"
        r"(?P<attention_type>[a-zA-Z0-9_]+)-"
        r"(?:param(?P<attention_param>[^-]+)-)?"
        r"(?:paramTwo(?P<param_two>[^-]+)-)?"
        r"mark_Trial(?P<id>\d+)"
    )
    match = re.match(pattern, folder_name)

    if match:
        # Convert the matched dictionary into a Namespace
        args_dict = match.groupdict()
    else:
        # Fallback pattern for cases without attention_type
        print(f"Folder name '{folder_name}' does not fully match expected format. Defaulting attention_type to 'none'.")
        partial_pattern = (
            r"(?P<dataset>cifar\d+)-"
            r"(?P<arch>[a-zA-Z0-9_]+)-"
            r"mark_Trial(?P<id>\d+)"
        )
        partial_match = re.match(partial_pattern, folder_name)

        if partial_match:
            args_dict = partial_match.groupdict()
            # Set default values for missing fields
            args_dict["attention_type"] = "none"
            args_dict["attention_param"] = None
            args_dict["param_two"] = None
        else:
            raise ValueError(f"Folder name '{folder_name}' does not match even the fallback format.")

    # Convert numeric fields to the appropriate type
    if args_dict["attention_param"] is not None:
        args_dict["attention_param"] = float(args_dict["attention_param"])
    if args_dict["param_two"] is not None:
        args_dict["param_two"] = int(args_dict["param_two"])

    return argparse.Namespace(**args_dict)

In [None]:
def set_random_seed(seed: int, args: argparse.Namespace) -> None:

    setattr(args, "seed", seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)

    setattr(args, "cuda", torch.cuda.is_available())

    if args.cuda:
        torch.cuda.manual_seed(args.seed)

In [None]:
# def process_checkpoints(base_folder: str):
#     """
#     Process all folders in the base directory, load checkpoints, and save models with structure.
#     """
#     # Iterate through folders in base directory
#     for folder_name in os.listdir(base_folder):
#         folder_path = os.path.join(base_folder, folder_name)
#         if not os.path.isdir(folder_path):
#             continue

#         try:
#             # Parse args from folder name
#             args = parse_folder_name(folder_name)

#             if args.dataset == "cifar100":
#                 setattr(args, "num_class", 100)
#             elif args.dataset == "cifar10":
#                 setattr(args, "num_class", 10)

#             setattr(args, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
#             set_random_seed(1,args)
#             setattr(args, "num_class", 10)
#             setattr(args, "gpu_ids", "0")
#             setattr(args, "validation", True)
#             setattr(args, "workers", 16)
#             setattr(args, "batch_size", 128)
#             setattr(args, "val_size", 5000)
#             setattr(args, "dataset_dir", "../datasets/CIFAR/")

#             # Create model
#             model = create_net(args)
#             train_loader, val_loader, test_loader = CIFAR_data_loaders(args)

#             # Combine with base folder to locate the checkpoint
#             checkpoint_path = os.path.join(folder_path, "model_best_top1.pth.tar")
#             if not os.path.exists(checkpoint_path):
#                 print(f"Checkpoint not found in {checkpoint_path}, skipping.")
#                 continue

#             # Load the checkpoint
#             model, optimizer, best_acc, best_acc_5, start_epoch = load_checkpoint(
#                 args, model, checkpoint_path
#             )


#         except Exception as e:
#             print(f"Error processing folder '{folder_name}': {e}")

In [None]:
def process_checkpoints(base_folder: str, output_folder: str):
    """
    Process checkpoints to calculate stats, generate confusion matrices, and save results.
    """

    # Prepare storage for results
    stats = []

    for folder_name in os.listdir(base_folder):
        folder_path = os.path.join(base_folder, folder_name)
        if not os.path.isdir(folder_path):
            continue

        try:
            # Parse args and set additional parameters
            args = parse_folder_name(folder_name)
            set_random_seed(1, args)
            setattr(args, "num_class", 100 if args.dataset == "cifar100" else 10)
            setattr(args, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
            setattr(args, "gpu_ids", "0")
            setattr(args, "validation", True)
            setattr(args, "workers", 16)
            setattr(args, "batch_size", 128)
            setattr(args, "val_size", 5000)
            setattr(args, "dataset_dir", "../datasets/CIFAR/")

            # Create model
            model = create_net(args)

            # Get parameters and FLOPs/MACs
            param_count = sum(p.numel() for p in model.parameters())
            macs, params = get_model_complexity_info(model, (3, 32, 32), as_strings=False, verbose=False)

            # Measure inference FPS
            dummy_input = torch.randn(args.batch_size, 3, 32, 32).to(args.device)
            model.to(args.device)
            model.eval()
            with torch.no_grad():
                torch.cuda.synchronize()
                start_time = torch.cuda.Event(enable_timing=True)
                end_time = torch.cuda.Event(enable_timing=True)

                start_time.record()
                model(dummy_input)
                end_time.record()
                torch.cuda.synchronize()
                elapsed_time_ms = start_time.elapsed_time(end_time)
            fps = args.batch_size / (elapsed_time_ms / 1000)

            # Combine with base folder to locate the checkpoint
            checkpoint_path = os.path.join(folder_path, "model_best_top1.pth.tar")
            if not os.path.exists(checkpoint_path):
                print(f"Checkpoint not found in {checkpoint_path}, skipping.")
                continue

            # Load the checkpoint
            model, _, _, _, _ = load_checkpoint(args, model, checkpoint_path)
            val_loader = CIFAR_data_loaders(args)[2]  # Validation loader

            y_true, y_pred = [], []
            for images, labels in val_loader:
                images, labels = images.to(args.device), labels.to(args.device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                y_true.extend(labels.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())

            # Calculate classification metrics
            report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
            precision = report["weighted avg"]["precision"]
            recall = report["weighted avg"]["recall"]
            f1 = report["weighted avg"]["f1-score"]

            # Append trial stats
            stats.append({
                "Dataset": args.dataset,
                "Architecture": args.arch,
                "Attention Type": args.attention_type,
                "Trial ID": args.id,
                "Params": param_count,
                "FLOPs": macs,
                "FPS": fps,
                "Precision": precision,
                "Recall": recall,
                "F1": f1,
            })

            # Generate confusion matrix for specific cases
            if args.attention_type in ["lap_spatial_param", "none"] and args.dataset == "cifar10" and args.id == "01":
                cm = confusion_matrix(y_true, y_pred)
                disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=range(args.num_class))
                fig, ax = plt.subplots(figsize=(10, 10))
                disp.plot(ax=ax)
                cm_filename = f"confusion_matrix_{args.dataset}_{args.arch}_{args.attention_type}_Trial{args.id}.png"
                plt.savefig(os.path.join(output_folder, cm_filename))
                plt.close(fig)

        except Exception as e:
            print(f"Error processing folder '{folder_name}': {e}")
            continue

    # Aggregate stats to calculate mean and std over trials
    stats_df = pd.DataFrame(stats)
    aggregated = (
        stats_df.groupby(["Dataset", "Architecture", "Attention Type"])
        .agg({
            "Params": ["mean", "std"],
            "FLOPs": ["mean", "std"],
            "FPS": ["mean", "std"],
            "Precision": ["mean", "std"],
            "Recall": ["mean", "std"],
            "F1": ["mean", "std"],
        })
        .reset_index()
    )

    # Save stats to CSV
    stats_filename = os.path.join(output_folder, "model_statistics.csv")
    aggregated.to_csv(stats_filename, index=False)
    print(f"Statistics saved to {stats_filename}")

In [7]:
process_checkpoints(base_folder, save_base_folder)

ResNet_CIFAR(
  273.0 k, 100.000% Params, 41.53 MMac, 99.130% MACs, 
  (conv1): Conv2d(432, 0.158% Params, 442.37 KMac, 1.056% MACs, 3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, 0.012% Params, 32.77 KMac, 0.078% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.039% MACs, inplace=True)
  (layer1): Sequential(
    14.46 k, 5.298% Params, 14.76 MMac, 35.228% MACs, 
    (0): BlockAttention(
      4.82 k, 1.766% Params, 4.92 MMac, 11.743% MACs, 
      (conv1): Conv2d(2.3 k, 0.844% Params, 2.36 MMac, 5.632% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, 0.012% Params, 32.77 KMac, 0.078% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(0, 0.000% Params, 32.77 KMac, 0.078% MACs, inplace=True)
      (conv2): Conv2d(2.3 k, 0.844% Params, 2.36 MMac, 5.632% MACs, 16, 16, kerne