#Setup

In [None]:
import pickle
import numpy as np
import pandas as pd

import os
import time
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from datetime import datetime

from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sn
from tabulate import tabulate

import torch
from torch import nn
from torchvision import models, transforms
from torchsummary import summary

# from tqdm.notebook import tqdm
from tqdm import tqdm

from sklearn.preprocessing import LabelEncoder

In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

#Set Variables

In [None]:
#Teacher checkpoints
#Select at least 1:
#- BERT: "bert-base-uncased"
#- RoBERTa: "roberta-base"
#- MentalBERT: "mental/mental-bert-base-uncased"
#- MMEMOG: "custom/MM-EMOG-SenticNet"
#- AST: "MIT/ast-finetuned-audioset-10-10-0.4593"
teacher_checkpoints = ["roberta-base", "custom/MM-EMOG-SenticNet", "MIT/ast-finetuned-audioset-10-10-0.4593"]

#Teacher model paths (path to each saved pytorch teacher models)
model_paths = {
  "bert-base-uncased": "",
  "roberta-base": "",
  "mental/mental-bert-base-uncased": "",
  "custom/MM-EMOG-SenticNet": "",
  "MIT/ast-finetuned-audioset-10-10-0.4593": ""
}
teacher_model_paths = [model_paths[x] for x in teacher_checkpoints]

#Student model type
student_checkpoint = "Transformer"

#Session Title. For file saving.
sessionTitle = "Multi-modal Multi-teacher Knowledge Distillation"

#Load data

In [None]:
#Data Preparation
original_train_sentences = []
original_val_sentences = []
original_test_sentences = []

original_train_labels = []
original_val_labels = []
original_test_labels = []

#NOTE: Load audio using librosa
# librosa.load(wavPath, sr = 16000) #AST required sample rate
original_train_audio = []
original_val_audio = []
original_test_audio = []

assert len(original_train_sentences) == len(original_train_labels) == len(original_train_audio)
assert len(original_val_sentences) == len(original_val_labels) == len(original_val_audio)
assert len(original_test_sentences) == len(original_test_labels) == len(original_test_audio)

train_size = len(original_train_sentences)
val_size = len(original_val_sentences)
test_size = len(original_test_sentences)

train_idx = np.arange(0, train_size)
val_idx = np.arange(train_size, train_size + val_size)
test_idx = np.arange(train_size + val_size, train_size + val_size + test_size)

all_sentences = np.array(original_train_sentences + original_val_sentences + original_test_sentences)
all_labels = np.array(original_train_labels + original_val_labels + original_test_labels)
all_audio = original_train_audio + original_val_audio + original_test_audio

#Label Encoding
unique_labels = np.unique(original_train_labels)
num_class = len(unique_labels)

lEnc = LabelEncoder()
lEnc.fit(unique_labels)

print(unique_labels)
print(lEnc.transform(unique_labels))

all_targets = lEnc.transform(all_labels)

#Load Resources

In [None]:
#Load list of emoticons
#Source: https://c.r74n.com/faces

with open("TextEmoticonList.txt", "r") as file:
  emoticonList = file.read().split("\n")

#Remove emoticons with spaces in-between
emoticonList = [emoticon for emoticon in emoticonList if len(emoticon.split(" ")) == 1]

#Remove one character emoticons
emoticonList = [emoticon for emoticon in emoticonList if len(emoticon) > 1]

print(len(emoticonList))
print(emoticonList[:10])

In [None]:
#Load list of emojis
#Source: https://www.airtable.com/universe/exphjm5ifnV0bX4Kb/emojis-database?explore=true

emojiList = pd.read_csv("Emojis-Grid view.csv")
emojiList = emojiList[emojiList["Emoji"] != "C"]
emojiList = emojiList["Emoji"].tolist()

#Unicode versions
emojiList_uni = [emoji.encode('unicode-escape').decode('ASCII') for emoji in emojiList]

print(len(emojiList))
print(emojiList[:10])
print(emojiList_uni[:10])

# Preprocess

##Text

In [None]:
#FLAGS
DEIDENTIFY = True     #Replace urls, emails, and usernames
EMOPRESERVE = True    #Identify emojis/emoticons on text and skip text cleaning on them
TEXTCLEAN = False     #Minimal cleaning of separating certain conjunctions
TOKEN_TYPE = "wp"     #wp: word piece (BERT Tokenizer); ws: word split

In [None]:
import re

tokenURL = "_URL_"
tokenEmail = "_EMAIL_"
tokenUsername = "_USER_"
reserveTokens = [tokenURL, tokenEmail, tokenUsername]

