# Model Evaluation
This notebook evaluates a fine-tuned Audio Spectrogram Transformer (AST) model on an unseen audio dataset. It loads a pre-trained model and uses its feature extractor to process raw audio directly. 

In [None]:
import torch
import evaluate
import math
import numpy as np
from sklearn.metrics import roc_auc_score
from datasets import load_dataset, Audio, Dataset
from transformers import ASTFeatureExtractor, ASTForAudioClassification, Trainer, TrainingArguments

## Configuration

In [None]:
# --- Evaluation Configuration ---
DATASET_NAME = "YOUR_EVALUATION_DATASET_NAME_OR_PATH"  # E.g., 'username/my_test_audio_dataset' or a local path
DATASET_SPLIT = 'train'
CONFIG_NAME = 'ours'

# --- Model & Cache Paths ---
MODEL_HUB_ID = "username/my_test_audio_dataset"  # Path to the directory where the fine-tuned model was saved
CACHE_DIR = './cache'

# --- Audio Processing Parameters (should match training) ---
TARGET_SAMPLE_RATE = 16000  # Hz (16kHz)
CHUNK_LENGTH_MS = 1000      # milliseconds (1 second)
CHUNK_LENGTH_SAMPLES = int(TARGET_SAMPLE_RATE * CHUNK_LENGTH_MS / 1000)


## Device Configuration

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## Load Dataset

In [None]:
def load_and_prepare_dataset(dataset_name: str) -> Dataset:
    """Load and prepare the evaluation dataset."""
    try:
        ds = load_dataset(dataset_name, CONFIG_NAME, split=DATASET_SPLIT, cache_dir=CACHE_DIR)
        # Ensure audio is at the target sample rate and mono
        ds = ds.cast_column("audio", Audio(sampling_rate=TARGET_SAMPLE_RATE, mono=True))
        print(f"Dataset {dataset_name} loaded successfully with config {CONFIG_NAME}: {ds.num_rows} examples.")
            
        return ds
    except Exception as e:
        print(f"Failed to load or prepare dataset {dataset_name}: {e}")
        raise

## Load Fine-tuned Model, Feature Extractor

In [None]:
try:
    # Load model and feature extractor from Hugging Face Hub ID
    print(f"Attempting to load model and feature extractor from Hugging Face Hub ID: {MODEL_HUB_ID}")
    model = ASTForAudioClassification.from_pretrained(MODEL_HUB_ID, cache_dir=CACHE_DIR)
    feature_extractor = ASTFeatureExtractor.from_pretrained(MODEL_HUB_ID, sampling_rate=TARGET_SAMPLE_RATE)
    model.to(DEVICE)
    model.eval()
except Exception as e:
    print(f"Error loading model/feature extractor from {MODEL_HUB_ID}: {e}")
    raise

## Preprocessing Function for Evaluation
This function processes raw audio using the loaded `ASTFeatureExtractor`. It truncates or pads audio to `CHUNK_LENGTH_SAMPLES` (1 second) before creating the spectrogram features, ensuring consistency with the training process.

In [None]:
def chunk_audio_array(audio_array: np.ndarray, chunk_length_samples: int) -> list[np.ndarray]:
    """Splits a long audio array into non-overlapping chunks of specified length."""
    num_samples = len(audio_array)
    num_chunks = math.ceil(num_samples / chunk_length_samples)
    chunks = []
    for i in range(num_chunks):
        start = i * chunk_length_samples
        end = start + chunk_length_samples
        chunk = audio_array[start:end]
        
        # Pad the last chunk if it's shorter than chunk_length_samples
        if len(chunk) < chunk_length_samples:
            padding_needed = chunk_length_samples - len(chunk)
            # Simple zero padding at the end
            chunk = np.pad(chunk, (0, padding_needed), 'constant')
        
        chunks.append(chunk)
    return chunks

In [None]:
dataset = load_and_prepare_dataset(DATASET_NAME)

In [None]:
all_true_file_labels = []
all_aggregated_file_predictions = [] # Store 0 for not_drone, 1 for drone
all_file_drone_scores = []

