In [1]:
# Import necessary libraries
import numpy as np
import torch
from torch import nn, optim
import torchaudio
import os
from tqdm.auto import tqdm
import random
import gc
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from transformers import AutoConfig, AutoModel,Wav2Vec2FeatureExtractor
from sklearn.metrics import roc_auc_score, roc_curve, f1_score, classification_report, ConfusionMatrixDisplay

In [2]:
def getlabels(path):
    text = readtxtfile(path)
    filename2label = {}
    for item in tqdm(text):
        key = item.split(' ')[1]
        value = item.split(' ')[-1]
        filename2label[key] = value
    return filename2label

In [3]:
def readtxtfile(path):
    with open(path, 'r') as file:
        text = file.read().splitlines()
        return text

In [4]:
class ASVSpoof(torch.utils.data.Dataset):
    def __init__(self, audio_dir_path, num_samples, filename2label):
        super().__init__()
        self.audio_dir_path = audio_dir_path
        self.filename2label = filename2label
        self.audio_file_names = self.get_audio_file_names(filename2label)
        self.num_samples = num_samples
        self.labels, self.label2id, self.id2label = self.get_labels(filename2label)
        
    def __getitem__(self, index):
        audio_path = os.path.join(self.audio_dir_path, self.audio_file_names[index])
        signal, sr = torchaudio.load(audio_path)
        signal = self.mix_down_if_necessary(signal)
        signal = self.cut_if_necessary(signal)
        signal = self.right_pad_if_necessary(signal)
        signal = signal.squeeze(0)
        label = self.labels[index]
        return signal, label
    
    def __len__(self):
        return len(self.labels)
    
    def get_audio_file_names(self, filename2label):
        audio_file_names = [name + '.flac' for name in filename2label.keys()]  # Modify if extension varies
        available_files = set(os.listdir(self.audio_dir_path))
        audio_file_names = [name for name in audio_file_names if name in available_files]
        return audio_file_names
    
    def get_labels(self, filename2label):
        labels = [filename2label[os.path.splitext(name)[0]] for name in self.audio_file_names]
        id2label = {idx: label for idx, label in enumerate(sorted(set(labels)))}
        label2id = {label: idx for idx, label in id2label.items()}
        labels = [label2id[label] for label in labels]
        return labels, label2id, id2label
    
    def mix_down_if_necessary(self, signal):
        if signal.shape[0] > 1: 
            signal = torch.mean(signal, dim=0, keepdim=True)
        return signal
    
    def cut_if_necessary(self, signal):
        if signal.shape[1] > self.num_samples:
            signal = signal[:, :self.num_samples]
        return signal
    
    def right_pad_if_necessary(self, signal):
        length = signal.shape[1]
        if self.num_samples > length:
            pad_last_dim = (0, self.num_samples - length)
            signal = torch.nn.functional.pad(signal, pad_last_dim)
        return signal

In [5]:
class CustomWavLMForClassification(nn.Module):
    def __init__(self, checkpoint):
        super(CustomWavLMForClassification, self).__init__()
        
        # Load WavLM config and model
        config = AutoModel.from_pretrained(checkpoint)
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint)  # WavLM uses the same feature extractor as Wav2Vec2
        
        # Load the WavLM model
        self.wavlm = AutoModel.from_pretrained(checkpoint,config=config)
        
        # Additional layers
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.1)
        self.pool = nn.AdaptiveAvgPool1d(128)
        
        # Classification layer
        self.linear = nn.Linear(768 * 128, 1)
        self.sigmoid = nn.Sigmoid()
        self.to(device)

    def forward(self, input_ids,attention_mask):
        # Move inputs to the correct device
        input_ids = input_ids.squeeze(1).to(device)
        attention_mask = attention_mask.to(device)

        # Extract features using the feature extractor
        input_features = self.feature_extractor(input_ids, attention_mask=attention_mask,return_tensors="pt", padding=True, sampling_rate=16000).input_values.squeeze(0).to(device)
        # Pass through WavLM model
        features = self.wavlm(input_features).last_hidden_state
        x = features  
    
        # Apply pooling
        x = x.transpose(1, 2)  # Change shape to (batch_size, hidden_size, seq_length)
        x = self.pool(x)  # Pooling operation
#         print(f"Shape after pooling: {x.shape}")  # Should be (batch_size, hidden_size, pooled_length)

        # Flatten the output
        x = x.reshape(x.shape[0], -1)  # Flatten to (batch_size, hidden_size * pooled_length)
#         print(f"Shape after flattening: {x.shape}")  # Should be (batch_size, hidden_size * pooled_length)

        # Pass through the linear layer
        x = self.linear(x)  # Ensure dimensions match here

        # Apply sigmoid activation
        x = self.sigmoid(x)
    
        return x