#CLEANING PROCESS
#- Include emojis and emoticons
#- Replace url, email, and usernames with tokens
#- Remove non-major puncutations and separate them from words with whitespaces
#- Lowercase
def preprocess_str(string):

  #Preclean
  if DEIDENTIFY:
    string = re.sub(r"https?://[^\s]+", tokenURL, string)              #Links
    string = re.sub(r"[\w.+-]+@[\w-]+\.[\w.-]+", tokenEmail, string)   #Email
    string = re.sub(r"@[a-zA-Z0-9_]{2,}", tokenUsername, string)       #Usernames

  #Emoticon/Emoji split
  tokens = [string]
  if EMOPRESERVE:
    allEmo = emoticonList + emojiList + emojiList_uni + reserveTokens
    for emoticon in allEmo:
      regEx = "(^|\s)" + re.escape(emoticon) + "(\s|$)" if emoticon.isalpha() else re.escape(emoticon)
      if emoticon in string:
        splits = []
        for split in tokens:
          splits.append(re.split(r"(" + regEx + ")", split))
        tokens = [y.strip() for x in splits for y in x if y != ""]

  for idx in range(len(tokens)):
    if EMOPRESERVE and tokens[idx] in allEmo: #Skip emoticons, emojis
      continue

    if TEXTCLEAN:
      tokens[idx] = re.sub(r"[^A-Za-z0-9(),!?\.\'\`]", " ", tokens[idx])
      tokens[idx] = re.sub(r"\'s", " \'s", tokens[idx])
      tokens[idx] = re.sub(r"\'ve", " \'ve", tokens[idx])
      tokens[idx] = re.sub(r"n\'t", " n\'t", tokens[idx])
      tokens[idx] = re.sub(r"\'re", " \'re", tokens[idx])
      tokens[idx] = re.sub(r"\'d", " \'d", tokens[idx])
      tokens[idx] = re.sub(r"\'ll", " \'ll", tokens[idx])
      tokens[idx] = re.sub(r",", " , ", tokens[idx])
      tokens[idx] = re.sub(r"!", " ! ", tokens[idx])
      tokens[idx] = re.sub(r"\(", " ( ", tokens[idx])
      tokens[idx] = re.sub(r"\)", " ) ", tokens[idx])
      tokens[idx] = re.sub(r"\?", " ? ", tokens[idx])
      tokens[idx] = re.sub(r"\.", " . ", tokens[idx])
      tokens[idx] = re.sub(r"\s{2,}", " ", tokens[idx])

    #Lower case and strip by default
    tokens[idx] = tokens[idx].lower().strip()

  return " ".join(tokens)

##Tokenizer

In [None]:
def get_tokenizer(token_type, checkpoint = None):
  if token_type.lower() == "wp":
    if checkpoint in [None, "bert-base-uncased", "custom/MM-EMOG-SenticNet"]:
      from transformers import BertTokenizer
      tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    else:
      from transformers import AutoTokenizer
      tokenizer = AutoTokenizer.from_pretrained(checkpoint)

    if DEIDENTIFY:
      tokenizer.add_tokens(reserveTokens)

    if EMOPRESERVE:
      #Add spaces to alpha emotions to avoid splitting words that commonly has them (ie "omo" in "tomorrow")
      temp = [" %s " % x if x.isalpha() else x for x in emoticonList]
      tokenizer.add_tokens(temp + emojiList + emojiList_uni)

    return tokenizer
  elif token_type.lower() == "ws":
    return string.split()
  else:
    raise Exception("Unknown value for TOKEN_TYPE")

##Audio: AST

In [None]:
def get_norm_stats(checkpoint, max_ast_length, sampling_rate):
  temp = ASTFeatureExtractor.from_pretrained(checkpoint, max_length = max_ast_length, do_normalize = False)
  temp_input = temp(all_audio, sampling_rate = sampling_rate, return_tensors = "pt")

  return torch.mean(temp_input["input_values"]), torch.std(temp_input["input_values"])

# Models

In [None]:
from transformers import AutoConfig, AutoModelForSequenceClassification
from transformers import AutoFeatureExtractor, ASTForAudioClassification, ASTFeatureExtractor
from transformers import logging
logging.set_verbosity_error()

