In [1]:
from utils import CNNBackbone, MultiTaskDataset, ASTBackbone, get_device, SpectrogramDataset, CLASS_MAPPING, torch_train_val_split, Classifier, train, set_seed, plot_train_val_losses, test_model, get_regression_report, create_folder
import torch.optim as optim
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

DATA_PATH = "/home/alex/Downloads/archive(1)/data/"
EPOCHS = 100
LR = 1e-4
BATCH_SIZE = 32
RANDOM_SEED = 42
RNN_HIDDEN_SIZE = 64
NUM_CATEGORIES = 1
cnn_in_channels = 1
cnn_filters = [32, 64, 128, 256]
cnn_out_feature_size = 256
DEVICE = get_device()


create_folder("model_weights"), create_folder("assets")

(None, None)

In [2]:
valence_data = SpectrogramDataset(DATA_PATH + "multitask_dataset/", class_mapping=CLASS_MAPPING, train=True, regression=1)
energy_data = SpectrogramDataset(DATA_PATH + "multitask_dataset/", class_mapping=CLASS_MAPPING, train=True, regression=2)
dancability_data = SpectrogramDataset(DATA_PATH + "multitask_dataset/", class_mapping=CLASS_MAPPING, train=True, regression=3)

In [3]:
multi_task_labels = []
for valence_y, energy_y, dancability_y in zip(valence_data.labels, energy_data.labels, dancability_data.labels):
    multi_task_labels.append((valence_y, energy_y, dancability_y))
np.array(multi_task_labels)

array([[0.578, 0.973, 0.873],
       [0.839, 0.782, 0.655],
       [0.587, 0.956, 0.204],
       ...,
       [0.337, 0.592, 0.316],
       [0.536, 0.404, 0.366],
       [0.477, 0.949, 0.431]])

In [4]:
multi_task_dataset = MultiTaskDataset(features=valence_data.feats, labels=np.array(multi_task_labels))
dataloader = DataLoader(multi_task_dataset, batch_size=2, shuffle=True)

In [5]:
import torch.nn as nn

class MultiTaskClassifier(nn.Module):
    def __init__(self, num_tasks, backbone, task_feature_sizes):
        """
        num_tasks (int): The number of tasks (e.g., 3 metrics)
        backbone (nn.Module): The shared backbone (CNNBackbone or LSTMBackbone)
        task_feature_sizes (list of int): Output sizes for each task
        """
        super(MultiTaskClassifier, self).__init__()
        self.backbone = backbone  # Shared backbone
        
        # Separate output layers for each task
        self.output_layers = nn.ModuleList([
            nn.Linear(self.backbone.feature_size, task_feature_sizes[i]) for i in range(num_tasks)
        ])
        
        # Criterion for each task
        self.criterions = [nn.MSELoss() for _ in range(num_tasks)]  # Regression losses

    def forward(self, x, targets):
        """
        x: Input features
        targets: List of target tensors for each task
        lengths: Sequence lengths (for LSTM inputs)
        """
        # Shared feature extraction
        feats = self.backbone(x)
        
        # Task-specific outputs each element holds the predictions for the corresponding head
        logits = [output_layer(feats) for output_layer in self.output_layers]
        logits = [logit.squeeze(-1) for logit in logits]
        # Compute losses for each task
        losses = [criterion(logits[i], targets[:, i]) for i, criterion in enumerate(self.criterions)]
        
        # Weighted sum of losses (equal weight for simplicity; can be tuned)
        total_loss = sum(losses)
        
        return total_loss, losses, logits


In [10]:
set_seed(RANDOM_SEED)
backbone = ASTBackbone(
    fstride=10,                     
    tstride=10,                   
    input_fdim=dancability_data[0][0].shape[1],      
    input_tdim=dancability_data[0][0].shape[0],     
    imagenet_pretrain=False,      
    model_size='small224',          
    feature_size=1    
)

