In [None]:
# !pip install evaluate datasets

In [2]:
import os
import torch
import evaluate
import numpy as np
from datasets import load_dataset, Dataset, Audio
from transformers import ASTFeatureExtractor, ASTForAudioClassification, ASTConfig, Trainer, TrainingArguments

In [3]:
DATASET_NAME = ''
CACHE_DIR = './cache'
ROOT_DIR = './'

# --- Audio Processing Parameters ---
TARGET_SAMPLE_RATE = 16000  # Hz (16kHz)
CHUNK_LENGTH_MS = 1000      # milliseconds (1 second)
CHUNK_LENGTH_SAMPLES = int(TARGET_SAMPLE_RATE * CHUNK_LENGTH_MS / 1000)

# --- Training Hyperparameters ---
BATCH_SIZE = 32
EPOCHS = 5
LR = 3e-5

# --- Model Configuration ---
MODEL_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
MODEL_SAVE_DIR = os.path.join(ROOT_DIR, 'output_models')
CHECKPOINT_FILENAME = "ast_best_model"

In [4]:
def load_dataset_splits(dataset_name: str) -> Dataset:
    """Load a dataset from the Hugging Face Hub.

    Args:
        dataset_name: Name of the dataset to load

    Returns:
        Hugging Face Dataset object
    """
    try:
        dataset = load_dataset(dataset_name, cache_dir=CACHE_DIR)
        dataset = dataset.cast_column("audio", Audio(sampling_rate=TARGET_SAMPLE_RATE, mono=True))

        print(f"Dataset loaded successfully with splits: {list(dataset.keys())}")
        print(f"Train split size: {dataset['train'].num_rows}")
        if 'valid' in dataset:
            print(f"Validation split size: {dataset['valid'].num_rows}")
        if 'test' in dataset:
            print(f"Test split size: {dataset['test'].num_rows}")

        return dataset
    except Exception as e:
        print(f"Failed to load dataset {dataset_name}: {e}")
        raise

In [5]:
def build_transformer_model(num_classes: int, model_checkpoint: str):
    """
    Loads a pre-trained Audio Spectrogram Transformer (AST) model
    from Hugging Face for PyTorch.

    Args:
        num_classes (int): The number of output classes for the classification layer.
        model_checkpoint (str): The Hugging Face AST model identifier.

    Returns:
        torch.nn.Module: The PyTorch AST model with a potentially resized classification head.
    """
    # Load the pre-trained AST model
    # Set ignore_mismatched_sizes=True to allow replacing the classification head
    label = {'not_drone': 0, 'drone': 1}

    ast_config = ASTConfig.from_pretrained(model_checkpoint)

    ast_config.num_labels = num_classes
    ast_config.label2id = label
    ast_config.id2label = {v: k for k, v in label.items()}

    model = ASTForAudioClassification.from_pretrained(
        model_checkpoint,
        cache_dir=CACHE_DIR,
        config=ast_config,
        ignore_mismatched_sizes=True,
    )
    return model

def get_feature_extractor(model_checkpoint: str):
    """
    Loads the AST feature extractor.

    Args:
        model_checkpoint (str): The Hugging Face AST model identifier.

    Returns:
        ASTFeatureExtractor: The AST feature extractor instance.
    """
    # Load the feature extractor
    try:
        feature_extractor = ASTFeatureExtractor.from_pretrained(
            model_checkpoint,
            sampling_rate=TARGET_SAMPLE_RATE
        )
        return feature_extractor
    except Exception as e:
        print(f"Error loading feature extractor: {e}")
        raise

In [None]:
try:
    feature_extractor = get_feature_extractor(MODEL_CHECKPOINT)
    print(f"Successfully loaded feature extractor: {feature_extractor}")
    print(f"Target sampling rate from feature extractor: {feature_extractor.sampling_rate} Hz")
    print(f"Feature extractor expects max_length: {feature_extractor.max_length} samples")
except Exception as e:
    print(f"Failed to load feature extractor: {e}", exc_info=True)
    feature_extractor = None

In [7]:
def preprocess_features(example):
    """Applies the AST feature extractor to a batch of audio data.

    This function is designed to be used with `dataset.with_transform()`.

    Args:
        batch: A dictionary representing a batch of examples from the HF dataset.
        Expected to contain an 'audio' key with audio data.

    Returns:
        A dictionary containing the processed 'input_values'.
    """
    # Ensure audio data is in the expected format (list of numpy arrays)
    audio_arrays = [x["array"] for x in example['input_values']]

    # Apply the feature extractor
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=CHUNK_LENGTH_SAMPLES,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )

    if "input_values" in inputs:
        example["input_values"] = inputs["input_values"]
    else:
        print(f"Feature extractor output did not contain expected keys ('input_values'). Found: {inputs.keys()}")
        # Handle error appropriately, maybe return None or raise exception
        raise KeyError("Could not find processed features in feature extractor output.")

    return example

In [17]:
accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "binary"

def compute_metrics(eval_pred):
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    return metrics

In [8]:
def get_device():
    if torch.cuda.is_available():
        print("CUDA available. Using GPU.")
        return torch.device("cuda")
    else:
        print("CUDA/MPS not available. Using CPU.")
        return torch.device("cpu")

DEVICE = get_device()

CUDA available. Using GPU.


In [None]:
ds = load_dataset_splits(dataset_name=DATASET_NAME)

Dataset loaded successfully with splits: ['train', 'valid', 'test']
Train split size: 585180
Validation split size: 32510
Test split size: 32511

In [10]:
ds = ds.rename_column('audio', 'input_values')
processed_datasets = ds.with_transform(preprocess_features)

In [None]:
model = build_transformer_model(num_classes=2, model_checkpoint=MODEL_CHECKPOINT)
model.to(DEVICE)

In [16]:
args = TrainingArguments(
    output_dir="test-train-model",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    logging_steps=10,
    gradient_accumulation_steps=4,
    warmup_ratio=0.1,
    metric_for_best_model="accuracy",
    load_best_model_at_end=True,
    fp16=True,
    disable_tqdm=False,
    # push_to_hub=True,
    # tpu_num_cores=48
)


In [18]:
# Setup the trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["valid"],
    processing_class=feature_extractor,
    compute_metrics=compute_metrics,  # Use the metrics function from above
)

In [19]:
trainer.train()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mzh-preston[0m ([33mzh-preston-queen-s-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.001,0.003579,0.998831,0.999057,0.99945,0.999253
2,0.0011,0.003091,0.999108,0.999725,0.999135,0.99943
3,0.002,0.001945,0.999539,0.999646,0.999764,0.999705
4,0.0,0.001803,0.999723,0.999803,0.999843,0.999823


TrainOutput(global_step=22855, training_loss=0.004115778539612579, metrics={'train_runtime': 27434.7446, 'train_samples_per_second': 106.649, 'train_steps_per_second': 0.833, 'total_flos': 1.9828487041060543e+20, 'train_loss': 0.004115778539612579, 'epoch': 4.998961010553946})

In [20]:
trainer.save_model(MODEL_SAVE_DIR)

In [21]:
trainer.evaluate()

{'eval_loss': 0.0018028703052550554,
 'eval_accuracy': 0.999723162103968,
 'eval_precision': 0.9998034745696093,
 'eval_recall': 0.9998427734758853,
 'eval_f1': 0.999823123636578,
 'eval_runtime': 143.347,
 'eval_samples_per_second': 226.792,
 'eval_steps_per_second': 7.088,
 'epoch': 4.998961010553946}

In [None]:
trainer.push_to_hub()