# For binary: 0 -> not_drone, 1 -> drone.
DRONE_CLASS_INDEX = 1

print("Starting custom evaluation loop...")
for example in dataset:
    
    audio_data = example['audio']
    raw_audio_array = audio_data['array']
    true_file_label = example['label'] 
    
    print(f"Begin processing: {audio_data['path']}")

    all_true_file_labels.append(true_file_label)
    
    audio_chunks = chunk_audio_array(raw_audio_array, CHUNK_LENGTH_SAMPLES)
    
    file_predicted_as_drone = False
    max_drone_prob_for_file_score = 0.0 

    for chunk_array in audio_chunks:
        # Preprocess the chunk
        inputs = feature_extractor(
            [chunk_array], # Feature extractor expects a list of arrays
            sampling_rate=TARGET_SAMPLE_RATE,
            max_length=CHUNK_LENGTH_SAMPLES,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        input_values = inputs.input_values.to(DEVICE)
        
        with torch.no_grad():
            logits = model(input_values).logits
        
        predicted_class_idx = torch.argmax(logits).item()
        
        if predicted_class_idx == DRONE_CLASS_INDEX:
            file_predicted_as_drone = True
        
        # For continuous score
        probabilities_chunk = torch.softmax(logits, dim=-1)[0]
        drone_prob_for_chunk = probabilities_chunk[DRONE_CLASS_INDEX].item()
        if drone_prob_for_chunk > max_drone_prob_for_file_score:
            max_drone_prob_for_file_score = drone_prob_for_chunk
            
    all_aggregated_file_predictions.append(1 if file_predicted_as_drone else 0)
    all_file_drone_scores.append(max_drone_prob_for_file_score)
    print(f"File prediction: {'Drone' if file_predicted_as_drone else 'Not Drone'}")


print("Evaluation loop finished.")
print(f"Total files processed: {len(all_true_file_labels)}")
print(f"Number of true labels: {sum(all_true_file_labels)}")
print(f"Number of predicted files: {sum(all_aggregated_file_predictions)}")

## Metrics Computation

In [None]:
# Load metrics
accuracy_metric = evaluate.load("accuracy")
recall_metric = evaluate.load("recall")
precision_metric = evaluate.load("precision")
f1_metric = evaluate.load("f1")
AVERAGE_MODE = "binary"

def compute_file_level_metrics(true_labels, predicted_labels):
    metrics = {}
    metrics.update(accuracy_metric.compute(predictions=predicted_labels, references=true_labels))
    # For binary precision/recall/f1, ensure pos_label is correctly set if default (1) isn't what you want
    # Or ensure your DRONE_CLASS_INDEX aligns with the positive label notion.
    metrics.update(precision_metric.compute(predictions=predicted_labels, references=true_labels, average=AVERAGE_MODE, pos_label=DRONE_CLASS_INDEX))
    metrics.update(recall_metric.compute(predictions=predicted_labels, references=true_labels, average=AVERAGE_MODE, pos_label=DRONE_CLASS_INDEX))
    metrics.update(f1_metric.compute(predictions=predicted_labels, references=true_labels, average=AVERAGE_MODE, pos_label=DRONE_CLASS_INDEX))
    return metrics

# Convert lists to numpy arrays for the metrics functions
true_labels_np = np.array(all_true_file_labels)
aggregated_predictions_np = np.array(all_aggregated_file_predictions)
file_scores_np = np.array(all_file_drone_scores)

print("\n--- File-Level Evaluation Results ---")
file_metrics = compute_file_level_metrics(true_labels_np, aggregated_predictions_np)
for key, value in file_metrics.items():
    print(f"{key}: {value}")

# --- AUC Calculation (using CONTINUOUS scores) ---
auc_score = roc_auc_score(true_labels_np, file_scores_np)
print(f"AUC: {auc_score:.4f}")

## Trainer Initialization and Evaluation

In [None]:
# training_args = TrainingArguments(
#     output_dir="./eval_results",
#     per_device_eval_batch_size=BATCH_SIZE,
#     do_train=False,
#     do_eval=True,
#     report_to="none"
# )