#Setup

In [None]:
import pandas as np
import numpy as np
import math
import time

# from tqdm.notebook import tqdm
from tqdm import tqdm
from matplotlib import pyplot as plt
from datetime import datetime

from sklearn.preprocessing import LabelEncoder

In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

#Set Variables

In [None]:
import os

SAVE_PATH = "_TEMP_MODELS/" #Path where models will be saved
os.makedirs(SAVE_PATH, exist_ok=True)

#Load Data

In [None]:
#NOTE: Load audio using librosa
# librosa.load(wavPath, sr = 16000) #AST required sample rate
original_train_audio = []
original_val_audio = []
original_test_audio = []

original_train_labels = []
original_val_labels = []
original_test_labels = []

train_size = len(original_train_labels)
val_size = len(original_val_labels)
test_size = len(original_test_labels)

all_audio = original_train_audio + original_val_audio + original_test_audio
all_labels = np.array(original_train_labels + original_val_labels + original_test_labels)

#LABELS
unique_labels=np.unique(original_train_labels)
num_class = len(unique_labels)

lEnc = LabelEncoder()
lEnc.fit(unique_labels)

print(unique_labels)
print(lEnc.transform(unique_labels))

all_targets = lEnc.transform(all_labels)

#Models

In [None]:
import torch
import torch.optim as optim
from torch import nn
from torchvision import models, transforms
from torchsummary import summary

In [None]:
from transformers import AutoConfig, AutoModelForSequenceClassification
from transformers import AutoFeatureExtractor, ASTForAudioClassification, ASTFeatureExtractor
from transformers import logging
logging.set_verbosity_error()

def get_model(checkpoint, num_class, args = None):

  if checkpoint == "MIT/ast-finetuned-audioset-10-10-0.4593":
    ast = AST(num_class, args = args)
    return ast
  else:
    raise Exception("Unknown checkpoint")

##AST

In [None]:
#Wraps AST into model class that returns dict
class AST(torch.nn.Module):
  def __init__(self, num_class, args = None):
    super(AST, self).__init__()

    #Load config
    config = AutoConfig.from_pretrained(checkpoint, num_labels = num_class)

    if args != None:
      config = AutoConfig.from_pretrained(checkpoint, num_labels = num_class, **args)
    else:
      config = AutoConfig.from_pretrained(checkpoint, num_labels = num_class)

    #Load pretrained model
    self.ast = ASTForAudioClassification.from_pretrained(checkpoint, ignore_mismatched_sizes = True, config = config)

    #Change last classification layer output to num classes
    # self.ast.classifier.dense = nn.Linear(in_features = 768, out_features = num_class, bias = True)

  def forward(self, input_values, labels = None):

    x = self.ast(input_values = input_values, labels = labels)
    return x

#Training

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels, ids):
        self.inputs = inputs
        self.labels = labels
        self.ids = ids

    def __getitem__(self, idx):
        if type(self.inputs) == dict:
          item = {key: val[idx] for key, val in self.inputs.items()}
        else:
          item = {"input": self.inputs[idx]}

        item["ids"] = self.ids[idx]
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
def get_norm_stats(max_ast_length, sampling_rate):
  temp = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", max_length = max_ast_length, do_normalize = False)
  temp_input = temp(all_audio, sampling_rate = sampling_rate, return_tensors = "pt")

  return torch.mean(temp_input["input_values"]), torch.std(temp_input["input_values"])

In [None]:
#Gather all inputs for all models
from transformers import AutoTokenizer

def build_inputs(audio, max_lengths = {}, return_max_len = False):
  inputs = {}

  MAX_AST = 512
  mean, std = get_norm_stats(MAX_AST, SAMPLING_RATE)

  #Extract features
  feature_extractor = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", max_length = MAX_AST, do_normalize = True, mean = float(mean), std = float(std))
  input = feature_extractor(audio, sampling_rate=SAMPLING_RATE, return_tensors="pt")
  print("AST -> Check mean 0, std 0.5:", torch.mean(input["input_values"]), torch.std(input["input_values"]))

  for k, v in input.items():
    dtype = torch.FloatTensor
    inputs[k] = dtype(v)

  if return_max_len:
    return inputs, max_lengths
  return inputs

