In [38]:
import torch
import torch.nn as nn               
import torch.nn.functional as F       
import torch.optim as optim            
from torch.utils.data import Dataset, DataLoader 
import timm 
#
import torchaudio
import torchaudio.transforms as T   
import numpy as np                     
import pandas as pd                   

import matplotlib.pyplot as plt

import os
import torchvision.transforms as T

from collections import OrderedDict, defaultdict

device = "cuda" if torch.cuda.is_available() else "cpu"


# Hyperparameters

In [39]:
EPOCHS = 25
INIT_LR = 1e-3
WEIGHT_DECAY = 1e-2
BATCH_SIZE = 64
LOSS = torch.nn.CrossEntropyLoss()



# Dataset init

In [40]:
# import torch
# from torch.utils.data import Dataset
# import torch.nn.functional as F
# import os
# import json
# from collections import OrderedDict
# import torchvision.transforms as T

# class AsTSpectogramDataset(Dataset):
#     def __init__(self, audio_dir, label_to_idx, max_cache_size=5):
#         self.label_to_idx = label_to_idx
#         self.audio_dir = audio_dir
#         self.chunk_index_pairs = []
#         self.cache = OrderedDict()
#         self.max_cache_size = max_cache_size

#         with open('./dataset_init.json', 'r') as file:
#             data = json.load(file)

#         for path in self.audio_dir:
#             label = os.path.basename(path).replace(".pt", "")
#             amount_of_chunks = data[label]
#             for n in range(amount_of_chunks):
#                 self.chunk_index_pairs.append((path, label.split('_')[0], n))


#     def load_cached_tensor(self, file_path):
#         if file_path in self.cache:
#             self.cache.move_to_end(file_path)
#         else:
#             tensor = torch.load(file_path)
#             self.cache[file_path] = tensor
#             if len(self.cache) > self.max_cache_size:
#                 self.cache.popitem(last=False)
#         return self.cache[file_path]

#     def __len__(self):
#         return len(self.chunk_index_pairs)
        
#     def __getitem__(self, idx):
#         file_path, label, chunk_index = self.chunk_index_pairs[idx]
#         tensor = self.load_cached_tensor(file_path)
        
#         chunk = tensor[chunk_index].to(torch.float32)
#         if chunk.size(0) == 3:
#             chunk = chunk[0].unsqueeze(0)  # Von [3, H, W] zu [1, H, W]

#         chunk = F.interpolate(chunk.unsqueeze(0), size=(128, 1024), mode='bilinear', align_corners=False).squeeze(0)

#         if label in self.label_to_idx:
#             label_index = self.label_to_idx[label]
#         else:
#             raise ValueError(f"Label '{label}' not found in label_to_idx")
#         print("Final shape:", chunk.shape)  # Sollte [1, 128, 1024] sein
#         return chunk, label_index

In [None]:
import torch
from torch.utils.data import Dataset
import torchaudio

class ASTRawAudioDataset(Dataset):
    def __init__(self, file_label_pairs, label_to_idx, feature_extractor, target_sr=16000):
        self.file_label_pairs = file_label_pairs
        self.label_to_idx = label_to_idx
        self.feature_extractor = feature_extractor
        self.target_sr = target_sr

    def __len__(self):
        return len(self.file_label_pairs)

    def __getitem__(self, idx):
        file_path, label = self.file_label_pairs[idx]
        waveform, sr = torchaudio.load(file_path)

        if sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sr)
            waveform = resampler(waveform)

        waveform = waveform.mean(dim=0)  


        inputs = self.feature_extractor(
            [waveform.numpy()],
            sampling_rate=self.target_sr,
            return_tensors="pt",
            padding=True
        )

        input_values = inputs["input_values"].squeeze(0)  
        label_index = self.label_to_idx[label]

        return input_values, label_index

In [None]:
#versucht aber hat ned geklappt 
import torch
from transformers import AutoFeatureExtractor, ASTForAudioClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


