In [1]:
import re
import argparse
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
import torchaudio.transforms as T

import torch
from torch import nn
from torch.nn import functional as F

import torchaudio

from tqdm.auto import tqdm

from sklearn.cluster import KMeans
from collections import Counter
os.environ["NCCL_DEBUG"] = "INFO"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MusicResNet18(nn.Module):
    def __init__(self, num_classes):
        super(MusicResNet18, self).__init__()

        self.resnet = models.resnet18(pretrained=True)
        self.fc = nn.Linear(1000, num_classes)

    def forward(self, x, embeddings=False):
        x = x.expand(x.size(0), 3, x.size(2), x.size(3))  
        out = self.resnet(x)

        if embeddings:
            return out
        else:
            lastlayer = self.fc(out)
            return lastlayer

In [3]:
class AudioDataset(Dataset):
    def __init__(self, csv_file, phase = 'TRAINING', target_sr=16000, device='cuda:1',chunk_dur = 30):
        self.data = pd.read_csv(csv_file, index_col=0)

        self.target_sr = target_sr
        self.chunk_dur = chunk_dur
        self.chunk_len = int(self.target_sr*self.chunk_dur)
        self.device = torch.device(device)
        self.phase = phase.upper()
        
        self.data = self.data[~self.data['panns'].isin(['Silence', 'Speech'])]
        self.data = self.data[~self.data['Singer'].isin(['Unlabelled', 'Other'])]

        if self.phase == 'TRAINING':
            self.data = self.data[self.data['Labeled'].isin([1])]
        else:
            self.data = self.data[self.data['Labeled'].isin([0])]

        self.singers = self.data['Singer'].unique()
        self.singer_to_idx = {singer: idx for idx, singer in enumerate(self.singers)}
        self.p_labels = np.zeros(len(self.data), dtype=int)

        self.data['label_idx'] = 0
        self.data['label_idx'] = self.data['Singer'].replace(self.singer_to_idx) 

        self.get_mel = T.MelSpectrogram(
            sample_rate=self.target_sr,
            n_fft=1280,
            n_mels=128,
            f_min=15,
            f_max=14000,
            hop_length=400,
            power=1.5
        ).to(self.device)

    def _fix_length(self, audio):
        if audio.shape[1] < self.chunk_len:
            pad_amount = self.chunk_len - audio.shape[1]
            audio = torch.nn.functional.pad(audio, (0, pad_amount))
        else:
            audio = audio[:, :self.chunk_len]
        return audio

    def _preprocess(self, audio, sr):
        if audio.shape[0] > 1:
            audio = torch.mean(audio, dim=0).unsqueeze(0)
        if sr != self.target_sr:
            resampler = T.Resample(orig_freq=sr, new_freq=self.target_sr).to(self.device)
            audio = resampler(audio)
        audio = self._fix_length(audio)
        mel_spec = self.get_mel(audio)
        return audio, mel_spec
        
    def __len__(self):
        return len(self.data)
    
    def calculate_class_weights(self):
        # Calculate class weights based on the true label distribution
        label_counts = self.data['label_idx'].value_counts().sort_index().values
        total_count = len(self.data)
        self.class_weights = total_count / (len(label_counts) * label_counts)
    
              
    def update_plabels(self, X, num_clusters = 69):

        print('Updating pseudo-labels...')
        kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(X)
        cluster_labels = kmeans.labels_

        valid_indices = np.arange(len(self.data))
        cluster_labels = cluster_labels[valid_indices]
        
        # Assign labels to clusters based on majority vote
        for cluster in range(num_clusters):
            indices = np.where(cluster_labels == cluster)[0]
            if len(indices) == 0:
                continue

            true_labels = self.data.iloc[indices]['label_idx'].values
            if len(true_labels) > 0:
                most_common_label = Counter(true_labels).most_common(1)[0][0]
                self.p_labels[indices] = most_common_label
            else:
                self.p_labels[indices] = -1

        p_label_counts = np.bincount(self.p_labels)
        total_count = len(self.p_labels)
        self.p_weights = total_count / (len(p_label_counts) * p_label_counts) 

    
    def __getitem__(self, idx):
        audio_path = self.data.iloc[idx]['Audio_Path']
        label_orig = self.data.iloc[idx]['Singer']
        label = self.singer_to_idx[label_orig]
        waveform, sample_rate = torchaudio.load(audio_path)

        audios, mel = self._preprocess(audio = waveform.to(self.device), sr = sample_rate)
        
        return idx, mel, label_orig, self.p_labels[idx], label

