In [None]:
!pip install -U datasets evaluate optuna

In [None]:
import torch
import evaluate
import numpy as np
from datasets import load_dataset, Dataset
from transformers import ASTFeatureExtractor, ASTForAudioClassification, ASTConfig

In [None]:
DATASET_NAME = 'username/my_test_audio_dataset'
CACHE_DIR = './cache'

# --- Audio Processing Parameters ---
TARGET_SAMPLE_RATE = 16000  # Hz (16kHz)
CHUNK_LENGTH_MS = 1000      # milliseconds (1 second)
CHUNK_LENGTH_SAMPLES = int(TARGET_SAMPLE_RATE * CHUNK_LENGTH_MS / 1000)

# --- Training Hyperparameters ---
BATCH_SIZE = 32
EPOCHS = 5
LR = 3e-5

# --- Model Configuration ---
MODEL_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
CHECKPOINT_FILENAME = "ast_best_model"

# --- Hyperparameter Config ---
STUDY_NAME = "drone-audio-detection-05-17"
N_TRIALS = 10

In [None]:
def load_dataset_splits(dataset_name: str) -> Dataset:
    """Load a dataset from the Hugging Face Hub.

    Args:
        dataset_name: Name of the dataset to load

    Returns:
        Hugging Face Dataset object
    """
    try:
        dataset = load_dataset(dataset_name, cache_dir=CACHE_DIR)

        print(f"Dataset loaded successfully with splits: {list(dataset.keys())}")
        print(f"Train split size: {dataset['train'].num_rows}")
        if 'valid' in dataset:
            print(f"Validation split size: {dataset['valid'].num_rows}")
        if 'test' in dataset:
            print(f"Test split size: {dataset['test'].num_rows}")

        return dataset
    except Exception as e:
        print(f"Failed to load dataset {dataset_name}: {e}")
        raise

In [None]:
def build_transformer_model(num_classes: int, model_checkpoint: str):
    """
    Loads a pre-trained Audio Spectrogram Transformer (AST) model
    from Hugging Face for PyTorch.

    Args:
        num_classes (int): The number of output classes for the classification layer.
        model_checkpoint (str): The Hugging Face AST model identifier.

    Returns:
        torch.nn.Module: The PyTorch AST model with a potentially resized classification head.
    """
    # Load the pre-trained AST model
    # Set ignore_mismatched_sizes=True to allow replacing the classification head
    label = {'not_drone': 0, 'drone': 1}

    ast_config = ASTConfig.from_pretrained(model_checkpoint)

    ast_config.num_labels = num_classes
    ast_config.label2id = label
    ast_config.id2label = {v: k for k, v in label.items()}

    model = ASTForAudioClassification.from_pretrained(
        model_checkpoint,
        cache_dir=CACHE_DIR,
        config=ast_config,
        ignore_mismatched_sizes=True,
    )
    return model


In [None]:
def get_feature_extractor(model_checkpoint: str):
    """
    Loads the AST feature extractor.

    Args:
        model_checkpoint (str): The Hugging Face AST model identifier.

    Returns:
        ASTFeatureExtractor: The AST feature extractor instance.
    """
    # Load the feature extractor
    try:
        feature_extractor = ASTFeatureExtractor.from_pretrained(model_checkpoint)
        return feature_extractor
    except Exception as e:
        print(f"Error loading feature extractor: {e}")
        raise

In [None]:
try:
    feature_extractor = get_feature_extractor(MODEL_CHECKPOINT)
    print(f"Successfully loaded feature extractor: {feature_extractor}")
    print(f"Target sampling rate from feature extractor: {feature_extractor.sampling_rate} Hz")
    print(f"Feature extractor expects max_length: {feature_extractor.max_length} samples")
except Exception as e:
    print(f"Failed to load feature extractor: {e}", exc_info=True)
    feature_extractor = None

In [None]:
def preprocess_features(example):
    """Applies the AST feature extractor to a batch of audio data.

    This function is designed to be used with `dataset.with_transform()`.

    Args:
        batch: A dictionary representing a batch of examples from the HF dataset.
        Expected to contain an 'audio' key with audio data.

    Returns:
        A dictionary containing the processed 'input_values'.
    """
    audio_arrays = [x["array"] for x in example['input_values']]

    # Apply the feature extractor
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=CHUNK_LENGTH_SAMPLES,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )

    if "input_values" in inputs:
        example["input_values"] = inputs["input_values"]
    else:
        print(f"Feature extractor output did not contain expected keys ('input_values'). Found: {inputs.keys()}")
        raise KeyError("Could not find processed features in feature extractor output.")

    return example

In [None]:
def get_device():
    if torch.cuda.is_available():
        print("CUDA available. Using GPU.")
        return torch.device("cuda")
    else:
        print("CUDA/MPS not available. Using CPU.")
        return torch.device("cpu")

DEVICE = get_device()

In [None]:
from transformers import TrainerCallback, Trainer, TrainingArguments

class CustomCallback(TrainerCallback):

  def __init__(self, trainer) -> None:
      super().__init__()
      self._trainer = trainer

  def on_train_begin(self, args, state, control, **kwargs):
    self._trainer.init_hf_repo()

  def on_train_end(self, args, state, control, **kwargs):
    self._trainer.push_to_hub()
    torch.cuda.empty_cache()


In [None]:
accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "binary"

def compute_metrics(eval_pred):
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    return metrics

In [None]:
def model_init():
    """Initializes a new model for each Optuna trial."""
    model = build_transformer_model(num_classes=2, model_checkpoint=MODEL_CHECKPOINT)
    return model.to(DEVICE)

In [None]:
# --- Optuna Hyperparameter Space Definition ---
def optuna_hp_space(trial):
    """Defines the hyperparameter search space for Optuna."""
    return {
        "hub_model_id": f"preszzz/{STUDY_NAME}-trial-{trial.number}",
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [8, 16, 32]),
        "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True),
        "warmup_ratio": trial.suggest_float("warmup_ratio", 0.0, 0.2),
        "lr_scheduler_type": trial.suggest_categorical("lr_scheduler_type", ["linear", "cosine", "polynomial"]),
        "max_grad_norm": trial.suggest_float("max_grad_norm", 0.1, 1.0),
        "optim": trial.suggest_categorical("optim", ["adamw_torch", "adafactor", "adamw_torch_fused"])
    }

In [None]:
def optuna_hp_name(trial):
    return f"{STUDY_NAME}_trial_{trial.number}"

In [None]:
ds = load_dataset_splits(DATASET_NAME)

In [None]:
ds = ds.rename_column('audio', 'input_values')
processed_datasets = ds.with_transform(preprocess_features)

In [None]:
args = TrainingArguments(
    eval_strategy = "epoch",
    save_strategy = "epoch",
    num_train_epochs=EPOCHS,
    logging_steps=10,
    metric_for_best_model="accuracy",
    gradient_accumulation_steps=4,
    load_best_model_at_end=True,
    fp16=True,
    disable_tqdm=False
    # push_to_hub=True,
    # save_total_limit=1
)

In [None]:
# Setup the trainer
trainer = Trainer(
    model_init=model_init,
    args=args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["test"],
    processing_class=feature_extractor,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.add_callback(CustomCallback(trainer))

In [None]:
best_trial_results = trainer.hyperparameter_search(
    direction="maximize",       # We want to maximize accuracy
    backend="optuna",
    hp_space=optuna_hp_space,
    hp_name=optuna_hp_name,
    n_trials=N_TRIALS,
    study_name=STUDY_NAME
)

In [None]:
print(best_trial_results)