def get_plm(checkpoint, num_class, args = None):

  if checkpoint in ["bert-base-uncased", "roberta-base", "mental/mental-bert-base-uncased"]:
    config = AutoConfig.from_pretrained(checkpoint, num_labels = num_class)

    if args != None:
      config = AutoConfig.from_pretrained(checkpoint, num_labels = num_class, **args)
    else:
      config = AutoConfig.from_pretrained(checkpoint, num_labels = num_class)

    return AutoModelForSequenceClassification.from_pretrained(checkpoint, config = config)
  elif checkpoint == "MIT/ast-finetuned-audioset-10-10-0.4593":
    ast = AST(num_class, args = args)
    return ast
  # elif checkpoint == "vgg":
  #   vgg = VGG(num_class)
  #   return vgg
  elif checkpoint.split("/")[0] == "custom":
    assert "pt_weights" in args
    assert "pt_weights_dim" in args

    return MLP(num_class = num_class, **args)
  else:
    raise Exception("Unknown checkpoint")

In [None]:
def get_student_model(st_checkpoint, params = None):
  # if st_checkpoint == "MLP":
  #   return MLP(pt_weights, pt_weights.shape[-1], num_class, params["num_layers"], params["num_hidden"], params["dropout"], max_length = actual_max)

  if st_checkpoint == "Transformer":
    if params != None:
      assert "num_hidden_layers" in params
      assert "num_attention_heads" in params
      assert "hidden_dropout_prob" in params
      assert "hidden_act" in params

      config = AutoConfig.from_pretrained("bert-base-uncased", num_labels = num_class, **params)
    else:
      config = AutoConfig.from_pretrained("bert-base-uncased", num_labels = num_class)

    model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", config = config)
    model.resize_token_embeddings(len(student_tokenizer))
    return model

  raise Exception("Unknown student checkpoint.")

##AST

In [None]:
#Wraps AST into model class that returns dict
class AST(torch.nn.Module):
  def __init__(self, num_class, args = None):
    super(AST, self).__init__()

    #Load config
    config = AutoConfig.from_pretrained(checkpoint, num_labels = num_class)

    if args != None:
      config = AutoConfig.from_pretrained(checkpoint, num_labels = num_class, **args)
    else:
      config = AutoConfig.from_pretrained(checkpoint, num_labels = num_class)

    #Load pretrained model
    self.ast = ASTForAudioClassification.from_pretrained(checkpoint, ignore_mismatched_sizes = True, config = config)

    #Change last classification layer output to num classes
    # self.ast.classifier.dense = nn.Linear(in_features = 768, out_features = num_class, bias = True)

  def forward(self, input_values, labels = None):

    x = self.ast(input_values = input_values, labels = labels)
    return x

##MLP

In [None]:
class MLP(torch.nn.Module):
  def __init__(self, pt_weights, pt_weights_dim, num_class, num_layers, hidden_dim, dropout, max_length):
    super(MLP, self).__init__()

    self.num_layers = num_layers

    #Load Embeddings
    self.embeddings = nn.Embedding.from_pretrained(pt_weights)

    #MLP
    if self.num_layers >= 2:
      self.l1 = nn.Linear(max_length * pt_weights_dim, hidden_dim)
      self.r1 = nn.ReLU()
      self.d1 = nn.Dropout(dropout)

      moduleList = []
      for _ in range(num_layers - 2):
        moduleList.append(nn.Linear(hidden_dim, hidden_dim))
        moduleList.append(nn.ReLU())
        moduleList.append(nn.Dropout(dropout))

      self.mod_list = nn.ModuleList(moduleList)
      self.lf = nn.Linear(hidden_dim, num_class)
      self.rf = nn.ReLU()
      self.df = nn.Dropout(dropout)
    else:
      self.l1 = nn.Linear(max_length * pt_weights_dim, num_class)
      self.r1 = nn.ReLU()
      self.d1 = nn.Dropout(dropout)
    # self.softmax = nn.Softmax()

  def forward(self, input_ids):

    #Generate embeddings
    features = self.embeddings(input_ids)      # embedded = [batch size, sent_len, emb dim]

    # features = torch.cat((bert_features, mmemog_features), axis = -1)
    x = features.view(features.shape[0], -1)  #Flatten

    if self.num_layers >= 2:
      x = self.l1(x)
      x = self.r1(x)
      x = self.d1(x)

      for i in range(self.num_layers - 2):
        x = self.mod_list[i](x)

      x = self.lf(x)
      x = self.rf(x)
      x = self.df(x)
    else:
      x = self.l1(x)
      x = self.r1(x)
      x = self.d1(x)
    # output = self.softmax(x)
    output = {"logits": x}
    return output

# Training Functions

## Initialize model

