# XAI Experiments - Notebook para Testes Individuais

Este notebook permite testar individualmente os componentes do pipeline XAI.

## 1. Configuração e Imports

In [None]:
import sys
sys.path.insert(0, './scripts')

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from config import (
    N_SAMPLES, DATA_DIR, RESULTS_DIR, 
    VIT_XAI_METHODS, CNN_XAI_METHODS,
    EMOTION_CLASSES, print_config, get_device
)

print_config()

## 2. Carregar Dados

In [None]:
from data_loader import load_dataset, get_all_images

# Carrega algumas imagens para teste
df = load_dataset(n_samples=7)  # 1 por classe
print(f"\nImagens carregadas: {len(df)}")
df.head()

In [None]:
# Seleciona uma imagem para testar
img_path = df.iloc[0]['path']
true_label = df.iloc[0]['label']

img = Image.open(img_path)
plt.figure(figsize=(4, 4))
plt.imshow(img)
plt.title(f"Label: {true_label}")
plt.axis('off')
plt.show()

## 3. Carregar Modelo ViT

In [None]:
from vit import load_vit_model, build_transform_from_convnext, run_xai_on_image

model_vit, cfg_vit, device = load_vit_model()
transform_vit, _, _, _ = build_transform_from_convnext()

## 4. Testar XAI no ViT

In [None]:
pil_img, pred_idx, conf, maps_vit = run_xai_on_image(
    img_path, model_vit, transform_vit, device, methods=tuple(VIT_XAI_METHODS)
)

from utils import get_label_name
print(f"Predição: {get_label_name(pred_idx)} (confiança: {conf:.2%})")
print(f"Métodos XAI: {list(maps_vit.keys())}")

In [None]:
# Visualizar heatmaps ViT
from visualization import save_xai_visualization

n = len(maps_vit)
fig, axes = plt.subplots(1, n + 1, figsize=(4 * (n + 1), 4))

img_resized = pil_img.resize((224, 224))
axes[0].imshow(img_resized)
axes[0].set_title('Original')
axes[0].axis('off')

for i, (name, hm) in enumerate(maps_vit.items(), 1):
    axes[i].imshow(img_resized)
    axes[i].imshow(hm, alpha=0.35, cmap='turbo')
    axes[i].set_title(f'ViT - {name}')
    axes[i].axis('off')

plt.suptitle(f"ViT: {get_label_name(pred_idx)} ({conf:.1%})")
plt.tight_layout()
plt.show()

## 5. Carregar Modelo CNN

In [None]:
from cnn import load_cnn_model, build_cnn_transform, run_xai_on_image_cnn

model_cnn, data_config, device = load_cnn_model()
transform_cnn, _, _, _ = build_cnn_transform(data_config)

## 6. Testar XAI na CNN

In [None]:
pil_img_cnn, pred_idx_cnn, conf_cnn, maps_cnn = run_xai_on_image_cnn(
    img_path, model_cnn, transform_cnn, device, methods=tuple(CNN_XAI_METHODS)
)

print(f"Predição CNN: {get_label_name(pred_idx_cnn)} (confiança: {conf_cnn:.2%})")
print(f"Métodos XAI: {list(maps_cnn.keys())}")

In [None]:
# Visualizar heatmaps CNN
n = len(maps_cnn)
fig, axes = plt.subplots(1, n + 1, figsize=(4 * (n + 1), 4))

img_resized = pil_img_cnn.resize((224, 224))
axes[0].imshow(img_resized)
axes[0].set_title('Original')
axes[0].axis('off')

for i, (name, hm) in enumerate(maps_cnn.items(), 1):
    axes[i].imshow(img_resized)
    axes[i].imshow(hm, alpha=0.35, cmap='turbo')
    axes[i].set_title(f'CNN - {name}')
    axes[i].axis('off')

plt.suptitle(f"CNN: {get_label_name(pred_idx_cnn)} ({conf_cnn:.1%})")
plt.tight_layout()
plt.show()

## 7. Calcular Métricas

In [None]:
from metrics import compute_all_metrics
from config import MEAN, STD

# Métricas para um heatmap específico
hm_vit = maps_vit['Rollout']  # Escolha um método

metrics = compute_all_metrics(
    model_vit, pil_img, hm_vit, device, MEAN, STD,
    model_type='vit', true_label_idx=df.iloc[0]['label_idx']
)

print("Métricas de Fidelidade e Localidade:")
for k, v in metrics.items():
    if isinstance(v, float):
        print(f"  {k}: {v:.4f}")
    else:
        print(f"  {k}: {v}")

## 8. Comparação ViT vs CNN

In [None]:
from visualization import save_comparison_visualization

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

img_np = np.array(pil_img.resize((224, 224)))

axes[0].imshow(img_np)
axes[0].set_title(f'Original\n{true_label}')
axes[0].axis('off')

axes[1].imshow(img_np)
axes[1].imshow(maps_vit['Rollout'], alpha=0.4, cmap='turbo')
axes[1].set_title(f'ViT (Rollout)\n{get_label_name(pred_idx)} {conf:.1%}')
axes[1].axis('off')

axes[2].imshow(img_np)
axes[2].imshow(maps_cnn['LayerCAM'], alpha=0.4, cmap='turbo')
axes[2].set_title(f'CNN (LayerCAM)\n{get_label_name(pred_idx_cnn)} {conf_cnn:.1%}')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 9. Pipeline Completo (Opcional)

In [None]:
# Descomente para rodar o pipeline completo em poucas imagens

# from main import run_full_analysis
# 
# results = run_full_analysis(
#     n_samples=7,
#     models=['vit', 'cnn'],
#     save_heatmaps=True,
#     generate_plots=True,
#     verbose=True
# )
# results.head()

## 10. Limpeza de Memória

In [None]:
import torch
import gc

# Libera memória GPU
del model_vit, model_cnn
gc.collect()
torch.cuda.empty_cache()
print("Memória liberada!")