### Installs, imports and settings

In [None]:
# Datasets:
DATASET = 'n-f-n'
# DATASET = 's-f-o'
# DATASET = 's-f-c'
# DATASET = 's-r-5o'
# DATASET = 's-r-5c'
# DATASET = 's-r-10c'
# DATASET = 'saa-march'

# Models:
MODEL = 'wavlm-basic'
# MODEL = 'wavlm-tdnn'
# MODEL = 'wav2vec2'
# MODEL = 'tdnn'

# Learning parameters:
FREEZING = False
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
AUDIO_LENGTH_SEC = 5

In [None]:
DATASET_NAME = "reralle/" + DATASET
FREEZE_STRING = 'freezing' if FREEZING else 'unfrozen'
OUTPUT_MODEL = MODEL + "_" + DATASET + "_" + str(BATCH_SIZE) + "batch_" + str(AUDIO_LENGTH_SEC) + "sec_" + str(LEARNING_RATE) + "lr_" + FREEZE_STRING
OUTPUT_MODEL

In [None]:
%%capture
!pip install git+https://github.com/huggingface/accelerate   ## as of may 11th something is broken in the pypi version so installing it via github
!pip install transformers==4.28.0 datasets evaluate huggingface_hub torchaudio librosa 
!export PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.6,max_split_size_mb:128

In [None]:
!huggingface-cli login

In [None]:
from datasets import load_dataset, load_metric, Audio
from transformers import AutoFeatureExtractor, TrainingArguments, Trainer, EarlyStoppingCallback, IntervalStrategy
from transformers.file_utils import ModelOutput
from transformers.modeling_outputs import SequenceClassifierOutput
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader, Dataset
import torchaudio
import evaluate
import librosa
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, Audio as IPythonAudio
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
from datetime import datetime
from typing import Optional, Tuple
from dataclasses import is_dataclass

In [None]:
if MODEL == 'wavlm-basic':
  from transformers import WavLMForSequenceClassification
  BASE_MODEL = "microsoft/wavlm-large"

if MODEL == 'wavlm-tdnn':
  from transformers import WavLMModel, WavLMPreTrainedModel
  BASE_MODEL = "microsoft/wavlm-large"

if MODEL == 'wav2vec2':
  from transformers import Wav2Vec2ForSequenceClassification
  BASE_MODEL = "facebook/wav2vec2-large"

## Loading data

In [None]:
dataset = load_dataset(DATASET_NAME)

In [None]:
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label
num_labels = len(id2label)


## Feature extraction

In [None]:
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))

In [None]:
if MODEL in {'wavlm-basic', 'wavlm-tdnn', 'wav2vec2'}:
    
    feature_extractor = AutoFeatureExtractor.from_pretrained(BASE_MODEL)

    def preprocess_function(examples):
        audio_arrays = [x["array"] for x in examples["audio"]]
        inputs = feature_extractor(
            audio_arrays, 
            sampling_rate=feature_extractor.sampling_rate, 
            max_length=AUDIO_LENGTH_SEC * 16000, 
            truncation=True, 
        )
        return inputs

    encoded_dataset = dataset.map(preprocess_function, remove_columns=["audio"], batched=True)
    encoded_dataset

elif MODEL == 'tdnn':
    # For normalization
    def compute_mean_std(dataset):
        all_features = []

        for idx in range(len(dataset)):
            audio = dataset[idx]["audio"]
            signal = audio["array"][:(AUDIO_LENGTH_SEC * 16000)]
            sr = audio["sampling_rate"]
            mfccs = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13, n_fft=int(0.025*sr), hop_length=int(0.010*sr))
            mfccs = np.transpose(mfccs)
            delta_mfccs = librosa.feature.delta(mfccs, order=1)
            delta2_mfccs = librosa.feature.delta(mfccs, order=2)
            features = np.concatenate((mfccs, delta_mfccs, delta2_mfccs), axis=1)
            all_features.append(features)

        all_features = np.vstack(all_features)
        mean = np.mean(all_features, axis=0)
        std = np.std(all_features, axis=0)

        return mean, std

    class MFCCDataset(Dataset):
        def __init__(self, dataset, mean=None, std=None):
            self.dataset = dataset
            self.mean = mean
            self.std = std

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            audio = self.dataset[idx]['audio']
            label = self.dataset[idx]['label']

            # Get the audio signal directly from the dataset
            signal = audio['array'][:AUDIO_LENGTH_SEC * 16000]
            sr = audio['sampling_rate']

            # Compute MFCCs
            mfccs = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13, n_fft=int(0.025*sr), hop_length=int(0.010*sr))

            # Compute deltas and delta-deltas
            mfccs_delta = librosa.feature.delta(mfccs)
            mfccs_delta_delta = librosa.feature.delta(mfccs_delta)

            # Concatenate MFCCs, deltas, and delta-deltas
            features = torch.tensor(np.vstack([mfccs, mfccs_delta, mfccs_delta_delta]), dtype=torch.float32)

            # Transpose
            features = torch.transpose(features, 0, 1)

            # Normalize the features
            if self.mean is not None and self.std is not None:
                features = (features - self.mean) / self.std

            return features.float(), torch.tensor(label, dtype=torch.long)

    train_mean, train_std = compute_mean_std(dataset["train"])

    # Create the PyTorch Datasets
    train_dataset = MFCCDataset(dataset["train"], mean=train_mean, std=train_std)
    validation_dataset = MFCCDataset(dataset["validation"], mean=train_mean, std=train_std)
    test_dataset = MFCCDataset(dataset['test'], mean=train_mean, std=train_std)

    # Create the DataLoaders
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Model

