## Imports

In [2]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torchaudio import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

from data import Sep28K
from models import LSTMModel

def set_seed(seed):
    import random
    import numpy as np
    import torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

TASK_TYPE = 'mtl'
_writer = SummaryWriter(log_dir = './logs',comment=TASK_TYPE)
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.mps

In [3]:
import torch
import torch.nn as nn

class CCCLoss(nn.Module):
    def __init__(self):
        super(CCCLoss, self).__init__()

    def forward(self, y_true, y_pred):
        y_true_mean = torch.mean(y_true, dim=1, keepdim=True)
        y_pred_mean = torch.mean(y_pred, dim=1, keepdim=True)
        
        y_true_var = torch.var(y_true, dim=1, unbiased=False)
        y_pred_var = torch.var(y_pred, dim=1, unbiased=False)
        
        covariance = torch.mean((y_true - y_true_mean) * (y_pred - y_pred_mean), dim=1)
        
        ccc = (2 * covariance) / (y_true_var + y_pred_var + (y_true_mean - y_pred_mean).squeeze() ** 2)
        
        ccc_loss = 1 - ccc.mean()

        return ccc_loss

class AverageMeter(object):
    def __init__(self, writer: SummaryWriter = None, name=None):
        self.reset()
        self._writer = writer
        self._name = name

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val 
        self.count += n
        self.avg = self.sum / self.count

    def write(self, epoch):
        self._writer.add_scalar(self._name + '/val', self.avg, epoch)

# configurations
from yacs.config import CfgNode as CN
cfg = CN()
cfg.seed = 42
cfg.batch_size = 32
cfg.num_workers = 4
cfg.lr = 1e-3
cfg.epochs = 100


## Prepare Data

In [4]:
from sklearn.model_selection import train_test_split

data_path = '../datasets/sep28k/clips'
label_path = '../datasets/sep28k/SEP-28k_labels_new.csv'
trans  =  transforms.MelSpectrogram(win_length=400, hop_length=160, n_mels=40)    
dataset = Sep28K(root = data_path, label_path = label_path, transforms=trans)

In [5]:
train_dataset, val_dataset = train_test_split(dataset, test_size=0.1, random_state=42)
val_dataset, test_dataset = train_test_split(val_dataset, test_size=0.3, random_state=42)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)



In [6]:
print("Number of samples in train dataset:", len(train_dataset))
print("Number of samples in validation dataset:", len(val_dataset))
print("Number of samples in test dataset:", len(test_dataset))

Number of samples in train dataset: 19670
Number of samples in validation dataset: 1530
Number of samples in test dataset: 656


## Train MTL Setting

In [7]:
TASK_TYPE = 'mtl'

model = LSTMModel(input_size=40, hidden_size=64, num_layers=1, output_size=6)
optimizer = Adam(model.parameters(), lr=0.01)

loss_t1 = nn.CrossEntropyLoss()
loss_t2 = CCCLoss()
t1_loss_train = AverageMeter(name='t1_loss_train', writer=_writer)
t2_loss_train = AverageMeter(name='t2_loss_train', writer=_writer)
t1_loss_val = AverageMeter(name='t1_loss_val', writer=_writer)
t2_loss_val = AverageMeter(name='t2_loss_val', writer=_writer)

