In [3]:
from utils import SpectrogramDataset, torch_train_val_split, ASTBackbone, get_regression_report,\
    CLASS_MAPPING, Classifier, Regressor, train, get_device, free_gpu_memory, test_model, plot_train_val_losses
import torch
import os

# DATA_PATH = "/home/alex/Downloads/archive(1)/data/"
DATA_PATH = os.path.join(os.getcwd(), "data/")
model_weights_path = os.path.join(os.getcwd(), "model_weights/")
assets_path = os.path.join(os.getcwd(), "assets/")

mel_specs_path = DATA_PATH + "fma_genre_spectrograms/"
beat_mel_specs_path = DATA_PATH + "fma_genre_spectrograms_beat/"
multitask_path = DATA_PATH + "multitask_dataset/"
os.makedirs(model_weights_path, exist_ok=True)
os.makedirs(assets_path, exist_ok=True)

EPOCHS = 100
LR = 1e-4
BATCH_SIZE = 8
VAL_SIZE = .2
RANDOM_SEED = 42
NUM_CATEGORIES = 10
AST_MODEL_SIZES = ['tiny224', 'small224', 'base224', 'base384']

if torch.cuda.is_available():
    print(f"Detected GPU: {torch.cuda.get_device_name(0)}")
    print(f"Free GPU Memory: {free_gpu_memory():.2f}%")
else:
    print("CUDA is not available. No compatible GPU detected.")
DEVICE = get_device()

Detected GPU: NVIDIA GeForce RTX 4060 Laptop GPU
Free GPU Memory: 96.95%


In [4]:
import torch
import torch.optim as optim

mel_specs_data = SpectrogramDataset(mel_specs_path, class_mapping=CLASS_MAPPING, train=True)
mel_specs_train_dl, mel_specs_val_dl = torch_train_val_split(
    dataset=mel_specs_data,
    batch_eval=BATCH_SIZE,
    batch_train=BATCH_SIZE,
    val_size=VAL_SIZE,
    shuffle=True,
)

x_b1, _, _ = next(iter(mel_specs_train_dl))
input_shape = x_b1[0].shape

# init AST Backbone and classifier
model_size = AST_MODEL_SIZES[0]
backbone = ASTBackbone(
    fstride=10,
    tstride=10,
    input_fdim=input_shape[1],
    input_tdim=input_shape[0],
    imagenet_pretrain=True,
    model_size=model_size,
    feature_size=NUM_CATEGORIES,  
)
model = Classifier(NUM_CATEGORIES, backbone).to(DEVICE)
genre_optimizer = optim.Adam(model.parameters(), lr=LR)

# Train the model on Spectrograms
train_losses, val_losses = train(model, mel_specs_train_dl, 
                                 mel_specs_val_dl, genre_optimizer, EPOCHS, device=DEVICE)

# Save pretrained weights
pretrained_weights = model_size+"_ast_spectrogram_pretraining.pth"
torch.save(model.state_dict(), pretrained_weights)

Training started for model checkpoint...


Training Progress:   1%|[32m          [0m| 1/100 [01:17<2:07:31, 77.29s/epoch, Epoch=1, Train Loss=2.0755, Val Loss=1.8398, Time (Train)=62.84s, Time (Val)=6.50s]


KeyboardInterrupt: 

In [5]:
# Fine-Tuning for Valence Regression
valence_data = SpectrogramDataset(
    multitask_path,
    class_mapping=CLASS_MAPPING,
    train=True,
    regression=1,  # Regression task for valence
)
valence_train_dl, valence_val_dl = torch_train_val_split(
    dataset=valence_data,
    batch_eval=BATCH_SIZE,
    batch_train=BATCH_SIZE,
    val_size=VAL_SIZE,
    shuffle=True,
)

# Initialize model for regression
pretrained_backbone = ASTBackbone(
    fstride=10,
    tstride=10,
    input_fdim=valence_data[0][0].shape[1],
    input_tdim=valence_data[0][0].shape[0],
    imagenet_pretrain=False,
    model_size=model_size,
    feature_size=1,  # Single output for regression
)
pretrained_model = Regressor(pretrained_backbone).to(DEVICE)

# Load pretrained weights for fine-tuning
pretrained_model.backbone.load_state_dict(torch.load(pretrained_weights, weights_only=True), strict=False)

# Freeze all but the last 2 Transformer layers
pretrained_backbone.freeze_layers(unfrozen_layers=2)

# Verify which parameters are frozen
# for name, param in pretrained_model.named_parameters():
#     print(f"{name}: {'Trainable' if param.requires_grad else 'Frozen'}")

# Define optimizer after freezing layers to avoid updating frozen parameters
valence_optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, pretrained_model.parameters()),  # Only trainable params
    lr=LR
)

# Fine-tune the model for Valence
finetuned_train_losses, finetuned_val_losses = train(
    pretrained_model,
    valence_train_dl,
    valence_val_dl,
    valence_optimizer,
    EPOCHS,
    device=DEVICE,
    regression_flag=True
)

# Save fine-tuned weights
torch.save(pretrained_model.state_dict(), "ast_valence_finetuning.pth")

backbone.v.cls_token: Trainable
backbone.v.pos_embed: Trainable
backbone.v.dist_token: Trainable
backbone.v.patch_embed.proj.weight: Trainable
backbone.v.patch_embed.proj.bias: Trainable
backbone.v.blocks.0.norm1.weight: Frozen
backbone.v.blocks.0.norm1.bias: Frozen
backbone.v.blocks.0.attn.qkv.weight: Frozen
backbone.v.blocks.0.attn.qkv.bias: Frozen
backbone.v.blocks.0.attn.proj.weight: Frozen
backbone.v.blocks.0.attn.proj.bias: Frozen
backbone.v.blocks.0.norm2.weight: Frozen
backbone.v.blocks.0.norm2.bias: Frozen
backbone.v.blocks.0.mlp.fc1.weight: Frozen
backbone.v.blocks.0.mlp.fc1.bias: Frozen
backbone.v.blocks.0.mlp.fc2.weight: Frozen
backbone.v.blocks.0.mlp.fc2.bias: Frozen
backbone.v.blocks.1.norm1.weight: Frozen
backbone.v.blocks.1.norm1.bias: Frozen
backbone.v.blocks.1.attn.qkv.weight: Frozen
backbone.v.blocks.1.attn.qkv.bias: Frozen
backbone.v.blocks.1.attn.proj.weight: Frozen
backbone.v.blocks.1.attn.proj.bias: Frozen
backbone.v.blocks.1.norm2.weight: Frozen
backbone.v.block

In [None]:
train_losses = train_losses.extend(finetuned_train_losses)
val_losses = val_losses.extend(finetuned_val_losses)
plot_train_val_losses(train_losses, val_losses, save_title=f"assets/finetuned_ast_train_val_losses.png")

In [8]:
y_true, y_pred, spear_corrs = test_model(pretrained_model, valence_val_dl, DEVICE, regression_flag=True)

print("Fine-Tuned AST Regressor on Valence Dataset")
get_regression_report(y_pred, y_true, spear_corrs)

Fine-Tuned AST Regressor on Valence Dataset
	Spearman Correlation: 0.3119
	MSE: 0.0751
	MAE: 0.2299
	RMSE: 0.2741