In [None]:
class MTKDDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels, ids):
        self.inputs = inputs
        self.labels = labels
        self.ids = ids

    def __getitem__(self, idx):
        if type(self.inputs) == dict:
          item = {key: val[idx] for key, val in self.inputs.items()}
        else:
          item = {"input": self.inputs[idx]}

        item["ids"] = self.ids[idx]
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
import os
os.environ['CURL_CA_BUNDLE'] = ''

In [None]:
#Gather all inputs for all models
from transformers import AutoTokenizer

student_tokenizer = get_tokenizer("wp", "bert-base-uncased")

def build_inputs(sentences, audio, max_lengths = {}, return_max_len = False):
  inputs = {}
  clean_sentences = [preprocess_str(x) for x in sentences]

  #Teacher inputs
  for tcheck in teacher_checkpoints:
    print("Building inputs for", tcheck)
    if tcheck not in ["bert-base-uncased", "roberta-base", "mental/mental-bert-base-uncased", "custom/MM-EMOG-SenticNet", "MIT/ast-finetuned-audioset-10-10-0.4593"]:
      raise Exception("Unhandled checkpoint")

    if tcheck == "MIT/ast-finetuned-audioset-10-10-0.4593":
      MAX_AST = 512
      mean, std = get_norm_stats(tcheck, MAX_AST, SAMPLING_RATE)

      #Extract features
      feature_extractor = ASTFeatureExtractor.from_pretrained(tcheck, max_length = MAX_AST, do_normalize = True, mean = float(mean), std = float(std))
      input = feature_extractor(audio, sampling_rate=SAMPLING_RATE, return_tensors="pt")
      print("AST -> Check mean 0, std 0.5:", torch.mean(input["input_values"]), torch.std(input["input_values"]))

    else: #Standard PLM inputs
      targs = {}
      ttokenizer = get_tokenizer("wp", tcheck)
      if tcheck == "custom/MM-EMOG-SenticNet":
        targs["return_token_type_ids"] = False
        targs["return_attention_mask"] = False

      if tcheck not in max_lengths.keys():
        actual_max = min(MAX_LENGTH, max([len(ttokenizer.tokenize(x)) for x in clean_sentences]))
        max_lengths[tcheck] = actual_max
      else:
        actual_max = max_lengths[tcheck]
      input = ttokenizer(clean_sentences, padding = "max_length", truncation = True, max_length = actual_max, **targs)

    for k, v in input.items():
      dtype = torch.FloatTensor if tcheck == "MIT/ast-finetuned-audioset-10-10-0.4593" else torch.LongTensor
      inputs[tcheck + "|" + k] = dtype(v)

  #Build student inputs
  #BERT model with BERT inputs
  if student_checkpoint == "Transformer":
    if "student" not in max_lengths.keys():
      actual_max = min(MAX_LENGTH, max([len(student_tokenizer.tokenize(x)) for x in clean_sentences]))
      max_lengths["student"] = actual_max
    else:
      actual_max = max_lengths["student"]
    student_input = student_tokenizer(clean_sentences, padding = "max_length", truncation = True, max_length = actual_max)

  else:
    raise Exception("Unknown student checkpoint.")

  for k, v in student_input.items():
    inputs["student|" + k] = torch.LongTensor(v)

  if return_max_len:
    return inputs, max_lengths
  return inputs

In [None]:
def prepare_datasets(sentences, targets, audios, train_percent = 0.9, cv = False, train_idx = [], test_idx = [], show_result = False):

  """Generate splits"""
  np.random.seed(123)

  #Cross validation / Force split train set to train & val
  if cv == True:
    assert len(train_idx) > 0
    #assert len(test_idx) > 0 #If test_idx = none -> split only train/val, test with all data

    idx_train = np.random.choice(train_idx, int(len(train_idx) * train_percent), replace = False)
    idx_val = [x for x in train_idx if x not in idx_train]
    idx_test = test_idx

  #Train and test only
  elif (val_size == 0) and (test_size != 0):
    idx_train = np.random.choice(np.arange(train_size), int(train_size * train_percent), replace = False)
    idx_val = [x for x in np.arange(train_size) if x not in idx_train]
    idx_test = np.arange(train_size, len(all_sentences))

  #Train, val, and test
  elif val_size != 0:
    idx_train = np.arange(0, train_size)
    idx_val = np.arange(train_size, train_size + val_size)
    idx_test = np.arange(train_size + val_size, len(all_sentences))
  else:
    raise Exception("Unknown split.")

  print("Data Loader split:")
  print("  - Train:", len(idx_train))
  print("  - Val:", len(idx_val))
  print("  - Test:", len(idx_test))

  """Generate inputs"""
  print("Training inputs")
  train_inputs, max_lengths = build_inputs(sentences[idx_train], [audios[i] for i in idx_train], return_max_len = True) #Extract max_lengths to implement on val and test sets
  print("Validation inputs")
  val_inputs = build_inputs(sentences[idx_val], [audios[i] for i in idx_val], max_lengths = max_lengths)
  print("Test inputs")
  test_inputs = build_inputs(sentences[idx_test], [audios[i] for i in idx_test], max_lengths = max_lengths) if len(idx_test) > 0 else []

  """Prepare loaders"""

  train_targets = torch.LongTensor(targets[idx_train])
  val_targets = torch.LongTensor(targets[idx_val])
  test_targets = torch.LongTensor(targets[idx_test])

  #Shuffle is turned off to ensure multiple loaders loading the same samples in the same order
  train_loader = torch.utils.data.DataLoader(MTKDDataset(train_inputs, train_targets, idx_train), shuffle=True, batch_size = BATCH_SIZE)
  val_loader = torch.utils.data.DataLoader(MTKDDataset(val_inputs, val_targets, idx_val), shuffle = True, batch_size = BATCH_SIZE)
  test_loader = torch.utils.data.DataLoader(MTKDDataset(test_inputs, test_targets, idx_test), shuffle = False, batch_size = BATCH_SIZE)

  if len(test_targets) == 0:
    return train_loader, val_loader
  return train_loader, val_loader, test_loader


