In [1]:
import numpy as np
import os
from torch.utils.data import Dataset
import librosa
import torch
import matplotlib.pyplot as plt
from torch import nn
from tqdm import tqdm
import torchvision
from sklearn.metrics import roc_curve


from torchvision.models import resnet34

device =  'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using {device}")

Using cuda


In [2]:
def pad_random(x: np.ndarray, max_len: int = 64000):
    x_len = x.shape[0]
    if x_len > max_len:
        stt = np.random.randint(x_len - max_len)
        return x[stt:stt + max_len]

    num_repeats = int(max_len / x_len) + 1
    padded_x = np.tile(x, (num_repeats))
    return pad_random(padded_x, max_len)

class SVDD2024(Dataset):
    """
    Dataset class for the SVDD 2024 dataset.
    """
    def __init__(self, base_dir, partition="train", max_len=64000):
        assert partition in ["train", "dev", "test"], "Invalid partition. Must be one of ['train', 'dev', 'test']"
        self.base_dir = base_dir
        self.partition = partition
        self.base_dir = os.path.join(base_dir, partition + "_set")
        self.max_len = max_len

        self.transforms = torchvision.transforms.Compose([torchvision.transforms.Resize((224,224))])
        
        try:
            with open(os.path.join(base_dir, f"{partition}.txt"), "r") as f:
                self.file_list = f.readlines()
        except FileNotFoundError:
            if partition == "test":
                self.file_list = []
                # get all *.flac files in the test_set directory
                for root, _, files in os.walk(self.base_dir):
                    for file in files:
                        if file.endswith(".flac"):
                            self.file_list.append(file)
            else:
                raise FileNotFoundError(f"File {partition}.txt not found in {base_dir}")


            
    
    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):            
        if self.partition == "test":
            file_name = self.file_list[index].strip()
            label = 0 # dummy label. Not used for test set.
        else:
            file = self.file_list[index]
            file_name = file.split(" ")[2].strip()
            bonafide_or_spoof = file.split(" ")[-1].strip()
            label = 1 if bonafide_or_spoof == "bonafide" else 0
        try:
            x, _ = librosa.load(os.path.join(self.base_dir, file_name + ".flac"), sr=16000, mono=True)
            x = pad_random(x, self.max_len) # x = pad_random (audio,64000)
            x = librosa.util.normalize(x)
            
            x = librosa.feature.mfcc(y=x, sr=16000, n_mfcc=44,hop_length=160,win_length=320)
            
            return torch.unsqueeze(torch.tensor(x),dim=0), label, file_name
            
            #x = librosa.feature.chroma_cqt(y=x,sr=16000)
            #x = np.abs(librosa.stft(x))
            #x= librosa.amplitude_to_db(x,ref=np.max)
            # file_name is used for generating the score file for submission
            #return self.transforms(torch.unsqueeze(torch.from_numpy(x),dim=0)), label, file_name

        except Exception as e:
            print(f"Error loading {file_name}: {e}")
            return None

In [3]:
train_ds=SVDD2024('./temp/ds/',partition='train')
test_ds=SVDD2024('./temp/ds/',partition='dev')

In [4]:
train_loader=torch.utils.data.DataLoader(train_ds,batch_size=16,num_workers=4)
test_loader=torch.utils.data.DataLoader(test_ds,batch_size=16,num_workers=4)

In [5]:
im=next(iter(train_loader))[0]
im.shape

torch.Size([16, 1, 44, 401])

In [6]:
model=resnet34()
model.conv1=nn.Conv2d(1,64,kernel_size=(7,7),stride=(2,2),padding=(3,3),bias=False)

In [7]:
model.fc=nn.Linear(512,1,bias=True)

### Model Testing

In [8]:
with torch.inference_mode():
    out=model(im)
out.shape

torch.Size([16, 1])

In [13]:
torch.sigmoid(out).detach().squeeze()

tensor([0.5659, 0.6237, 0.5919, 0.6035, 0.5649, 0.5832, 0.5538, 0.5650, 0.6342,
        0.5868, 0.6292, 0.5994, 0.6193, 0.5587, 0.5936, 0.6003])

### EER Functions

In [9]:
def compute_det_curve(target_scores, nontarget_scores):

    n_scores = target_scores.size + nontarget_scores.size
    all_scores = np.concatenate((target_scores, nontarget_scores))
    labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size)))

    # Sort labels based on scores
    indices = np.argsort(all_scores, kind='mergesort')
    labels = labels[indices]

    # Compute false rejection and false acceptance rates
    tar_trial_sums = np.cumsum(labels)
    nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)

    frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size))  # false rejection rates
    far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size))  # false acceptance rates
    thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))  # Thresholds are the sorted scores

    return frr, far, thresholds