set_seed(RANDOM_SEED)
model = MultiTaskClassifier(num_tasks=3, backbone=backbone, task_feature_sizes=[1, 1, 1])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)  # Every 10 epochs, reduce LR by factor of 0.7
inputs, targets, lengths = next(iter(dataloader))
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
model.to(DEVICE)
for epoch in range(80):
    model.train()
    optimizer.zero_grad()
    loss, losses, logits = model(inputs.float(), targets.float())
    loss.backward()
    # clip gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()
    if epoch == 0 or (epoch+1)%5 == 0:
        loss1, loss2, loss3 = losses
        print(f'Epoch {epoch+1}\n\tTotal Loss at training set: {loss.item()}\n\t{loss1.item()}, {loss2.item()}, {loss3.item()}')


Epoch 1
	Total Loss at training set: 0.4920566976070404
	0.271818071603775, 0.1185767725110054, 0.10166185349225998
Epoch 5
	Total Loss at training set: 8.623558044433594
	7.062807083129883, 1.3222465515136719, 0.23850411176681519
Epoch 10
	Total Loss at training set: 0.47748029232025146
	0.30532971024513245, 0.09234171360731125, 0.07980887591838837
Epoch 15
	Total Loss at training set: 0.21627843379974365
	0.06962907314300537, 0.12688522040843964, 0.019764143973588943
Epoch 20
	Total Loss at training set: 0.1924845427274704
	0.07543405890464783, 0.11010538041591644, 0.006945108529180288
Epoch 25
	Total Loss at training set: 0.17525197565555573
	0.07651092112064362, 0.08794593811035156, 0.0107951108366251
Epoch 30
	Total Loss at training set: 0.1587393581867218
	0.0710899606347084, 0.08248107880353928, 0.00516832061111927
Epoch 35
	Total Loss at training set: 0.1441487818956375
	0.06292903423309326, 0.08041144907474518, 0.0008082911954261363
Epoch 40
	Total Loss at training set: 0.1402

In [10]:
# Example training loop
set_seed(RANDOM_SEED)
backbone = CNNBackbone(valence_data[0][0].shape, cnn_in_channels, cnn_filters, cnn_out_feature_size)
set_seed(RANDOM_SEED)
model = MultiTaskClassifier(num_tasks=3, backbone=backbone, task_feature_sizes=[1, 1, 1])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()
backbone.train()
model.to(DEVICE)
for epoch in range(EPOCHS):
    running_loss = 0.
    for inputs, targets, _ in dataloader:
        optimizer.zero_grad()
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        loss, losses, logits = model(inputs.float(), targets.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(f'Epoch {epoch+1}\n\tTotal Loss at training set: {running_loss / len(dataloader)}')


Epoch 1
	Total Loss at training set: 0.004171649085150824
Epoch 1
	Total Loss at training set: 4314.486393871307
Epoch 1
	Total Loss at training set: 4455.163112621307
Epoch 1
	Total Loss at training set: 4806.403425121307
Epoch 1
	Total Loss at training set: 5998.0998140101965
Epoch 1
	Total Loss at training set: 7151.465091787974
Epoch 1
	Total Loss at training set: 8052.096897343529
Epoch 1
	Total Loss at training set: 8486.65175845464
Epoch 1
	Total Loss at training set: 8594.07314734353
Epoch 1
	Total Loss at training set: 8600.06731780794
Epoch 1
	Total Loss at training set: 8671.255238814883
Epoch 1
	Total Loss at training set: 8930.74989159266
Epoch 1
	Total Loss at training set: 9261.431697148217
Epoch 1
	Total Loss at training set: 9495.900377703772
Epoch 1
	Total Loss at training set: 9738.337391592662
Epoch 1
	Total Loss at training set: 9834.048398537107
Epoch 1
	Total Loss at training set: 9862.144431523217
Epoch 1
	Total Loss at training set: 9883.29841155794
Epoch 1
	To

KeyboardInterrupt: 