##Train

In [None]:
import time
import torch.optim as optim
import torch.nn.functional as F

In [None]:
criterion_ce = torch.nn.CrossEntropyLoss()
criterion_kl = torch.nn.KLDivLoss()
weight_ce = 1.0
weight_kl = 1.0

#Training student with multi-teacher
def train_student(st_model, optimizer, epoch, show_result = True):
  loss_batch_train = []
  acc_batch_train = []
  f1_batch_train = []
  loss_batch_val = []
  acc_batch_val = []
  f1_batch_val = []

  t = time.time()

  #TRAINING
  st_model.train()
  for i, batch in enumerate(train_loader):

    targets = batch["labels"].to(device)
    current_batch_size = len(batch["labels"])

    #Compute teacher outputs
    tm_scores = []
    # hint_maps = []
    for check in teacher_checkpoints:
      #Get teacher inputs
      input_args = {k.split(check + "|")[-1]: v.to(device) for k, v in batch.items() if k.startswith(check)}
      tm = teacher_models[check]
      tm.eval()

      with torch.no_grad():
        tm_output = tm(**input_args)["logits"]

        tm_output = F.softmax(tm_output, dim = -1)
        tm_scores.append(tm_output)

    tm_scores_Tensor = torch.stack(tm_scores, dim = 1)
    mean_scores = torch.mean(tm_scores_Tensor, dim = 1)

    #Compute student outputs
    input_args = {k.split("student|")[-1]: v.to(device) for k, v in batch.items() if k.startswith("student")}
    st_output = st_model(**input_args)["logits"]

    #Compute gradients - MFH computation
    kd_loss = weight_ce * criterion_ce(st_output, targets) + weight_kl * criterion_kl(torch.log_softmax(st_output, dim=1), mean_scores)
    loss = kd_loss

    optimizer.zero_grad()
    loss.backward(retain_graph = True)
    optimizer.step()

    st_output = st_output.float()
    loss = loss.float()

    train_acc = accuracy(st_output.data, targets.data)[0]
    train_f1 = f1_score(targets.cpu(), np.argmax(st_output.cpu().detach().numpy(), axis = -1), average = "weighted")

    loss_batch_train.append(loss.item())
    acc_batch_train.append(train_acc.cpu())
    f1_batch_train.append(train_f1)

  #VALIDATION
  st_model.eval()
  for i, batch in enumerate(val_loader):

    targets = batch["labels"].to(device)
    current_batch_size = len(batch["labels"])

    #Compute teacher outputs
    tm_scores = []
    # hint_maps = []
    for check in teacher_checkpoints:
      #Get teacher inputs
      input_args = {k.split(check + "|")[-1]: v.to(device) for k, v in batch.items() if k.startswith(check)}
      tm = teacher_models[check]
      tm.eval()
      with torch.no_grad():
        tm_output = tm(**input_args)["logits"]

        tm_output = F.softmax(tm_output, dim = -1)
        tm_scores.append(tm_output)

    tm_scores_Tensor = torch.stack(tm_scores, dim = 1)
    mean_scores = torch.mean(tm_scores_Tensor, dim = 1)

    #Compute student outputs
    input_args = {k.split("student|")[-1]: v.to(device) for k, v in batch.items() if k.startswith("student")}
    with torch.no_grad():
      st_output = st_model(**input_args)["logits"]

      #Compute gradients - MFH computation
      kd_loss = weight_ce * criterion_ce(st_output, targets) + weight_kl * criterion_kl(torch.log_softmax(st_output, dim=1), mean_scores)
      loss = kd_loss

      val_acc = accuracy(st_output.data, targets.data)[0]
      val_f1 = f1_score(targets.cpu(), np.argmax(st_output.cpu().detach().numpy(), axis = -1), average = "weighted")

      loss_batch_val.append(loss.item())
      acc_batch_val.append(val_acc.cpu())
      f1_batch_val.append(val_f1)

  if show_result:
    print(  'Epoch: {:04d}'.format(epoch+1),
            'loss_train: {:.4f}'.format(np.mean(loss_batch_train)),
            'acc_train: {:.4f}'.format(np.mean(acc_batch_train)),
            'f1w_train: {:.4f}'.format(np.mean(f1_batch_train)),
            'loss_val: {:.4f}'.format(np.mean(loss_batch_val)),
            'acc_val: {:.4f}'.format(np.mean(acc_batch_val)),
            'f1w_val: {:.4f}'.format(np.mean(f1_batch_val)),
            'time: {:.4f}s'.format(time.time() - t))

  #return losses for early stopping
  return np.mean(loss_batch_train), np.mean(loss_batch_val), np.mean(acc_batch_train), np.mean(acc_batch_val), np.mean(f1_batch_train), np.mean(f1_batch_val)