In [None]:
if MODEL in {'wavlm-tdnn', 'tdnn'}:

  # Source: https://github.com/cvqluu/TDNN
  class TDNN(nn.Module):
    
    def __init__(
                    self, 
                    input_dim=23, 
                    output_dim=512,
                    context_size=5,
                    stride=1,
                    dilation=1,
                    batch_norm=False,
                    dropout_p=0.0
                ):
        '''
        TDNN as defined by https://www.danielpovey.com/files/2015_interspeech_multisplice.pdf
        Affine transformation not applied globally to all frames but smaller windows with local context
        batch_norm: True to include batch normalisation after the non linearity
        
        Context size and dilation determine the frames selected
        (although context size is not really defined in the traditional sense)
        For example:
            context size 5 and dilation 1 is equivalent to [-2,-1,0,1,2]
            context size 3 and dilation 2 is equivalent to [-2, 0, 2]
            context size 1 and dilation 1 is equivalent to [0]
        '''
        super(TDNN, self).__init__()
        self.context_size = context_size
        self.stride = stride
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.dilation = dilation
        self.dropout_p = dropout_p
        self.batch_norm = batch_norm
      
        self.kernel = nn.Linear(input_dim*context_size, output_dim)
        self.nonlinearity = nn.ReLU()
        if self.batch_norm:
            self.bn = nn.BatchNorm1d(output_dim)
        if self.dropout_p:
            self.drop = nn.Dropout(p=self.dropout_p)
        
    def forward(self, x):
        '''
        input: size (batch, seq_len, input_features)
        outpu: size (batch, new_seq_len, output_features)
        '''

        #print("x.shape in the forward of a tdnn layer: ", x.shape)

        _, _, d = x.shape
        assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format(self.input_dim, d)
        x = x.unsqueeze(1)

        # Unfold input into smaller temporal contexts
        x = F.unfold(
                        x, 
                        (self.context_size, self.input_dim), 
                        stride=(1,self.input_dim), 
                        dilation=(self.dilation,1)
                    )

        # N, output_dim*context_size, new_t = x.shape
        x = x.transpose(1,2)
        x = self.kernel(x)
        x = self.nonlinearity(x)
        
        if self.dropout_p:
            x = self.drop(x)

        if self.batch_norm:
            x = x.transpose(1,2)
            x = self.bn(x)
            x = x.transpose(1,2)

        return x

  # Source: https://github.com/r39ashmi/e2e_dialect/blob/main/classification/main_tdnn.py
  class TDNNHead(nn.Module):

      def __init__(self, config):
          super().__init__()
          self.frame1 = TDNN(input_dim=config.hidden_size, output_dim=128, context_size=5, dilation=2, dropout_p=0.5)
          self.frame2 = TDNN(input_dim=128, output_dim=128, context_size=3, dilation=3, dropout_p=0.5)
          self.frame3 = TDNN(input_dim=128, output_dim=128, context_size=3, dilation=4, dropout_p=0.5)
          self.frame4 = TDNN(input_dim=128, output_dim=128, context_size=1, dilation=1, dropout_p=0.5)
          self.frame5 = TDNN(input_dim=128, output_dim=375, context_size=1, dilation=1, dropout_p=0.5)
          self.relu=nn.ReLU()
          self.linear1=nn.Linear(375,375)
          self.linear2=nn.Linear(375,150)
          self.linear3=nn.Linear(150,10)
          nfilters=40

      def forward(self, features, **kwargs):
          output1=self.frame1(features)
          output1=self.frame2(output1)
          output1=self.frame3(output1)
          output1=self.frame4(output1)
          output1=self.frame5(output1)
          output1=output1.permute(0,2,1)
          output1=torch.mean(output1,2)
          output1=self.relu(self.linear1(output1))
          output1=self.relu(self.linear2(output1))
          output1=self.linear3(output1)
          return F.log_softmax(output1,dim=-1)

