In [9]:
import torch
import torch.nn as nn               
import torch.nn.functional as F       
import torch.optim as optim            
from torch.utils.data import Dataset, DataLoader 

#
import torchaudio
import torchaudio.transforms as T   
import numpy as np                     
import pandas as pd                   

import matplotlib.pyplot as plt

import os
import torchvision.transforms as T

from collections import OrderedDict, defaultdict

device = "cuda" if torch.cuda.is_available() else "cpu"


In [10]:
EPOCHS = 25
INIT_LR = 1e-3
WEIGHT_DECAY = 1e-2
BATCH_SIZE = 64
LOSS = torch.nn.CrossEntropyLoss()



#model = CNNClassifier(num_classes=206)
#optimizer = optim.AdamW(model.parameters(), lr=INIT_LR, weight_decay=1e-2)

In [12]:
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import os
import json
from collections import OrderedDict
import torchvision.transforms as T

class ViTSpectogramDataset(Dataset):
    def __init__(self, audio_dir, label_to_idx, max_cache_size=5):
        self.label_to_idx = label_to_idx
        self.audio_dir = audio_dir
        self.chunk_index_pairs = []
        self.cache = OrderedDict()
        self.max_cache_size = max_cache_size

        with open('./dataset_init.json', 'r') as file:
            data = json.load(file)

        for path in self.audio_dir:
            label = os.path.basename(path).replace(".pt", "")
            amount_of_chunks = data[label]
            for n in range(amount_of_chunks):
                self.chunk_index_pairs.append((path, label.split('_')[0], n))


    def load_cached_tensor(self, file_path):
        if file_path in self.cache:
            self.cache.move_to_end(file_path)
        else:
            tensor = torch.load(file_path)
            self.cache[file_path] = tensor
            if len(self.cache) > self.max_cache_size:
                self.cache.popitem(last=False)
        return self.cache[file_path]

    def __len__(self):
        return len(self.chunk_index_pairs)
    
    def __getitem__(self, idx):
        file_path, label, chunk_index = self.chunk_index_pairs[idx]
        tensor = self.load_cached_tensor(file_path)
        chunk = tensor[chunk_index].to(torch.float32)
        if chunk.size(0) == 1:
            chunk = chunk.repeat(3, 1, 1)
        chunk = F.interpolate(chunk.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)

        if label in self.label_to_idx:
            label_index = self.label_to_idx[label]
        else:
            print(f"[WARNING] Label '{label}' not found in label_to_idx for file {file_path}")
            raise ValueError(f"Label '{label}' not found in label_to_idx")
        return chunk, label_index


In [13]:
import os

train_files, test_files = [], []
label_to_idx = dict()
for idx, file in enumerate(sorted(os.listdir('./data/processed_train_audio'))):
    path = f'./data/processed_train_audio/{file}'
    if '_train' in file:
        train_files.append(path)
        label_to_idx[file.split('_')[0]] = idx // 2  # or use len(label_to_idx) for unique index
    elif '_test' in file:
        test_files.append(path)


In [14]:
train_dataset = ViTSpectogramDataset(train_files, label_to_idx)
test_dataset = ViTSpectogramDataset(test_files, label_to_idx)


In [None]:
print("Train dataset size:", len(train_dataset))
print("Test dataset size:", len(test_dataset))

sample, label_index = train_dataset[0]
print("Sample shape:", sample.shape)
print("Label index:", label_index)


Train dataset size: 60900
Test dataset size: 60900
Sample shape: torch.Size([3, 224, 224])
Label index: 0


In [31]:
from torch.utils.data import DataLoader

BATCH_SIZE = BATCH_SIZE  
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, persistent_workers=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, persistent_workers=False)


In [17]:
import timm
import torch.nn as nn

num_classes = len(label_to_idx)
model = timm.create_model("vit_base_patch16_224", pretrained=True)
model.head = nn.Linear(model.head.in_features, num_classes)
model = model.to(device)


  from .autonotebook import tqdm as notebook_tqdm


In [18]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)


In [30]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import torch

best_f1 = 0.0
best_model_path = "best_vit_model.pth"

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    # tqdm progress bar for training
    for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f"Training epoch {epoch}")):
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch} Train loss: {avg_train_loss:.4f}")

    # EVALUATION
    model.eval()
    all_preds = []
    all_targets = []
    # tqdm progress bar for evaluation
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc=f"Evaluating epoch {epoch}"):
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            preds = output.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(target.cpu().numpy())

    acc = accuracy_score(all_targets, all_preds)
    prec = precision_score(all_targets, all_preds, average='macro', zero_division=0)
    rec = recall_score(all_targets, all_preds, average='macro', zero_division=0)
    f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)
    print(f"Epoch {epoch}: Accuracy={acc:.4f}, Precision={prec:.4f}, Recall={rec:.4f}, F1={f1:.4f}")

    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved at epoch {epoch} with F1={f1:.4f}")


Training epoch 0:  16%|█▌        | 152/952 [26:17<2:18:25, 10.38s/it]


KeyboardInterrupt: 