In [None]:
train_dataset = AudioDataset('/home/surge_siya/ALL CSV/Panns_CSV.csv', phase = 'training')
test_dataset = AudioDataset('/home/surge_siya/ALL CSV/Panns_CSV.csv', phase = 'testing')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [None]:
device = 'cuda:1'
num_classes = len(train_dataset.singers)
model = MusicResNet18(num_classes=num_classes).to(device)
num_epochs = 31
num_clusters = 69
lr = 0.001

In [6]:
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1,reduction='none').to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr = 0.01, total_steps=num_epochs*len(train_loader))

In [7]:
device = next(model.parameters()).device
print("Model is currently on:", device)

Model is currently on: cuda:1


In [8]:
def extract_features(loader,model,d=1000, device='cuda:2',num_instances=50000):#d: embedding dimension
    print('Extracting Features')
    device = torch.device(device)
    model = model.to(device)
    X = torch.zeros((num_instances,d)).to(device)
    model.eval()
    pbar = tqdm(loader)
    with torch.no_grad():
        for idx, mel, _,_,_ in pbar:
#             print(mel.shape)
            emb = model(mel.to(device),embeddings=True) 
#             print(idx.shape, emb.shape)
            X[idx,:] = emb#.cpu().numpy()
        X = np.array(X.cpu())
    return X.astype(np.float32)

In [None]:
pbar = tqdm(range(1, num_epochs + 1))
optimizer.zero_grad()

for epoch in pbar:
    model.eval()
    X = extract_features(loader=train_loader, model=model)
    train_dataset.update_plabels(X, num_clusters=num_clusters)
    model.train()
    model = model.to(device)
    total_loss = 0
    
    for batch_idx, (idx, mel, gt, _, target) in enumerate(tqdm(train_loader)):
        mel = mel.to(device)
        target = torch.tensor(train_dataset.p_labels[idx]).to(device)
        pred = model(mel)
        # tau = torch.Tensor(train_dataset.p_weights)
        # w = torch.Tensor(train_dataset.class_weights[target])
        loss_ce = criterion(pred, target)
        total_loss += loss_ce.mean()
        loss_ce.mean().backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    print(f'loss after epoch {epoch} is {total_loss/15}')
    total_loss = 0

In [12]:
from sklearn.metrics import accuracy_score, classification_report

def test_model(model, test_loader, device='cuda:1'):
    """
    Function to test the model on the test dataset and compute accuracy.

    Parameters:
    - model: Trained PyTorch model
    - test_loader: DataLoader for the test dataset
    - device: Device to run the model on (default: 'cuda:1')

    Returns:
    - accuracy: Accuracy of the model on the test set
    """
    model.eval()
    model = model.to(device)
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for idx, mel, _, _, target in tqdm(test_loader):
            mel = mel.to(device)
            pred = model(mel)
            preds = torch.argmax(pred, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(target.numpy())

    accuracy = accuracy_score(all_targets, all_preds)
    report = classification_report(all_targets, all_preds)
    print(report)
    return accuracy

In [13]:
train_accuracy = test_model(model, train_loader, device='cuda:1')

100%|██████████| 15/15 [00:01<00:00,  9.45it/s]

              precision    recall  f1-score   support

           0       0.44      0.67      0.53         6
           1       0.38      0.71      0.50         7
           2       0.50      0.83      0.62         6
           3       1.00      0.86      0.92         7
           4       0.25      0.57      0.35         7
           5       1.00      0.29      0.44         7
           6       0.19      0.50      0.27         6
           7       0.26      0.71      0.38         7
           8       1.00      0.83      0.91         6
           9       0.20      0.83      0.32         6
          10       0.27      0.57      0.36         7
          11       1.00      0.71      0.83         7
          12       1.00      0.29      0.44         7
          13       0.57      0.57      0.57         7
          14       0.50      0.14      0.22         7
          15       0.67      0.86      0.75         7
          16       0.18      0.33      0.24         6
          17       0.50    


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [14]:
test_accuracy = test_model(model, test_loader, device='cuda:1')

100%|██████████| 8/8 [00:00<00:00,  9.04it/s]

              precision    recall  f1-score   support

           0       0.50      1.00      0.67         1
           1       0.00      0.00      0.00         1
           2       0.25      1.00      0.40         1
           3       1.00      1.00      1.00         1
           4       0.06      1.00      0.12         1
           5       0.00      0.00      0.00         1
           6       0.00      0.00      0.00         1
           7       0.11      1.00      0.20         1
           8       1.00      1.00      1.00         1
           9       0.00      0.00      0.00         1
          10       0.00      0.00      0.00         1
          11       0.00      0.00      0.00         1
          12       0.00      0.00      0.00         1
          13       0.50      1.00      0.67         1
          14       1.00      1.00      1.00         1
          15       0.09      1.00      0.17         1
          16       0.00      0.00      0.00         2
          17       0.00    


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