In [None]:
if MODEL == 'wavlm-tdnn':

  class WavLMwithTDNN(WavLMPreTrainedModel):
      def __init__(self, config):
          super().__init__(config)
          self.num_labels = config.num_labels
          self.pooling_mode = "mean"
          self.config = config

          self.wavlm = WavLMModel(config)
          self.classifier = TDNNHead(config)

          self.init_weights()

      def freeze_base_model(self):
          """
          Calling this function will disable the gradient computation for the base model so that its parameters will not
          be updated during training. Only the classification head will be updated.
          """
          for param in self.wavlm.parameters():
              param.requires_grad = False


      def forward(
              self,
              input_values,
              attention_mask=None,
              output_attentions=None,
              output_hidden_states=None,
              return_dict=None,
              labels=None,
      ):
          return_dict = return_dict if return_dict is not None else self.config.use_return_dict
          outputs = self.wavlm(
              input_values,
              attention_mask=attention_mask,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
          )
          hidden_states = outputs[0]
          logits = self.classifier(hidden_states)
          loss_fct = CrossEntropyLoss()
          loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
          #print(f"logits shape: {logits.shape}, labels shape: {labels.shape if labels is not None else None}")

          return SequenceClassifierOutput(
              loss=loss,
              logits=logits,
              hidden_states=outputs.hidden_states,
              attentions=outputs.attentions,
          )

In [None]:
if MODEL == 'wavlm-basic':
  model = WavLMForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=num_labels, label2id=label2id, id2label=id2label)

if MODEL == 'wavlm-tdnn':
  model = WavLMwithTDNN.from_pretrained(BASE_MODEL, num_labels=num_labels, label2id=label2id, id2label=id2label)

if MODEL == 'wav2vec2':
  model = Wav2Vec2ForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=num_labels, label2id=label2id,mid2label=id2label)
  
if MODEL == 'tdnn':
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model = TDNNHead(batch_size=BATCH_SIZE, in_channels=39, learnable=True)
  model.to(device)

In [None]:
if FREEZING:
  model.freeze_base_model()

## Training

In [None]:
def print_metrics(accuracy, f1_average, acc_per_class, f1_per_class, cm, export):
    print(f"Accuracy: {accuracy}")
    print(f"F1 average: {f1_average}")
    print("Accuracy per class:")
    for i, label in id2label.items():
        print(f"\t{label}: {acc_per_class[int(i)]}")

    print("F1 score per class:")
    for i, label in id2label.items():
        print(f"\t{label}: {f1_per_class[int(i)]}")

    print("Confusion matrix:")
    print(cm)
    print("\n")

    if export:
      current_dateTime = datetime.now()
      f1_df = pd.DataFrame(f1_per_class)
      f1_df.to_excel(f"f1_epoch{current_dateTime}.xlsx")
      cmtx = pd.DataFrame(cm, index=['true:{:}'.format(x) for x in dataset["test"].features["label"].names], columns=['pred:{:}'.format(x) for x in dataset["test"].features["label"].names])
      cmtx.to_excel(f"cm_epoch{current_dateTime}.xlsx")

In [None]:
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    true_labels = [id2label[str(x)] for x in eval_pred.label_ids]
    predicted_labels = [id2label[str(x)] for x in predictions]
  
    accuracy = accuracy_score(true_labels, predicted_labels)
    f1_per_class = f1_score(true_labels, predicted_labels, average=None)
    f1_average = f1_score(true_labels, predicted_labels, average="macro")
    cm = confusion_matrix(true_labels, predicted_labels)
    acc_per_class = cm.diagonal() / cm.sum(axis=1)

    print_metrics(accuracy, f1_average, acc_per_class, f1_per_class, cm, export=False)

    return {"accuracy": accuracy, "f1": f1_average}