In [None]:
def train(model, optimiser, train_loader, val_loader, criterion, task_type, device, writer=None):
    loss_t1, loss_t2 = criterion
    t1_loss_train = AverageMeter(name='t1_loss_train', writer=writer)
    t2_loss_train = AverageMeter(name='t2_loss_train', writer=writer)
    t1_loss_val = AverageMeter(name='t1_loss_val', writer=writer)
    t2_loss_val = AverageMeter(name='t2_loss_val', writer=writer)
    
    for epoch in range(cfg.epochs):
        model.train()
        t1_loss_train.reset()
        t2_loss_train.reset()

        for batch in train_loader:
            x, y_t1, y_t2, y = batch
            x, y_t1, y_t2, y = x.to(device), y_t1.to(device), y_t2.to(device), y.to(device)
            optimiser.zero_grad()
            pred_t1, pred_t2 = model(x, task_type=task_type)
            loss1, loss2 = 0, 0
            if task_type == 'mtl':
                loss1 = loss_t1(y_t1, pred_t1)
                loss2 = loss_t2(y_t2, pred_t2)
                loss = loss1 + loss2
            elif task_type == 't1':
                loss = loss_t1(y_t1, pred_t1)
            else:
                loss = loss_t2(y_t2, pred_t2)
            loss.backward()
            optimiser.step()
            t1_loss_train.update(loss1.item())
            t2_loss_train.update(loss2.item())

        t1_loss_train.write(epoch)
        t2_loss_train.write(epoch)

        model.eval()
        t1_loss_val.reset()
        t2_loss_val.reset()

        for i, (data, target) in enumerate(val_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            if task_type == 'mtl':
                loss1 = loss_t1(output[0], target[:, 0])
                loss2 = loss_t2(output[1], target[:, 1])
            else:
                loss = loss_t1(output, target)
            t1_loss_val.update(loss1.item())
            t2_loss_val.update(loss2.item())

        t1_loss_val.write(epoch)
        t2_loss_val.write(epoch)

        print(f"Epoch: {epoch}, T1 Loss: {t1_loss_train.avg}, T2 Loss: {t2_loss_train.avg}")


In [8]:
#  Training MTL
num_epochs = 100
best_loss = torch.tensor(float('inf'))
patience = 10
for epoch in range(num_epochs):
    t1_loss_train.reset()
    t2_loss_train.reset()
    for batch in train_loader:
        optimizer.zero_grad()
        X, y_t1, y_t2, y = batch
        t1_pred, t2_pred = model(X.squeeze(1), task_type=TASK_TYPE)
        loss_task_1 = loss_t1(t1_pred, y_t1)
        loss_task_2 = loss_t2(t2_pred, y_t2)
        t1_loss_train.update(loss_task_1.item(), X.size(0))
        t2_loss_train.update(loss_task_2.item(), X.size(0))
        loss = loss_task_1 + loss_task_2
        loss.backward()
        optimizer.step()
    t1_loss_train.write(epoch)
    t2_loss_train.write(epoch)

    # validate
    model.eval()
    with torch.no_grad():
        t1_loss_val.reset()
        t2_loss_val.reset()
        for batch in test_loader:
            X, y_t1, y_t2, y = batch
            pred_t1, pred_t2 = model(X.squeeze(1), task_type=TASK_TYPE)
            loss_task_1 = loss_t1(pred_t1, y_t1)
            loss_task_2 = loss_t2(pred_t2, y_t2)
            t1_loss_val.update(loss_task_1.item(), X.size(0))
            t2_loss_val.update(loss_task_2.item(), X.size(0))
        t1_loss_val.write(epoch)
        t2_loss_val.write(epoch)
    # setup early stopping
    if t1_loss_val.avg + t2_loss_val.avg < best_loss:
        best_loss = t1_loss_val.avg + t2_loss_val.avg
        patience = 10
    else:
        patience -= 1

    if patience == 0:
        break


In [None]:
from sklearn.metrics import accuracy_score, f1_score
def test(model, test_loader, task_type='mtl'):
    model.eval()
    accuracy_t1 = AverageMeter()
    accuracy_t2 = AverageMeter()
    f1 = AverageMeter()
    with torch.no_grad():
        for batch in test_loader:
            X, y_t1, y_t2, y = batch
            pred_t1, pred_t2 = model(X.squeeze(1), task_type=task_type)
            
            if task_type == 'mtl' or task_type == 't1':
                pred_t1 = torch.argmax(pred_t1, dim=1)
                acc = (pred_t1 == y_t1).sum().item()
                accuracy_t1.update(acc, X.size(0))

            pred_t2 = torch.argmax(pred_t2, dim=1)
            acc_fluency = (pred_t2 == y).sum().item()
            accuracy_t2.update(acc_fluency, X.size(0))

            f1.update(f1_score(y, pred_t2, average='weighted'), X.size(0))

    return accuracy_t1.avg, accuracy_t2.avg, f1.avg

print(test(model, test_loader, task_type=TASK_TYPE))

## Train STL

In [None]:
TASK_TYPE = 't2'

model = LSTMModel(input_size=40, hidden_size=64, num_layers=1, output_size=6)
optimizer = Adam(model.parameters(), lr=0.01)

loss_t1 = nn.CrossEntropyLoss()
loss_t2 = CCCLoss()
t2_loss_train = AverageMeter(name='t2_loss_train', writer=_writer)
t2_loss_val = AverageMeter(name='t2_loss_val', writer=_writer)

In [None]:
#  Training t2
num_epochs = 100
best_loss = torch.tensor(float('inf'))
patience = 10
for epoch in range(num_epochs):
    t2_loss_train.reset()
    for batch in train_loader:
        optimizer.zero_grad()
        X, y_t1, y_t2, y = batch
        t1_pred, t2_pred = model(X.squeeze(1), task_type=TASK_TYPE)
        loss_task_2 = loss_t2(t2_pred, y_t2)
        t2_loss_train.update(loss_task_2.item(), X.size(0))
        loss.backward()
        optimizer.step()
    t2_loss_train.write(epoch)

    # validate
    model.eval()
    with torch.no_grad():
        t1_loss_val.reset()
        t2_loss_val.reset()
        for batch in val_loader:
            X, y_t1, y_t2, y = batch
            pred_t1, pred_t2 = model(X.squeeze(1), task_type=TASK_TYPE)
            loss_task_2 = loss_t2(pred_t2, y_t2)
            t2_loss_val.update(loss_task_2.item(), X.size(0))
        t2_loss_val.write(epoch)
    # setup early stopping
    if t2_loss_val.avg < best_loss:
        best_loss = t2_loss_val.avg
        patience = 10
    else:
        patience -= 1

    if patience == 0:
        break


In [None]:
print(test(model, test_loader, task_type=TASK_TYPE))

## Train ConvLSTM with MTL