In [6]:
def collate_fn(batch):
    batch = [b for b in batch if b[0] is not None]
    if not batch:
        return None, None
    input_values, labels = zip(*batch)
    
    # Pad sequences to the same length
    max_len = max(input_values, key=lambda x: x.shape[0]).shape[0]
    padded_inputs = torch.stack([x for x in input_values])
    labels = torch.tensor(labels)
    return padded_inputs, labels

In [7]:
def EER(labels, outputs):
    fpr, tpr, threshold = roc_curve(labels, outputs, pos_label=1)
    fnr = 1 - tpr
    eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
    eer_threshold
    eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
    return eer

In [8]:
from torch.utils.data import random_split
torch.cuda.empty_cache()
# Mixed precision scaler
scaler = torch.amp.GradScaler('cuda')

train_audio_files_path = '/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_train/flac'
train_labels_path = '/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'
val_audio_files_path = '/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_dev/flac'
val_labels_path = '/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'
test_audio_files_path = '/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_eval/flac'
test_labels_path = '/kaggle/input/asvpoof-2019-dataset/LA/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt'


train_filename2label = getlabels(train_labels_path)
val_filename2label = getlabels(val_labels_path)
test_filename2label = getlabels(test_labels_path)

print("Length of training dataset: ", len(train_filename2label))
print("Length of validation dataset: ", len(val_filename2label))
print("Length of test dataset: ", len(test_filename2label))

num_samples = 4 * 16000
# Update the DataLoader creation to use the new subset datasets
train_dataset = ASVSpoof(train_audio_files_path, num_samples, train_filename2label)
val_dataset = ASVSpoof(val_audio_files_path, num_samples, val_filename2label)
test_dataset = ASVSpoof(test_audio_files_path, num_samples, test_filename2label)

device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
checkpoint = "microsoft/wavlm-base"
model = CustomWavLMForClassification(checkpoint)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
print(train_dataset[0][0].shape)

train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size

# Split the train_dataset into train and validation subsets
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True,collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=True,collate_fn=collate_fn, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True,collate_fn=collate_fn, pin_memory=True)

t_steps = len(train_loader)
v_steps = len(val_loader)
ts_steps = len(test_loader)

print(t_steps)
print(v_steps)
print(ts_steps)

  0%|          | 0/25380 [00:00<?, ?it/s]

  0%|          | 0/24844 [00:00<?, ?it/s]

  0%|          | 0/71237 [00:00<?, ?it/s]

Length of training dataset:  25380
Length of validation dataset:  24844
Length of test dataset:  71237
cuda


config.json:   0%|          | 0.00/2.24k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/378M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

torch.Size([64000])
635
159
2227


In [9]:
gc.collect()
torch.cuda.empty_cache()

In [10]:
torch.autograd.set_detect_anomaly(True)
num_epochs = 6
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0.0  # Initialize the best validation accuracy