In [None]:
# My twist on the early stopping callback, because the original did not work
class MyEarlyStoppingCallback(EarlyStoppingCallback):

    def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
        super().__init__(early_stopping_patience, early_stopping_threshold)
        self.best_metric = 0.0

    def check_metric_value(self, args, state, control, metric_value):

        print(f"best model: {trainer.state.best_model_checkpoint}")
        print(f"metric value: {metric_value}")

        if metric_value - self.best_metric > self.early_stopping_threshold:
            self.early_stopping_patience_counter = 0
            print("reset early stopping patience")
        else:
            self.early_stopping_patience_counter += 1
            print(f"increased early stopping counter to {self.early_stopping_patience_counter}")

        if metric_value > self.best_metric:
            self.best_metric = metric_value

In [None]:
# # Evaluate 3 times per epoch
# # Useful for larger datasets but then this needs to be added to the training arguments:
#     # evaluation_strategy = IntervalStrategy.STEPS, # "steps"
#     # eval_steps = NUM_STEPS, # Evaluation and Save happens every n steps
#     # save_steps = NUM_STEPS,
#     #save_strategy = IntervalStrategy.STEPS,
# NUM_STEPS = int((len(encoded_dataset["train"]) * 31) / (3 * 1000))
# NUM_STEPS

In [None]:
if MODEL in {'wavlm-basic', 'wavlm-tdnn', 'wav2vec2'}:
  args = TrainingArguments(
      OUTPUT_MODEL,
      save_total_limit = 3,
      evaluation_strategy ="epoch",
      save_strategy="epoch",
      learning_rate=LEARNING_RATE,
      per_device_train_batch_size=BATCH_SIZE,
      gradient_accumulation_steps=4,
      per_device_eval_batch_size=BATCH_SIZE,
      num_train_epochs=1000,
      warmup_ratio=0.003,
      logging_steps=15,
      load_best_model_at_end=True, 
      metric_for_best_model = 'eval_f1',
      push_to_hub=True,
  )

  trainer = Trainer(
      model,
      args,
      train_dataset=encoded_dataset["train"],
      eval_dataset=encoded_dataset["validation"],
      tokenizer=feature_extractor,
      compute_metrics=compute_metrics,
      callbacks = [MyEarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.005)]
  )

  trainer.train()

In [None]:
if MODEL in {'wavlm-basic', 'wavlm-tdnn', 'wav2vec2'}:
  trainer.evaluate(eval_dataset=encoded_dataset["test"])

In [None]:
if MODEL in {'wavlm-basic', 'wavlm-tdnn', 'wav2vec2'}:
  trainer.save_model()
  trainer.push_to_hub()

In [None]:
if MODEL == 'tdnn':
    # Set loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    # Training function
    def train(model, dataloader, criterion, optimizer, device):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for features, labels in dataloader:
            features, labels = features.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * features.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            #print(f"{loss=} {running_loss=} {outputs.max(1)=} {labels.size(0)=} {total=} {correct=}")

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        return epoch_loss, epoch_acc

    # Validation function
    def validate(model, dataloader, criterion, device):
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for features, labels in dataloader:
                features, labels = features.to(device), labels.to(device)

                outputs = model(features)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * features.size(0)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        return epoch_loss, epoch_acc

    # Training loop
    num_epochs = 30
    best_val_acc = 0.0

    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []

    for epoch in range(num_epochs):
        train_loss, train_acc = train(model, train_dataloader, criterion, optimizer, device)
        val_loss, val_acc = validate(model, validation_dataloader, criterion, device)
        
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        
        print('Epoch: {}/{}, Train Loss: {:.4f}, Train Acc: {:.4f}, Val Loss: {:.4f}, Val Acc: {:.4f}'.format(
            epoch + 1, num_epochs, train_loss, train_acc, val_loss, val_acc))

        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f"{OUTPUT_MODEL}_best.pth")

    print("Training completed. The best model is saved as", f"{OUTPUT_MODEL}_best.pth")

    # Plot losses
    plt.figure()
    plt.plot(train_losses, label="Training Loss")
    plt.plot(val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    # Plot accuracies
    plt.figure()
    plt.plot(train_accuracies, label="Training Accuracy")
    plt.plot(val_accuracies, label="Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.show()

In [None]:
if MODEL == 'tdnn':
    # Load the best model and test
    model.load_state_dict(torch.load(f"{OUTPUT_MODEL}_best.pth"))
    test_loss, test_acc = validate(model, test_dataloader, criterion, device)
    print("Test Loss: {:.4f}, Test Acc: {:.4f}".format(test_loss, test_acc))