In [None]:
def prepare_datasets(targets, audios, train_percent = 0.9, cv = False, train_idx = [], test_idx = [], show_result = False):

  """Generate splits"""
  np.random.seed(123)

  #Cross validation / Force split train set to train & val
  if cv == True:
    assert len(train_idx) > 0
    #assert len(test_idx) > 0 #If test_idx = none -> split only train/val, test with all data

    idx_train = np.random.choice(train_idx, int(len(train_idx) * train_percent), replace = False)
    idx_val = [x for x in train_idx if x not in idx_train]
    idx_test = test_idx

  #Train and test only
  elif (val_size == 0) and (test_size != 0):
    idx_train = np.random.choice(np.arange(train_size), int(train_size * train_percent), replace = False)
    idx_val = [x for x in np.arange(train_size) if x not in idx_train]
    idx_test = np.arange(train_size, len(all_targets))

  #Train, val, and test
  elif val_size != 0:
    idx_train = np.arange(0, train_size)
    idx_val = np.arange(train_size, train_size + val_size)
    idx_test = np.arange(train_size + val_size, len(all_targets))
  else:
    raise Exception("Unknown split.")

  print("Data Loader split:")
  print("  - Train:", len(idx_train))
  print("  - Val:", len(idx_val))
  print("  - Test:", len(idx_test))

  """Generate inputs"""
  print("Training inputs")
  train_inputs, max_lengths = build_inputs([audios[i] for i in idx_train], return_max_len = True) #Extract max_lengths to implement on val and test sets
  print("Validation inputs")
  val_inputs = build_inputs([audios[i] for i in idx_val], max_lengths = max_lengths)
  print("Test inputs")
  test_inputs = build_inputs([audios[i] for i in idx_test], max_lengths = max_lengths) if len(idx_test) > 0 else []

  """Prepare loaders"""

  train_targets = torch.LongTensor(targets[idx_train])
  val_targets = torch.LongTensor(targets[idx_val])
  test_targets = torch.LongTensor(targets[idx_test])

  #Shuffle is turned off to ensure multiple loaders loading the same samples in the same order
  train_loader = torch.utils.data.DataLoader(CustomDataset(train_inputs, train_targets, idx_train), shuffle=True, batch_size = BATCH_SIZE)
  val_loader = torch.utils.data.DataLoader(CustomDataset(val_inputs, val_targets, idx_val), shuffle = True, batch_size = BATCH_SIZE)
  test_loader = torch.utils.data.DataLoader(CustomDataset(test_inputs, test_targets, idx_test), shuffle = False, batch_size = BATCH_SIZE)

  if len(test_targets) == 0:
    return train_loader, val_loader
  return train_loader, val_loader, test_loader


In [None]:
def train_model(show_result = True, epochs = 3, early_stop = 10, scheduler = None):
    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        t = time.time()
        model.train()

        f1_batch_train = []
        acc_batch_train = []
        loss_batch_train = []
        for batch in train_loader:
          targets = batch["labels"].to(device)
          input_args = {k: v.to(device) for k, v in batch.items() if k != "ids"}

          output = model(**input_args)
          loss_train = criterion(output["logits"], targets)
          optimizer.zero_grad()
          loss_train.backward()
          optimizer.step()

          loss_batch_train.append(loss_train.item())
          acc_batch_train.append(cal_accuracy(output["logits"], targets))
          f1_batch_train.append(f1_score(targets.cpu(), torch.argmax(output["logits"].cpu(), axis = -1), average = "weighted"))

        train_loss.append(np.mean(loss_batch_train))

        model.eval()
        with torch.no_grad():
          loss_batch_val = []
          acc_batch_val = []
          f1_batch_val = []
          for batch in val_loader:
            targets = batch["labels"].to(device)
            input_args = {k: v.to(device) for k, v in batch.items() if k != "ids"}

            output = model(**input_args)
            loss_val = criterion(output["logits"], targets)

            loss_batch_val.append(loss_val.item())
            acc_batch_val.append(cal_accuracy(output["logits"], targets))
            f1_batch_val.append(f1_score(targets.cpu(), torch.argmax(output["logits"].cpu(), axis = -1), average = "weighted"))

        val_loss.append(np.mean(loss_batch_val))

        if scheduler != None:
          scheduler.step(np.mean(loss_batch_val))

        if show_result:
            print(  'Epoch: {:04d}'.format(epoch+1),
                    'loss_train: {:.4f}'.format(np.mean(loss_batch_train)),
                    'acc_train: {:.4f}'.format(np.mean(acc_batch_train)),
                    'f1w_train: {:.4f}'.format(np.mean(f1_batch_train)),
                    'loss_val: {:.4f}'.format(np.mean(loss_batch_val)),
                    'acc_val: {:.4f}'.format(np.mean(acc_batch_val)),
                    'f1w_val: {:.4f}'.format(np.mean(f1_batch_val)),
                    'time: {:.4f}s'.format(time.time() - t),
                    'lr:', optimizer.param_groups[0]["lr"], flush = True)

        if early_stop != None and early_stop != 0 and epoch > early_stop and np.min(val_loss[-early_stop:]) > np.min(val_loss[:-early_stop]) :
            if show_result:
                print("Early Stopping...")
            break

    plt.title("%s" % (checkpoint))
    plt.plot(train_loss, label = "train")
    plt.plot(val_loss, label = "Val")
    plt.legend()
    plt.savefig(filePath_trainplot)
    plt.show()

#Evaluation

In [None]:
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, hamming_loss, roc_auc_score
from scipy.special import softmax

In [None]:
def cal_accuracy(predictions,labels):
    pred = torch.argmax(predictions,-1).cpu().tolist()
    lab = labels.cpu().tolist()
    cor = 0
    for i in range(len(pred)):
        if pred[i] == lab[i]:
            cor += 1
    return cor/len(pred)

#Tuning

In [None]:
import optuna