for epoch in range(num_epochs):
    gc.collect()
    torch.cuda.empty_cache()
    y_true = []
    y_pred = []
    train_loss = 0.0
    correct = 0
    total = 0
    
    train_loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch + 1}/{num_epochs}')
    for batch_idx, (audio_inputs, labels) in train_loop:
        train_loop.set_description(f'Epoch {epoch + 1} / {num_epochs}')
        model.train()
        torch.cuda.empty_cache()
        audio_inputs = audio_inputs.to(device).squeeze(1)
        labels = labels.to(device)
        labels = labels.type(torch.cuda.FloatTensor)
        attention_mask = torch.ones(audio_inputs.shape, dtype=torch.long).to(device)
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(audio_inputs,attention_mask).squeeze(1)
        y_pred.append(outputs.detach().cpu().numpy())
        y_true.append(labels.detach().cpu().numpy())
        
        loss = criterion(outputs, labels)
        train_loss += loss.item()
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        predicted = (outputs > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        train_loop.set_postfix(Training_loss=loss.item(), Training_accuracy=100 * correct / total)
        # clear some memory
        del audio_inputs, labels, outputs, loss

    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    train_eer = EER(y_true, y_pred)
    train_accuracy = 100 * correct / total
    
    # Validation every epoch
    y_true = []
    y_pred = []
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        val_loop = tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation')
        for val_batch_idx, (audio_inputs, labels) in val_loop:
            torch.cuda.empty_cache()
            audio_inputs = audio_inputs.to(device).squeeze(1)
            val_labels = labels.to(device)
            val_labels = labels.type(torch.cuda.FloatTensor)
            attention_mask = torch.ones(audio_inputs.shape, dtype=torch.long).to(device)
            
            val_outputs = model(audio_inputs,attention_mask).squeeze(1)
            y_pred.append(val_outputs.detach().cpu().numpy())
            y_true.append(labels.detach().cpu().numpy())
            curr_val_loss = criterion(val_outputs, val_labels)
            val_loss += curr_val_loss.item()

            # Calculate accuracy
            predicted = (val_outputs > 0.5).float()
            total += val_labels.size(0)
            correct += (predicted == val_labels).sum().item()
            val_loop.set_postfix(validation_loss=curr_val_loss.item(), validation_accuracy = 100 * correct / total)

    train_loss_after_epoch = train_loss / len(train_loader)
    val_loss_after_epoch = val_loss / len(val_loader)
    train_losses.append(train_loss_after_epoch)
    val_losses.append(val_loss_after_epoch)
    train_accuracies.append(train_accuracy)
    val_accuracy = 100 * correct / total
    val_accuracies.append(val_accuracy)

                 
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    val_eer = EER(y_true,y_pred)

    print(f'Epoch : {epoch + 1} Training loss : {train_loss_after_epoch} Train EER : {train_eer} Training accuracy: {train_accuracy}% Validation loss : {val_loss_after_epoch}  Val EER : {val_eer} Validation accuracy: {val_accuracy}%')
    
print("Training completed.")

Epoch 1/6:   0%|          | 0/635 [00:00<?, ?it/s]

Validation:   0%|          | 0/159 [00:00<?, ?it/s]

Epoch : 1 Training loss : 0.06186307307581856 Train EER : 0.049734427812650896 Training accuracy: 97.83293932230103% Validation loss : 0.004425547508188449  Val EER : 0.0019646365422396855 Validation accuracy: 99.90149724192277%


Epoch 2/6:   0%|          | 0/635 [00:00<?, ?it/s]

Validation:   0%|          | 0/159 [00:00<?, ?it/s]

Epoch : 2 Training loss : 0.005514037145283825 Train EER : 0.003380009657170449 Training accuracy: 99.85224586288416% Validation loss : 0.0036389736433650927  Val EER : 0.0019646365422396855 Validation accuracy: 99.88179669030733%


Epoch 3/6:   0%|          | 0/635 [00:00<?, ?it/s]

Validation:   0%|          | 0/159 [00:00<?, ?it/s]

Epoch : 3 Training loss : 0.002281739077813017 Train EER : 0.0014485755673587638 Training accuracy: 99.96059889676911% Validation loss : 0.0017524590232841427  Val EER : 0.0019646365422396855 Validation accuracy: 99.98029944838456%


Epoch 4/6:   0%|          | 0/635 [00:00<?, ?it/s]

Validation:   0%|          | 0/159 [00:00<?, ?it/s]

Epoch : 4 Training loss : 0.0024710054396363855 Train EER : 0.0014485755673587638 Training accuracy: 99.95567375886525% Validation loss : 0.002078344853200762  Val EER : 0.0 Validation accuracy: 99.94089834515367%


Epoch 5/6:   0%|          | 0/635 [00:00<?, ?it/s]

Validation:   0%|          | 0/159 [00:00<?, ?it/s]

Epoch : 5 Training loss : 0.0014002781873687772 Train EER : 0.0009657170449058426 Training accuracy: 99.97537431048069% Validation loss : 0.011601800736219236  Val EER : 0.0137524557956778 Validation accuracy: 99.822695035461%


Epoch 6/6:   0%|          | 0/635 [00:00<?, ?it/s]

Validation:   0%|          | 0/159 [00:00<?, ?it/s]

Epoch : 6 Training loss : 0.0018534872158537422 Train EER : 0.0009657170449058426 Training accuracy: 99.94089834515367% Validation loss : 0.002858759079639548  Val EER : 0.0019646365422396855 Validation accuracy: 99.94089834515367%
Training completed.


In [11]:
# # Load the checkpoint of the best validation accuracy model
# checkpoint = torch.load('/kaggle/input/rwaw_wavlm_trained_asvspoof2019/pytorch/default/1/best_model.pth')
# # Load model state dict
# model.load_state_dict(checkpoint['model_state_dict'])

In [12]:
new_outputs = []
new_labels = []
model.eval()
test_loss = 0.0
with torch.no_grad():
    test_loop = tqdm(enumerate(test_loader), total = len(test_loader))
    for test_batch_idx, (test_input_ids, test_labels) in test_loop:
        torch.cuda.empty_cache()
        test_input_ids = test_input_ids.to(device).squeeze(1)
        test_labels = test_labels.to(device)
        test_labels = test_labels.type(torch.cuda.FloatTensor) #use torch.FloatTensor if on cpu
        attention_mask = torch.ones(test_input_ids.shape, dtype=torch.long).to(device)


        test_outputs = model(test_input_ids, attention_mask).squeeze(1)
        new_outputs.append(test_outputs.cpu().numpy())
        new_labels.append(test_labels.cpu().numpy())
        curr_test_loss = criterion(test_outputs, test_labels)
        test_loss += curr_test_loss.item()
        test_loop.set_postfix(test_loss = curr_test_loss.item())
new_labels = np.concatenate(new_labels)
new_outputs = np.concatenate(new_outputs)
print(new_labels.shape, new_outputs.shape)
test_eer = EER(new_labels,new_outputs)
print(test_eer)

  0%|          | 0/2227 [00:00<?, ?it/s]

(71237,) (71237,)
0.012372535690006799


In [13]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'wav2vec_wavlm.pth')