##Test

In [None]:
def test_model():
    model.eval()

    acc_batch_val = []
    output = []
    gold = []
    for x_batch, y_batch in test_loader:
      out = model(x_batch)
      output.extend(out.cpu().detach().numpy() )
      gold.extend(y_batch.cpu().numpy())

    return np.array(gold), np.array(output)

## Evaluation

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [None]:
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, hamming_loss, roc_auc_score
from scipy.special import softmax

def evaluate_output(outputs, targets, targetLabels, title = "", show_results = True, showClassMatrix = False, multiLabel = False, return_results = False):

    if multiLabel:
      preds = np.array(outputs) >= 0.5
    else:
      preds = np.argmax(outputs, axis = 1)
      probs = outputs if np.sum(outputs[0]) == 1 else softmax(outputs, axis = -1) #Check if probabilities

    accuracy = accuracy_score(targets, preds)
    f1_score_micro = f1_score(targets, preds, average='micro')
    f1_score_macro = f1_score(targets, preds, average='macro')
    f1_score_weighted = f1_score(targets, preds, average="weighted")
    if len(targetLabels) > 2:
      roc_scores = roc_auc_score(targets, probs, average = None, multi_class = "ovr")
      roc_type = "Average"
    else:
      roc_scores = roc_auc_score(targets, probs[:,1])
      roc_type = "Positive Class (%s)" % targetLabels[1]


    if show_results:
      print()
      print("=" * 50)
      print(title)
      print("=" * 50)
      print("Accuracy Score: %.4f" % (accuracy))
      print("F1 Score (Micro): %.4f" % (f1_score_micro))
      print("F1 Score (Macro): %.4f" % (f1_score_macro))
      print("F1 Score (Weighted): %.4f" % (f1_score_weighted))
      print("ROC AUC (%s): %.4f" % (roc_type, np.mean(roc_scores)))
      print()

      if multiLabel:
        ham_loss = hamming_loss(targets, preds)
        print("Hamming Loss: %.4f" % (ham_loss))

      print(targetLabels)
      print(classification_report(targets, preds, target_names = targetLabels, digits = 4))
      if len(targetLabels) > 2:
        print("\nROCAUC")
        print(tabulate([roc_scores], headers = targetLabels, floatfmt=".4f"))

    if showClassMatrix:
      if multiLabel:
        cmArray = multilabel_confusion_matrix(targets, preds)
        for i in range(len(targetLabels)):
          plotConfusionMatrix(cmArray[i], [0,1], targetLabels[i])
      else:
        cmArray = confusion_matrix(targets, preds)
        plotConfusionMatrix(cmArray, np.unique(targetLabels), title)

    if return_results:
      results = {"Accuracy": accuracy,
              "F1_Micro": f1_score_micro,
              "F1_Macro": f1_score_macro,
              "F1_Weighted": f1_score_weighted,
              "Class Precision": precision_score(targets, preds, average = None),
              "Class Recall": recall_score(targets, preds, average = None),
              "Class F1": f1_score(targets, preds, average = None),
      }

      if len(targetLabels) > 2:
        results["Class ROCAUC"] = roc_scores
      else:
        results["ROC_Class1:"] = roc_scores

      return results