In [None]:
def objective(trial):

  tune_dropout = trial.suggest_categorical("dropout", [0.01, 0.05, 0.1, 0.5])

  use_scheduler = False
  if checkpoint == "MIT/ast-finetuned-audioset-10-10-0.4593":
    tune_lr = trial.suggest_categorical("learning_rate", [1e-03, 1e-04, 1e-05, 5e-05])
    tune_layers = trial.suggest_int("num_hidden_layers", 2, 12, 2)
    tune_heads = trial.suggest_categorical("num_attention_heads", [ 2,  3,  4,  6,  8, 12]) #choose num heads % 768
    tune_patience = trial.suggest_int("scheduler_patience", 2, 5)
    tune_factor = trial.suggest_categorical("scheduler_factor", [0.1, 0.5])

    args = {
            "num_hidden_layers": tune_layers,
            "num_attention_heads": tune_heads,
            "hidden_dropout_prob": tune_dropout,
            "attention_dropout_prob": tune_dropout,
            "max_length": MAX_LENGTH
            }

    tune_model = get_model(checkpoint, num_class, args).to(device)
    tune_decay = 0
    tune_epochs = 25
    early_stop = 5
    use_scheduler = True

  else:
    raise Exception("Unsupported checkpoint")

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(tune_model.parameters(), lr = tune_lr, weight_decay = tune_decay)
  if use_scheduler:
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=tune_factor, patience=tune_patience, verbose=True)

  #Training
  val_loss = []
  for epoch in range(tune_epochs):
    t = time.time()
    tune_model.train()

    for batch in train_loader:
      targets = batch["labels"].to(device)
      input_args = {k: v.to(device) for k, v in batch.items() if k != "ids"}

      output = tune_model(**input_args)
      loss_train = criterion(output["logits"], targets)
      optimizer.zero_grad()
      loss_train.backward()
      optimizer.step()

    tune_model.eval()
    with torch.no_grad():
      f1_batch_val = []
      loss_batch_val = []
      for batch in val_loader:
        targets = batch["labels"].to(device)
        input_args = {k: v.to(device) for k, v in batch.items() if k != "ids"}

        output = tune_model(**input_args)
        loss_val = criterion(output["logits"], targets)

        loss_batch_val.append(loss_val.item())
        f1_batch_val.append(f1_score(targets.cpu(), torch.argmax(output["logits"].cpu(), axis = -1), average = "weighted"))

    val_loss.append(np.mean(loss_batch_val))
    f1_val = np.mean(f1_batch_val)

    scheduler.step(np.mean(loss_batch_val))

    #Record metric
    trial.report(f1_val, epoch)

    if early_stop != None and early_stop != 0 and epoch > early_stop and np.min(val_loss[-early_stop:]) > np.min(val_loss[:-early_stop]) :
      break

    # Handle pruning based on the intermediate value.
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()

  return f1_val

In [None]:
def tune_parameters(n_trials = 50):

  study = optuna.create_study(direction = "maximize")

  if checkpoint == "MIT/ast-finetuned-audioset-10-10-0.4593":
    study.enqueue_trial({"dropout": 0.5,  #default parameters
                        "num_hidden_layers": 12,
                        "num_attention_heads": 12,
                        "learning_rate": 5e-05,
                        "weight_decay": 0,
                        "scheduler_patience": 5,
                        "scheduler_factor": 0.5})
  else:
    raise Exception("Unsupported checkpoint")
  study.optimize(objective, n_trials = n_trials)

  return study

#AST

In [None]:
checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
filePath_model = "%s/%s.pt" % (SAVE_PATH, checkpoint.replace("/", "_"))
filePath_log = "%s/%s.log" % (SAVE_PATH, checkpoint.replace("/", "_"))
filePath_trainplot = "%s/%s.png" % (SAVE_PATH, checkpoint.replace("/", "_"))
BATCH_SIZE = 32
MAX_LENGTH = 512

In [None]:
train_loader, val_loader = prepare_datasets(all_targets, all_audio, cv = True, train_idx = np.arange(len(data.texts)))

##Tuning

In [None]:
start = datetime.now()
print("Tuning...", flush = True)
study = tune_parameters(n_trials = 2)
print("Total tuning time: %s\n" % (datetime.now() - start), flush = True)

best_trial = study.best_trial
best_params = best_trial.params

print("BEST:", best_trial.value)
print("Params:")
for key, value in best_params.items():
  print("    {}: {}".format(key, value))

##Training

In [None]:
args = {"num_hidden_layers": best_params["num_hidden_layers"],
        "num_attention_heads": best_params["num_attention_heads"],
        "hidden_dropout_prob": best_params["dropout"],
        "attention_dropout_prob": best_params["dropout"],
        "max_length": MAX_LENGTH}

model = get_model(checkpoint, num_class, args).to(device)
optimizer = optim.Adam(model.parameters(), lr = best_params["learning_rate"]) #, weight_decay = best_params["weight_decay"])
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=best_params["scheduler_factor"], patience=best_params["scheduler_patience"], verbose=True)

print("=" * 20, "MODEL CONFIG", "=" * 20, flush = True)
print(model)

train_model(epochs = 25, early_stop = 5, scheduler = scheduler)

torch.save(model, filePath_model) #Save model
torch.cuda.empty_cache()