# XAI Project for signature classification using CEDAR dataset

## Configurations Colab & Kaggle API

In [None]:
!git clone https://github.com/silvano315/eXplainability-for-signature-detection.git

In [None]:
import os

os.chdir("eXplainability-for-signature-detection")
os.getcwd()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Move your Kaggle API to /root/.config/kaggle and /root/.kaggle/kaggle.json

os.makedirs('/root/.kaggle', exist_ok = True)

!cp /content/drive/MyDrive/Kaggle_api/kaggle.json /root/.config/kaggle.json
!cp /content/drive/MyDrive/Kaggle_api/kaggle.json /root/.kaggle/kaggle.json

## Import libraries

In [None]:
import json
import logging
import yaml
import torch
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

from src.utils.kaggle_downloader import setup_dataset
from src.utils.logger_setup import get_logger
from src.utils.dataset_analyzer import create_dataset_metadata, validate_dataset_consistency, \
                                        save_metadata, load_metadata
from src.utils.eda import print_dataset_statistics, plot_dataset_distribution, \
                            show_sample_images, analyze_image_properties, generate_eda_report
from src.data.cedar_dataset import CEDARDataset, create_dataloaders, create_balanced_splits
from src.model.model_factory import get_available_models, validate_model_config, create_model
from src.training.experiment import Experiment
from src.training.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from src.training.trainer import ModelTrainer
from src.visualization.plot_results import scatter_plot_metrics, plot_confusion_matrix, plot_misclassified_images

## Configurations

In [None]:
# Load config

with open('config/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

In [None]:
# Setup paths

DATA_PATH = Path(config['paths']['data']['raw_path'])
MODELS_PATH = Path(config['paths']['data']['models_path'])
MODELS_PATH.mkdir(parents=True, exist_ok=True)

In [None]:
# Training parameters

BATCH_SIZE = config['training']['batch_size']
NUM_EPOCHS = config['training']['num_epochs']
LEARNING_RATE = config['training']['learning_rate']
NUM_CLASSES = config['dataset']['num_classes']

## Load data

In [None]:
# Download Malimg dataset from Kaggle

dataset_path = setup_dataset()
print(f"Dataset found at: {dataset_path}")

In [None]:
# Generate metadata for the dataset

signatures_path = Path("data/raw/cedardataset/signatures")
metadata = create_dataset_metadata(signatures_path)

In [None]:
# Save or load metadata with FLAG

SAVE_METADATA = False
dataset_metadata_path = Path("metadata/metadata.json")

if SAVE_METADATA:
    save_metadata(metadata, dataset_metadata_path)
else:
    metadata = load_metadata(dataset_metadata_path)

In [None]:
# Validate dataset consistency (statistics, class distribution, etc.)

validate_metadata = validate_dataset_consistency(metadata)
for key, value in validate_metadata.items():
    if isinstance(value, dict):
        print(f"{key}:")
        for subkey, subvalue in value.items():
            print(f"  {subkey}: {subvalue}")
    else:
      print(f"{key}: {value}")

## Exploratory Data Analysis

In [None]:
# Print a complete EDA report

signatures_path = Path("data/raw/cedardataset/signatures")
output_dir = Path("reports/eda")

generate_eda_report(signatures_path, metadata, output_dir)

In [None]:
# I you want to run the EDA report step by step

print_dataset_statistics(metadata)

In [None]:
# I you want to run the EDA report step by step

plot_dataset_distribution(metadata)

## Create Data Loaders

In [None]:
# If you haven't run it before

dataset_metadata_path = Path("metadata/metadata.json")

metadata = load_metadata(dataset_metadata_path)

In [None]:
# Update metadata with balanced splits

metadata_with_splits = create_balanced_splits(metadata)

print("\n")
print("You can see the split key updated")
for i, (key, value) in enumerate(metadata_with_splits.items()):
    print(f"{key}: {value}")
    if i == 5:
      break

In [None]:
# Create dataloaders for training, validation, and test sets

data_path = Path("data/raw/cedardataset/signatures")

dataloaders = create_dataloaders(data_path, metadata_with_splits)

In [None]:
# You can check the dataloaders for each split

for split, loader in dataloaders.items():
    print(f"{split} DataLoader:")
    for i, (images, labels) in enumerate(loader):
        print(f"  Batch {i+1}: {len(images)} images")
        print(f"  Labels: {labels}")
        print(f"  Images shape: {images.shape}")
        if i == 2:
            break
    print("\n")
    print(f"Class counts: {loader.dataset.get_class_counts()}")
    print(f"Subject info: {loader.dataset.get_subject_info()}")

## Model Definition & Setup Training

In [None]:
# Model configuration --> Baseline model

model_config = {
    'type': 'baseline',
    'num_classes': 2,
    'input_channels': 3
}

In [None]:
# Model configuration --> Transfer learning model

model_config = {
    'type': 'transfer',
    'model_name': 'resnet50',
    'num_classes': NUM_CLASSES,
    'pretrained': True,
    'use_custom_classifier': True
}

In [None]:
# Give a look at every avialable model

get_available_models()

In [None]:
# Validate configuration

validate_model_config(model_config)

In [None]:
# Create model

model = create_model(model_config)

In [None]:
# Optimizer e Loss

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=LEARNING_RATE
)

criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Setup experiment

logger = get_logger()

experiment = Experiment(
    name="signature_detection_resnet_v1",       # maybe use the name of the model
    root="/content/drive/MyDrive/XAI_ProfAI/experiments",
    logger=logger
)
experiment.init()

In [None]:
# Setup callbacks

callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=config['training']['early_stopping']['patience'],
        min_delta=config['training']['early_stopping']['min_delta'],
        verbose=True
    ),
    ModelCheckpoint(
        filepath='best_baseline_model.pth',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        optimizer=optimizer,
        mode='min',
        patience=5,
        factor=0.1,
        verbose=True
    )
]

In [None]:
# Initialize trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

trainer = ModelTrainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    experiment=experiment,
    device=device,
    logger=logger
)

In [None]:
# Train the model

trained_model = trainer.train(
    train_loader=dataloaders['train'],
    val_loader=dataloaders['val'],
    num_epochs=NUM_EPOCHS,
    callbacks=callbacks
)

torch.save(trained_model.state_dict(), experiment.root / 'final_model_resnet_v1.pth')

## Validation on Test Set

In [None]:
# Evaluation on test set

test_logs = trainer.validate(dataloaders['test'])
experiment.save_history('test', **test_logs)
logger.info(f"Test Results: {test_logs}")

In [None]:
# Get predictions on test set

test_targets, test_predictions = trainer.predict(dataloaders['test'])

plot_confusion_matrix(test_targets, test_predictions, classes = ["original", "forgery"],)
logger.info("Confusion matrix saved as 'confusion_matrix.png'")

In [None]:
# Save test results
test_results = {
    'targets': test_targets.tolist(),
    'predictions': test_predictions.tolist()
}

with open(f"{experiment.results_dir}/test_results.json", 'w') as f:
    json.dump(test_results, f)

## Analysis of Results

In [None]:
# Generate and Save plots training history

experiment.plot_history()

In [None]:
# Evaluation train and validation results

scatter_plot_metrics(f'{experiment.root}/history/train.csv',
                     f'{experiment.root}/history/val.csv')

In [None]:
# Replace missing values with 0 in column lr from val.csv and test.csv (TO BE REFACTORED)

val = pd.read_csv(f"{experiment.root}/history/val.csv")
val['lr'] = val['lr'].fillna(0).to_numpy()
val.to_csv(f"{experiment.root}/history/val.csv", index=False)

test = pd.read_csv(f"{experiment.root}/history/test.csv")
test['lr'] = test['lr'].fillna(0).to_numpy()
test.to_csv(f"{experiment.root}/history/test.csv", index=False)

In [None]:
# Calculate average metrics for last n epochs

experiment = Experiment("signature_detection_resnet_v1", "/content/drive/MyDrive/XAI_ProfAI/experiments")
experiment.load_history_from_file("val")
experiment.load_history_from_file("train")
experiment.load_history_from_file("test")

avg_metrics = experiment.calculate_average_metrics('val', last_n_epochs=5)
print("Average validation metrics:", avg_metrics)

In [None]:
# Export results in JSON

experiment.export_results_to_json("/content/drive/MyDrive/XAI_ProfAI/experiments/signature_detection_resnet_v1/results/results.json")

In [None]:
# Find best epoch according to validation accuracy

metric = 'accuracy'

best_epoch = experiment.get_best_epoch(metric, mode='max')
print(f"Best validation accuracy was achieved at epoch {best_epoch} with {100*experiment.history['val'][metric][best_epoch-1]:.1f}%")

In [None]:

# Plot learning rate

experiment.plot_learning_rate(experiment.history['train']['lr'])

In [None]:
# Plot misclassified images with ground truth and prediction

plot_misclassified_images(
    model=trained_model,
    dataloader=dataloaders['test'],
    device=device,
    num_images=16,
    class_names=["original", "forgery"],
    mean=config['preprocessing']['image']['mean'],
    std=config['preprocessing']['image']['std']
)

## eXplainability

In [None]:
from src.xai.explainer import SignatureExplainer
from src.xai.visualizer import XAIVisualizer

In [None]:
# Load model e setup

explainer = SignatureExplainer(model, device)
visualizer = XAIVisualizer()

In [None]:
# Choose an image from the test set

image = next(iter(dataloaders['test']))[0][0]
image = image.unsqueeze(0).to(device)

In [None]:
# Single explanation

grad_cam_map = explainer.grad_cam(image.unsqueeze(0), target_layer="conv3")
visualizer.plot_single_explanation(image, grad_cam_map, "Grad-CAM")

In [None]:
int_grad_map = explainer.integrated_gradients(image.unsqueeze(0))
visualizer.plot_single_explanation(image, int_grad_map, "Integrated-Gradients")

In [None]:
occ_map = explainer.occlusion_map(image.unsqueeze(0))
visualizer.plot_single_explanation(image, occ_map, "Occlusion")

In [None]:
# Compare all methods - TO BE FIXED

explanations = explainer.compare_explanations(image.unsqueeze(0), target_layer="conv3")
visualizer.plot_comparison(image, explanations)