In [None]:
def save_predictions(ids, targets, outputs, extra_title = ""):
  saveDF = pd.DataFrame({"Texts": all_sentences[ids],
                        "Labels": lEnc.inverse_transform(targets),
                        "Preds": lEnc.inverse_transform(np.argmax(outputs, axis = -1))},
                        index = ids)
  for i, c in enumerate(lEnc.classes_):
    saveDF[c] = outputs[:, i]

  saveDF = saveDF.sort_index()
  saveDF.to_csv("%s_%s.csv" % (sessionTitle, extra_title), index = False)

In [None]:
def plotConfusionMatrix(cmArray, labels, title, savePath = ""):

  df_cm = pd.DataFrame(cmArray,
                      index = labels,
                      columns = labels)

  plt.figure(figsize = (5,4))
  plt.title(title)
  sn.heatmap(df_cm, annot=True, cmap="YlGnBu", fmt='g')

  if savePath != "":
    # plt.tight_layout()
    plt.savefig("%s_results.png" % savePath, bbox_inches = "tight")


#Tuning Functions

In [None]:
import optuna
def objective(trial):

  tune_dropout = trial.suggest_categorical("dropout", [0.01, 0.05, 0.1, 0.5])

  if student_checkpoint == "Transformer":
    tune_lr = trial.suggest_categorical("learning_rate", [1e-04, 1e-05, 2e-05, 3e-05, 4e-05, 5e-05])
    tune_decay = trial.suggest_categorical("weight_decay", [0, 0.01, 0.1])
    tune_layers = trial.suggest_int("num_hidden_layers", 2, 12, 2)
    tune_heads = trial.suggest_categorical("num_attention_heads", [ 2,  3,  4,  6,  8, 12]) #choose num heads % 768
    tune_act = trial.suggest_categorical("hidden_act", ["relu", "gelu"])

    tune_epochs = trial.suggest_int("num_epochs", 3, 5)
    tune_stop = None

    args = {
            "num_hidden_layers": tune_layers,
            "num_attention_heads": tune_heads,
            "hidden_dropout_prob": tune_dropout,
            "hidden_act": tune_act
            }
    tune_model = get_student_model(student_checkpoint, args).to(device)
    tune_model.resize_token_embeddings(len(student_tokenizer))         #Resize vocab for added emojis and reserved tokens
    optimizer = optim.Adam(tune_model.parameters(), lr = tune_lr, weight_decay = tune_decay)

  else:
    raise Exception("Unknown student checkpoint.")

  #Training
  train_loss = []
  val_loss = []

  for epoch in range(tune_epochs):
    tLoss, vLoss, tAcc, vAcc, tF1, f1_val = train_student(tune_model, optimizer, epoch, show_result = False)
    train_loss.append(tLoss)
    val_loss.append(vLoss)

    if tune_stop != None and epoch > tune_stop and np.min(val_loss[-tune_stop:]) > np.min(val_loss[:-tune_stop]) :
      # if show_result:
      #     print("Early Stopping at epoch %d" % (epoch + 1))
      break

    #Record metric
    trial.report(f1_val, epoch)

    # Handle pruning based on the intermediate value.
    if trial.should_prune():
        raise optuna.exceptions.TrialPruned()

  return f1_val


In [None]:
def tune_parameters(n_trials = 20):
  study = optuna.create_study(direction = "maximize")

  if student_checkpoint == "Transformer":
    study.enqueue_trial({"dropout": 0.5,
                        "learning_rate": 1e-04,
                        "num_hidden_layers": 12,
                        "num_attention_heads": 12,
                        "weight_decay": 0,
                        "num_epochs": 3,
                        "hidden_act": "relu"})
  else:
    raise Exception("Unknown student checkpoint.")
  study.optimize(objective, n_trials = n_trials)

  return study

#Knowledge Distillation

##Initialize Teachers

In [None]:
from transformers import AutoModel

#Load teacher models
teacher_models = {}
print("Teacher models:")
for check, path in zip(teacher_checkpoints, teacher_model_paths):
  if path.split(".")[-1] == "pt": #standard model
    tm = torch.load(path)
  else: #pretrained model
    tm = AutoModel.from_pretrained(path)
  print("->", check)
  tm.to(device)
  tm.eval()
  teacher_models[check] = tm

#Training Setup

In [None]:
os.makedirs("_OUTPUT", exist_ok = True)