extractor_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = AutoFeatureExtractor.from_pretrained(
    extractor_name,
    trust_remote_code=True
)


feature_extractor.return_attention_mask = True
feature_extractor.max_length = 1024  
feature_extractor.padding_value = 0.0


def collate_fn(batch):
    waveforms, labels = zip(*batch)  

    inputs = feature_extractor(
        list(waveforms),               
        sampling_rate=16000,           
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=1024
    )

    input_values = inputs["input_values"].to(device)
    labels = torch.tensor(labels).to(device)
    return input_values, labels

model = ASTForAudioClassification.from_pretrained(
    extractor_name,
    num_labels=len(label_to_idx),
    ignore_mismatched_sizes=True
).to(device)

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([206]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([206, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Train, Test split

In [None]:
import os
import random

base_dir = './data/train_audio'
label_to_idx = {}
train_files = []
test_files = []
test_split_ratio = 0.2  

all_labels = sorted(os.listdir(base_dir))

for idx, label in enumerate(all_labels):
    label_dir = os.path.join(base_dir, label)
    if not os.path.isdir(label_dir):
        continue  

    label_to_idx[label] = idx
    audio_files = [os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.ogg')]

    #split n shuffle
    random.shuffle(audio_files)
    split_point = int(len(audio_files) * (1 - test_split_ratio))
    train_files.extend([(f, label) for f in audio_files[:split_point]])
    test_files.extend([(f, label) for f in audio_files[split_point:]])

# DataLoaders

In [82]:
train_dataset = ASTRawAudioDataset(train_files, label_to_idx, feature_extractor)
test_dataset = ASTRawAudioDataset(test_files, label_to_idx, feature_extractor)

In [83]:
from torch.utils.data import DataLoader

BATCH_SIZE = BATCH_SIZE
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [84]:
print("Train dataset size:", len(train_dataset))
print("Test dataset size:", len(test_dataset))

sample, label_index = train_dataset[0]
print("Sample shape:", sample.shape)
print("Label index:", label_index)


Train dataset size: 22777
Test dataset size: 5787
Sample shape: torch.Size([1024, 128])
Label index: 0


In [85]:
import torch.nn as nn
from torch.optim import AdamW

loss_fn = nn.CrossEntropyLoss()

optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import torch

best_f1 = 0.0
best_model_path = "best_ast_model.pth"

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0

    for batch_idx, (waveform, target) in enumerate(tqdm(train_loader, desc=f"Training epoch {epoch}")):
        waveform = waveform.to(device)           
        target = target.to(device)               

       
        inputs = feature_extractor(
            waveform,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        ).to(device)

   
        optimizer.zero_grad()
        outputs = model(**inputs)
        logits = outputs.logits


        loss = loss_fn(logits, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch} Train loss: {avg_train_loss:.4f}")


    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for waveform, target in tqdm(test_loader, desc=f"Evaluating epoch {epoch}"):
            waveform = waveform.to(device)
            target = target.to(device)

            inputs = feature_extractor(
                waveform,
                sampling_rate=16000,
                return_tensors="pt",
                padding=True
            ).to(device)

            outputs = model(**inputs)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).cpu().numpy()

            all_preds.extend(preds)
            all_targets.extend(target.cpu().numpy())

    acc = accuracy_score(all_targets, all_preds)
    prec = precision_score(all_targets, all_preds, average='macro', zero_division=0)
    rec = recall_score(all_targets, all_preds, average='macro', zero_division=0)
    f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

    print(f"Epoch {epoch}: Accuracy={acc:.4f}, Precision={prec:.4f}, Recall={rec:.4f}, F1={f1:.4f}")

    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved at epoch {epoch} with F1={f1:.4f}")

Training epoch 0:   0%|          | 0/356 [00:02<?, ?it/s]


AssertionError: choose a window size 400 that is [2, 64]