def compute_eer(target_scores, nontarget_scores):
    """ Returns equal error rate (EER) and the corresponding threshold. """
    frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
    abs_diffs = np.abs(frr - far)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((frr[min_index], far[min_index]))
    return eer, thresholds[min_index]


In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm

def compute_det_curve(target_scores, nontarget_scores):
    n_scores = target_scores.size + nontarget_scores.size
    all_scores = np.concatenate((target_scores, nontarget_scores))
    labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size)))

    # Sort labels based on scores
    indices = np.argsort(all_scores, kind='mergesort')
    labels = labels[indices]

    # Compute false rejection and false acceptance rates
    tar_trial_sums = np.cumsum(labels)
    nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)

    frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size))  # false rejection rates
    far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size))  # false acceptance rates
    thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))  # Thresholds are the sorted scores

    return frr, far, thresholds

def compute_eer(target_scores, nontarget_scores):
    """ Returns equal error rate (EER) and the corresponding threshold. """
    frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
    abs_diffs = np.abs(frr - far)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((frr[min_index], far[min_index]))
    return eer, thresholds[min_index]

def train_model(model, train_loader, test_loader, epochs, device):
    model = model.to(device)
    optimizer = optim.Adam(params=model.parameters())
    loss_fn = nn.BCEWithLogitsLoss()

    for i in tqdm(range(epochs)):
        train_preds = torch.tensor([], device=device)
        val_preds = torch.tensor([], device=device)
        train_actual = torch.tensor([], device=device)
        val_actual = torch.tensor([], device=device)
        
        print("\nTraining:")
        model.train()
        
        temp_train_loss = []
        temp_val_loss = []
    
        net_train_loss = 0
        net_val_loss = 0
        
        with tqdm(total=len(train_loader)) as pbar:
            for x, y, _ in train_loader:
                x = x.to(device)
                y = y.to(device).type(torch.float32)

                logits = model(x)
        
                optimizer.zero_grad()
                loss = loss_fn(logits.squeeze(), y)
                
                preds = torch.sigmoid(logits).detach().squeeze()
                train_preds = torch.cat((train_preds, preds))
                train_actual = torch.cat((train_actual, y))
                
                temp_train_loss.append(loss.item())
                
                loss.backward()
                optimizer.step()
                
                pbar.update(1)
            pbar.close()
        
        net_train_loss = sum(temp_train_loss) / len(temp_train_loss)
    
        print("Testing:")
        model.eval()
    
        with tqdm(total=len(test_loader)) as pbar2:
            for x, y, _ in test_loader:
                x = x.to(device)
                y = y.to(device).type(torch.float32)
                
                with torch.inference_mode():
                    logits = model(x)
                    loss = loss_fn(logits.squeeze(), y.type(torch.float32))
                    
                    preds = torch.sigmoid(logits).detach().squeeze()
                    val_preds = torch.cat((val_preds, preds))
                    val_actual = torch.cat((val_actual, y))
                    
                    temp_val_loss.append(loss.item())
    
                    pbar2.update(1)
            pbar2.close()
    
        net_val_loss = sum(temp_val_loss) / len(temp_val_loss)
        
        # Compute EER for training and validation sets
        train_preds_np = train_preds.cpu().numpy()
        train_actual_np = train_actual.cpu().numpy()
        val_preds_np = val_preds.cpu().numpy()
        val_actual_np = val_actual.cpu().numpy()
        
        
        # Get target and nontarget scores for train set
        train_target_scores = train_preds_np[train_actual_np == 1]
        train_nontarget_scores = train_preds_np[train_actual_np == 0]
        
        # Get target and nontarget scores for val set
        val_target_scores = val_preds_np[val_actual_np == 1]
        val_nontarget_scores = val_preds_np[val_actual_np == 0]
        
        # Compute EER for training set
        train_eer, train_eer_threshold = compute_eer(train_target_scores, train_nontarget_scores)
        # Compute EER for validation set
        val_eer, val_eer_threshold = compute_eer(val_target_scores, val_nontarget_scores)
        
        print(f"Epoch {i + 1}/{epochs}")
        print(f"Train Loss: {net_train_loss:.4f}")
        print(f"Val Loss: {net_val_loss:.4f}")
        print(f"Train EER: {train_eer:.4f} at threshold {train_eer_threshold:.4f}")
        print(f"Val EER: {val_eer:.4f} at threshold {val_eer_threshold:.4f}")

