In [1]:
from transformers import AutoFeatureExtractor, ASTForAudioClassification
from datasets import load_dataset
import matplotlib.pyplot as plt
import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate
print(sampling_rate)
feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

# audio file is decoded on the fly
inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_ids = torch.argmax(logits, dim=-1).item()
predicted_label = model.config.id2label[predicted_class_ids]
predicted_label
print (predicted_label)

# compute loss - target_label is e.g. "down"
target_label = model.config.id2label[0]
print (target_label)
inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
loss = model(**inputs).loss
round(loss.item(), 2)

  from .autonotebook import tqdm as notebook_tqdm


16000
Speech
Speech


0.17

In [100]:
inputs

{'input_values': tensor([[[-1.2776, -1.2776, -1.2776,  ..., -1.2776, -1.2776, -1.2776],
         [-1.2776, -1.2776, -1.2776,  ..., -1.2776, -1.2776, -1.2776],
         [-1.2776, -1.2776, -1.2776,  ..., -1.2776, -1.2776, -1.2776],
         ...,
         [ 0.4185,  0.0726,  0.4494,  ...,  0.5430,  0.4019,  0.4213],
         [ 0.3091,  0.0120,  0.3888,  ...,  0.5611,  0.4199,  0.4125],
         [ 0.3717,  0.1178,  0.4946,  ...,  0.6881,  0.4622,  0.3717]]])}

In [3]:
from transformers import ASTForAudioClassification, ASTFeatureExtractor
import torch
import librosa
import numpy as np

# Load the feature extractor and model
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)
model = ASTForAudioClassification.from_pretrained(model_name)


In [16]:
import torchaudio

# Load an audio file (replace 'audio.wav' with your file)
# audio_path = r"C:\Users\pepij\OneDrive - Delft University of Technology\THESIS\data\WAV_Groningen_1\WAV_Groningen_1\Noorderplantsoen\NP142.wav"
audio_path = r"C:\Users\pepij\OneDrive - Delft University of Technology\THESIS\data\10672568 (1)\all_wav\TS418.wav"
# audio_path = r"C:\Users\pepij\.cache\huggingface\datasets\downloads\extracted\d5bf28a8657c1072a648bb608e048a0064c56318d0909ed1ba42d92596386abf\dev_clean\1272\141231\1272-141231-0016.flac"
# audio_path = r"C:\Users\pepij\Downloads\6jiO0tPLK7U_000090.flac"
# audio_path = r"C:\Users\pepij\Downloads\glLQrEijrKg_000300.flac"
# audio_path = r"C:\Users\pepij\Downloads\Labrador Dog Barking Sound - The SOund ButtOn.mp3"
# audio_path = r"C:\Users\pepij\Downloads\sound effect party Time.mp3"
waveform, sample_rate = torchaudio.load(audio_path)

print('before: ', waveform, sample_rate)

# Convert to mono and resample to 16 kHz if needed
waveform = waveform.mean(dim=0)  # Convert to mono
if sample_rate != 16000:
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
print('after:', waveform)
# Extract features
inputs = feature_extractor(waveform.numpy(), sampling_rate=16000, return_tensors="pt")




before:  tensor([[ 0.0114,  0.0008, -0.0023,  ..., -0.0925, -0.0984, -0.1039],
        [-0.0501, -0.0554, -0.0579,  ..., -0.1169, -0.1267, -0.1331]]) 48000
after: tensor([-0.0159, -0.0332, -0.0416,  ..., -0.0766, -0.1076, -0.0771])


In [17]:
print(inputs['input_values'].shape) 
print(inputs)

torch.Size([1, 1024, 128])
{'input_values': tensor([[[-0.1945, -0.7048, -0.3279,  ..., -0.0445,  0.0113, -0.0500],
         [-0.2524, -0.5661, -0.1893,  ..., -0.1017, -0.0756, -0.0289],
         [-0.0648, -0.3554,  0.0214,  ..., -0.0062,  0.0460,  0.0109],
         ...,
         [-0.1665, -0.4081, -0.0312,  ..., -0.3489, -0.3885, -0.2701],
         [-0.0134, -0.3589,  0.0180,  ..., -0.2585, -0.3129, -0.3194],
         [ 0.0587, -0.3125,  0.0643,  ..., -0.2951, -0.3233, -0.3755]]])}


In [18]:
# Run the model
with torch.no_grad():
    outputs = model(**inputs)

# Get class probabilities
logits = outputs.logits
# probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
probs = torch.sigmoid(logits).squeeze()
# predicted_label = torch.argmax(probs, dim=-1).item()

# Get the top 10 predictions
top_probs, top_indices = torch.topk(probs, 10)

# Load label names
id2label = model.config.id2label

# Print top 10 predictions with probabilities
print("Top 10 Predictions:")
for i in range(10):
    label = id2label[top_indices[i].item()]
    probability = top_probs[i].item()
    print(f"{label}: {probability:.4f}")


Top 10 Predictions:
Speech: 0.4165
Bird: 0.2337
Animal: 0.1782
Pigeon, dove: 0.1461
Outside, urban or manmade: 0.0910
Outside, rural or natural: 0.0894
Coo: 0.0819
Bird vocalization, bird call, bird song: 0.0760
Vehicle: 0.0698
Walk, footsteps: 0.0623


In [10]:
# Load label names
id2label = model.config.id2label

# Print prediction
print(f"Predicted label: {id2label[predicted_label]}")
# print(id2label)

KeyError: 'Speech'

In [12]:
id2label

{0: 'Speech',
 1: 'Male speech, man speaking',
 2: 'Female speech, woman speaking',
 3: 'Child speech, kid speaking',
 4: 'Conversation',
 5: 'Narration, monologue',
 6: 'Babbling',
 7: 'Speech synthesizer',
 8: 'Shout',
 9: 'Bellow',
 10: 'Whoop',
 11: 'Yell',
 12: 'Battle cry',
 13: 'Children shouting',
 14: 'Screaming',
 15: 'Whispering',
 16: 'Laughter',
 17: 'Baby laughter',
 18: 'Giggle',
 19: 'Snicker',
 20: 'Belly laugh',
 21: 'Chuckle, chortle',
 22: 'Crying, sobbing',
 23: 'Baby cry, infant cry',
 24: 'Whimper',
 25: 'Wail, moan',
 26: 'Sigh',
 27: 'Singing',
 28: 'Choir',
 29: 'Yodeling',
 30: 'Chant',
 31: 'Mantra',
 32: 'Male singing',
 33: 'Female singing',
 34: 'Child singing',
 35: 'Synthetic singing',
 36: 'Rapping',
 37: 'Humming',
 38: 'Groan',
 39: 'Grunt',
 40: 'Whistling',
 41: 'Breathing',
 42: 'Wheeze',
 43: 'Snoring',
 44: 'Gasp',
 45: 'Pant',
 46: 'Snort',
 47: 'Cough',
 48: 'Throat clearing',
 49: 'Sneeze',
 50: 'Sniff',
 51: 'Run',
 52: 'Shuffle',
 53: 'Walk