In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from utils import DEVICE
from data_loader import get_dataloaders
from models import get_model
from training import train_in_stages
from colorspace import convert_and_normalize

from analysis.precompute import compute_all_analysis_outputs
from analysis.plots_full import run_all_full_plots
from analysis.summary import build_summary_table

In [None]:
# -------------------------------
# EXPERIMENT SETTINGS
# -------------------------------

DATASET = "CIFAR10"        # "CIFAR10", "HAM10000", "KANSAS"
MODEL_TYPE = "resnet18"   # "mediumcnn", "shallowcnn", "efficientnet"

COLOR_SPACES = ["rgb", "lab", "hsv", "yuv", "ycrcb", "xyz"]

EPOCHS_TOTAL = 20
FIRST_STAGE_EPOCHS = 5
BATCH_SIZE = 16
IMG_SIZE = 32

print(f"Dataset: {DATASET}")
print(f"Model: {MODEL_TYPE}")
print(f"Color spaces: {COLOR_SPACES}")

train_loader, val_loader, test_loader, CLASS_NAMES, class_weights = \
    get_dataloaders(DATASET, batch_size=BATCH_SIZE, img_size=IMG_SIZE)

print("Classes:", CLASS_NAMES)
print("Device:", DEVICE)

In [None]:
# -------------------------------
# TRAIN MODELS FOR EACH COLOR SPACE
# -------------------------------

results = {}
models_by_space = {}

for space in COLOR_SPACES:
    print("\n==========================================")
    print(f"Training for color space: {space.upper()}")
    print("==========================================")

    model = get_model(MODEL_TYPE, num_classes=len(CLASS_NAMES))

    stats = train_in_stages(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        first_stage_epochs=FIRST_STAGE_EPOCHS,
        final_epochs=EPOCHS_TOTAL,
        color_space=space,
        class_weights=class_weights,
    )

    results[space] = stats
    models_by_space[space] = model

print("\n Training completed for all color spaces!")

In [None]:
# -------------------------------
# RUN ALL ANALYSIS IN ONE CALL
# -------------------------------
import importlib
import analysis.gradcam as gradcam
importlib.reload(gradcam)

import analysis.precompute as precompute
importlib.reload(precompute)

import analysis.plots_full as plots_full
importlib.reload(plots_full)

import analysis.precompute as precompute
importlib.reload(precompute)
from analysis.precompute import compute_all_analysis_outputs




analysis_outputs = compute_all_analysis_outputs(
    models_by_space=models_by_space,
    results=results,
    train_loader=train_loader,
    test_loader=test_loader,
    class_names=CLASS_NAMES,
    color_spaces=COLOR_SPACES,
    convert_fn=convert_and_normalize,
)

color_complexity = analysis_outputs["color_complexity"]
texture_complexity = analysis_outputs["texture_complexity"]
confusions              = analysis_outputs["confusions"]
per_class_acc           = analysis_outputs["per_class_acc"]
embedding_stability     = analysis_outputs["embedding_stability"]
color_complexity        = analysis_outputs["color_complexity"]
texture_complexity      = analysis_outputs["texture_complexity"]
color_sensitivity       = analysis_outputs["color_sensitivity"]
kernel_stats_by_space   = analysis_outputs["kernel_stats"]
kernel_similarity_to_rgb= analysis_outputs["kernel_similarity"]
complexity_metrics      = analysis_outputs["complexity_metrics"]
pruning_sensitivity     = analysis_outputs["pruning_sensitivity"]
gcam_iou_to_rgb         = analysis_outputs["gcam_iou"]
attention_entropy       = analysis_outputs["attention_entropy"]
feature_redundancy      = analysis_outputs["feature_redundancy"]
deltaE_stats_all        = analysis_outputs["deltaE_stats"]
ssim_stats_all          = analysis_outputs["ssim_stats"]


In [None]:
# -------------------------------
# PLOTS (CORE + FULL)
# -------------------------------

run_all_full_plots(
    results,
    COLOR_SPACES,
    CLASS_NAMES,
    confusions,
    per_class_acc,
    embedding_stability,
    color_complexity,
    texture_complexity,
    color_sensitivity,
    kernel_stats_by_space,
    kernel_similarity_to_rgb,
    complexity_metrics,
    pruning_sensitivity,
    gcam_iou_to_rgb,
    attention_entropy,
    feature_redundancy,
    models_by_space,
    test_loader,
    gradcam_sample_indices=(0, 5, 10),
)

In [None]:
summary_df = build_summary_table(
    color_spaces=COLOR_SPACES,
    results=results,
    analysis_outputs=analysis_outputs,
)

summary_df.to_csv("results/summary_table_full_metrics.csv", index=False)

summary_df