In [None]:
BATCH_SIZE = 32
MAX_LENGTH = 256 #Adjusts programmatically if shorter

NUM_RUNS = 10
show_result = True

#Hyperparameter Selection

In [None]:
#Load tuning data
#CV = True and train_idx = arange all data to force 90/10 split for tuning
train_loader, val_loader = prepare_datasets(all_sentences, all_targets, all_audio, cv = True, train_idx = np.arange(len(all_sentences)))

In [None]:
#TUNE
start = datetime.now()
print("Tuning...", flush = True)
study = tune_parameters(n_trials = 50)
print("Total tuning time: %s\n" % (datetime.now() - start), flush = True)

best_trial = study.best_trial
best_params = best_trial.params

print("BEST:", best_trial.value)
print("Params:")
for key, value in best_params.items():
  print("    {}: {}".format(key, value))

#Run Training


##Train & Evaluate Student

In [None]:
#Loaders
train_loader, val_loader, test_loader = prepare_datasets(all_sentences, all_targets, all_audio)

In [None]:
best_run = {"F1_Weighted": 0}
all_runs = {}
for run in range(NUM_RUNS):
  print("=" * 50)
  print("RUN %d" % (run + 1))
  print("=" * 50)
  start = datetime.now()

  if student_checkpoint == "Transformer":
    args = {
            "num_hidden_layers": best_params["num_hidden_layers"],
            "num_attention_heads": best_params["num_attention_heads"],
            "hidden_dropout_prob": best_params["dropout"],
            "hidden_act": best_params["hidden_act"]
            }
    student_model = get_student_model(student_checkpoint, params = args).to(device)

  optimizer = optim.Adam(student_model.parameters(), lr=best_params["learning_rate"], weight_decay=best_params["weight_decay"] if "weight_decay" in best_params else 0)
  print(student_model.config)

  #Training
  train_loss = []
  val_loss = []

  for epoch in range(best_params["num_epochs"] if "num_epochs" in best_params else NUM_EPOCHS):
    tLoss, vLoss, tAcc, vAcc, tF1, vF1  = train_student(student_model, optimizer, epoch)
    train_loss.append(tLoss)
    val_loss.append(vLoss)

  # plt.plot(train_loss, label = "train")
  # plt.plot(val_loss, label = "val")
  # plt.legend()
  # plt.show()

  #Testing - student only
  st_outputs = []
  st_targets = []
  st_ids = []
  for batch in test_loader:
    input_args = {k.split("student|")[-1]: v.to(device) for k, v in batch.items() if k.startswith("student")}
    with torch.no_grad():
      st_output = student_model(**input_args)["logits"]
      st_outputs.extend(st_output.cpu().detach().numpy())
      st_targets.extend(batch["labels"].cpu().detach().numpy())
      st_ids.extend(batch["ids"])

  st_ids = torch.stack(st_ids).numpy()
  st_outputs = np.array(st_outputs)

  #Evaluate
  run_results = evaluate_output(st_outputs, st_targets, lEnc.classes_, show_results = False, return_results = True)
  for k, v in run_results.items():
    if k in all_runs:
      all_runs[k].append(v)
    else:
      all_runs[k] = [v]

  print("-> RUN %s total time: %s" % (run, datetime.now() - start))
  for k, v in run_results.items():
    print("-> %s: %s" % (k, v))

  #Save if best run
  if run_results["F1_Weighted"] > best_run["F1_Weighted"]:
    best_run = run_results
    best_run["targets"] = st_targets
    best_run["outputs"] = st_outputs
    save_predictions(st_ids, st_targets, st_outputs, extra_title = "_bestOf%d" % NUM_RUNS)

In [None]:
# Evaluate Best Run
title = sessionTitle + "_best"
for check in teacher_checkpoints:
  title += "\n->" + check
evaluate_output(best_run["outputs"], best_run["targets"], lEnc.classes_.astype(str), title, showClassMatrix = True)

In [None]:
#Evaluate average runs
out_text = ""
tab_data = {"Classes": lEnc.classes_}
for k, v in all_runs.items():
  if k.startswith("Class"):
    tab_data[k] = np.mean(v, axis = 0)
    tab_data[k + "(±)"] = np.std(v, axis = 0)
  else:
    out_text += "%s: %.4f (± %.4f)\n" % (k, np.mean(v), np.std(v))

print(out_text)
print(tabulate(tab_data, headers = tab_data.keys(), floatfmt=".4f"))

In [None]:
#Save all results
pd.DataFrame(all_runs).to_csv("_OUTPUT/%s_results.csv" % (sessionTitle))