In [1]:
import torch
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, TrainingArguments, Trainer
from datasets import load_dataset, ClassLabel

In [3]:
# Load the pre-trained AST model and feature extractor
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(model_name)

In [4]:
print(feature_extractor)

ASTFeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "ASTFeatureExtractor",
  "feature_size": 1,
  "max_length": 1024,
  "mean": -4.2677393,
  "num_mel_bins": 128,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000,
  "std": 4.5689974
}



In [6]:
model.config.patch_size


16

In [7]:
# binary classification
model.config.num_labels = 1  
model.classifier = torch.nn.Linear(model.config.hidden_size, 1)  

In [None]:
# Load dataset
dataset = load_dataset("your_dataset_name")  # Replace with dataset

In [None]:
# Preprocess the dataset
# This function converts each audio example to the required format using the feature extractor.
def preprocess_function(examples):
    audio = examples["audio"]  # Replace "audio" with the actual column name in your dataset
    inputs = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    return inputs

# Apply the preprocessing function to the dataset
dataset = dataset.map(preprocess_function, remove_columns=["audio_column"])

In [None]:
# training arguments
training_args = TrainingArguments(
    output_dir="./ast-finetuned-binary",  # Save directory
    evaluation_strategy="epoch",  # Evaluate the model at each epoch
    per_device_train_batch_size=4,  # Adjust batch size according to your hardware
    per_device_eval_batch_size=4,
    num_train_epochs=3,  # Number of training epochs
    learning_rate=5e-5,  # Learning rate
    save_steps=500,  # Save checkpoint every 500 steps
    logging_dir="./logs",  # Directory for logs
    logging_steps=100,  # Log every 100 steps
)


In [None]:
# Step 6: Define the trainer
# The Trainer class will handle training, evaluation, and saving of the model.
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],  # Your training dataset
    eval_dataset=dataset["test"],  # Your validation/test dataset
    tokenizer=feature_extractor,  # Not really used but can be set to the feature extractor
)

# Step 7: Fine-tune the model
trainer.train()

# Step 8: Save the fine-tuned model
trainer.save_model("./ast-finetuned-binary")

### Inference

In [None]:
# sigmoid function for probabilities
probabilities = torch.sigmoid(outputs.logits)

# Class predictions can be obtained by thresholding the probabilities (e.g., > 0.5)
predictions = (probabilities > 0.5).float()  
