In [2]:
import torch
import torch.nn as nn               
import torch.nn.functional as F       
import torch.optim as optim            
from torch.utils.data import Dataset, DataLoader 

#
import torchaudio
import torchaudio.transforms as T   
import numpy as np                     
import pandas as pd                   

import matplotlib.pyplot as plt

import os
from collections import OrderedDict, defaultdict

device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
EPOCHS = 25
INIT_LR = 1e-3
WEIGHT_DECAY = 1e-2
BATCH_SIZE = 512
LOSS = torch.nn.BCEWithLogitsLoss()

optimizer = optim.AdamW(model.parameters(), lr=INIT_LR, weight_decay=1e-2)


#model = CNNClassifier(num_classes=206)
#optimizer = optim.AdamW(model.parameters(), lr=INIT_LR, weight_decay=1e-2)

In [3]:
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import os
import json
from collections import OrderedDict
import torchvision.transforms as T

class ViTSpectogramDataset(Dataset):
    def __init__(self, audio_dir, label_to_idx, max_cache_size=5):
        self.label_to_idx = label_to_idx
        self.audio_dir = audio_dir
        self.chunk_index_pairs = []
        self.cache = OrderedDict()
        self.max_cache_size = max_cache_size

        with open('./dataset_init.json', 'r') as file:
            data = json.load(file)

        for path in self.audio_dir:
            label = os.path.basename(path).replace(".pt", "")
            amount_of_chunks = data[label]
            for n in range(amount_of_chunks):
                self.chunk_index_pairs.append((path, label.split('_')[0], n))

        # ViT expects 224x224 and 3 channels
        self.resize = T.Resize((224, 224))

    def load_cached_tensor(self, file_path):
        if file_path in self.cache:
            self.cache.move_to_end(file_path)
        else:
            tensor = torch.load(file_path)
            self.cache[file_path] = tensor
            if len(self.cache) > self.max_cache_size:
                self.cache.popitem(last=False)
        return self.cache[file_path]

    def __len__(self):
        return len(self.chunk_index_pairs)
    
    def __getitem__(self, idx):
        file_path, label, chunk_index = self.chunk_index_pairs[idx]
        tensor = self.load_cached_tensor(file_path)
        chunk = tensor[chunk_index].to(torch.float32)  # [C, H, W] or [1, H, W]

        # Ensure 3 channels
        if chunk.size(0) == 1:
            chunk = chunk.repeat(3, 1, 1)
        # Resize to 224x224
        chunk = F.interpolate(chunk.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)

        label_tensor = torch.zeros(len(self.label_to_idx), dtype=torch.float32)
        if label in self.label_to_idx:
            label_index = self.label_to_idx[label]
            label_tensor[label_index] = 1.0
        assert label_tensor.sum() >= 1
        return chunk, label_tensor


In [4]:
import os

train_files, test_files = [], []
for file in os.listdir('./data/processed_train_audio'):
    path = f'./data/processed_train_audio/{file}'
    if '_train' in file:
        train_files.append(path)
    elif '_test' in file:
        test_files.append(path)


In [6]:
import pandas as pd
metadata = pd.read_csv("./data/processed_data.csv")
unique_labels = sorted(metadata["primary_label"].astype(str).unique())
label_to_idx = {label: idx for idx,label in enumerate(unique_labels)}
print(len(label_to_idx.keys()))

tensor_dir = "./data/processed_train_audio/"

206


In [9]:
train_dataset = ViTSpectogramDataset(train_files, label_to_idx)
test_dataset = ViTSpectogramDataset(test_files, label_to_idx)

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
import timm
import torch.nn as nn

num_classes = len(label_to_idx)
model = timm.create_model("vit_base_patch16_224", pretrained=True)
model.head = nn.Linear(model.head.in_features, num_classes)
model = model.to(device)


In [None]:
import torch
for epoch in range(EPOCHS):
    # TRAININGTRAINING
    model.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        
        data = data.to(device)  
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)  # shape: [batch, num_classes]

        print("Output shape:", output.shape)
        print("Target shape:", target.shape)

        loss = loss_fn(output, target)  # target must be float32 multi-hot
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if batch_idx % 100 == 0:
            print("Output shape:", output.shape)
            print("Target shape:", target.shape)
            print(f"Epoch {epoch} [{batch_idx}/{len(train_loader)}] Loss: {loss.item():.4f}")


    train_loss /= len(train_loader)

    # EVALUATION
    model.eval()

    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            loss = loss_fn(output, target)
            test_loss += loss.item()

            # Multi-label prediction (threshold at 0.5)
            pred = (torch.sigmoid(output) > 0.5).float()

            # Count exact matches (all labels correct)
            correct += (pred == target).all(dim=1).sum().item()
            total += data.size(0)

    test_loss /= len(test_loader)
    accuracy = 100. * correct / total

    print(f"\nEpoch: {epoch}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.0f}%)")