In [1]:
import os
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchnet import meter
from tqdm import tqdm
import numpy as np
import time
import torch.nn.functional as F
import torch.nn as nn


In [2]:
from cqt_loader_setlist import CQT
from models import CQTNet
from config import DefaultConfig
from utility import calc_MAP, norm

# Set Parameters

In [3]:
torch.cuda.is_available()

True

In [25]:
modelname = 'CQTNet'
load_latest = True
load_model_path = 'check_points/CQTNet.pth'
model_save_path = 'finetuning'

load_latest=True 

parallel=True

opt = DefaultConfig(load_model_path=load_model_path, 
                    notes=model_save_path,
                    batch_sz=32,
                    max_epoch=50, 
                    lr=0.001, 
                    lr_decay=0.8, 
                    weight_decay=1e-5, 
                    use_gpu=True, 
                    num_workers=4, 
                    parallel=parallel, 
                    load_latest=load_latest,
                    gpu_device="cuda:0")


torch.cuda.is_available() True
use_gpu True
cuda:0


In [5]:
torch.cuda.get_device_name(device=0)

'NVIDIA GeForce RTX 3090'

In [6]:
#root = "/Users/dirceusilva/Documents/BancoDados/setlist_all/setlist_65k/features_develop/universe_develop"
root = '/mnt/dataset/dataset_cover/universe_develop'
train_data_path = os.path.join(root, "universe_train")
val_data_path = os.path.join(root, "universe_val")
test_data_path = os.path.join(root, "universe_test")

# step 1: Build DataLoader

In [7]:
#######################################################
#                  Create Dataset
#######################################################
train_dataset0 = CQT(train_data_path, mode="train", out_length=200, num_workers=opt.num_workers)
train_dataset1 = CQT(train_data_path, mode="train", out_length=300, num_workers=opt.num_workers)
train_dataset2 = CQT(train_data_path, mode="train", out_length=400, num_workers=opt.num_workers)

valid_dataset = CQT(val_data_path, mode="valid", out_length=None, num_workers=1) 
test_dataset = CQT(test_data_path, mode="test", out_length=None, num_workers=1)

#######################################################
#                  Define Dataloaders
#######################################################

train_dataloader0 = DataLoader(train_dataset0, batch_size=opt.batch_size, shuffle=True, generator=torch.Generator(device='cuda:0'))
train_dataloader1 = DataLoader(train_dataset1, batch_size=opt.batch_size, shuffle=True, generator=torch.Generator(device='cuda:0'))
train_dataloader2 = DataLoader(train_dataset2, batch_size=opt.batch_size, shuffle=True, generator=torch.Generator(device='cuda:0'))

val_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, generator=torch.Generator(device='cuda:0'))
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, generator=torch.Generator(device='cuda:0'))

loaders = {'train': train_dataloader0, 'valid': val_dataloader, 'test': test_dataloader}


# step 2: configure model

In [15]:

#opt._parse(kwargs)

model = CQTNet.CQTNet()
model.load_state_dict(torch.load(load_model_path, map_location=opt.device))

print(model)

CQTNet(
  (features): Sequential(
    (conv0): Conv2d(1, 32, kernel_size=(12, 3), stride=(1, 1), padding=(6, 0), bias=False)
    (norm0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (conv1): Conv2d(32, 64, kernel_size=(13, 3), stride=(1, 1), dilation=(1, 2), bias=False)
    (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool1): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1), dilation=1, ceil_mode=False)
    (conv2): Conv2d(64, 64, kernel_size=(13, 3), stride=(1, 1), bias=False)
    (norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), dilation=(1, 2), bias=False)
    (norm3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (pool3)

In [16]:
model = CQTNet.CQTNetSetlist(original_model=model, num_classes=train_dataset0.get_nr_classes())

In [17]:
for p in model.features.parameters():
    p.requires_grad = True

for p in model.fc0.parameters():
    p.requires_grad = True

In [18]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)

features.conv0.weight True
features.norm0.weight True
features.norm0.bias True
features.conv1.weight True
features.norm1.weight True
features.norm1.bias True
features.conv2.weight True
features.norm2.weight True
features.norm2.bias True
features.conv3.weight True
features.norm3.weight True
features.norm3.bias True
features.conv4.weight True
features.norm4.weight True
features.norm4.bias True
features.conv5.weight True
features.norm5.weight True
features.norm5.bias True
features.conv6.weight True
features.norm6.weight True
features.norm6.bias True
features.conv7.weight True
features.norm7.weight True
features.norm7.bias True
features.conv8.weight True
features.norm8.weight True
features.norm8.bias True
features.conv9.weight True
features.norm9.weight True
features.norm9.bias True
fc0.weight True
fc0.bias True
fc1.weight True
fc1.bias True


In [19]:
if parallel is True: 
        model = torch.nn.DataParallel(model, device_ids=[0])

if parallel is True:
    if opt.load_latest is True:
        #model.load_state_dict(torch.load(os.path.join("check_points",opt.notes)), strict=False)
        model.module.load_latest(opt.notes)
    elif opt.load_model_path:
        model.module.load(opt.load_model_path)
else:
    if opt.load_latest is True:
        model.load_latest(opt.notes)
    elif opt.load_model_path:
        model.load(opt.load_model_path)
model.to(opt.device)

check_points/<class 'models.CQTNet.CQTNetSetlist'>finetuning/latest.pth


DataParallel(
  (module): CQTNetSetlist(
    (features): Sequential(
      (conv0): Conv2d(1, 32, kernel_size=(12, 3), stride=(1, 1), padding=(6, 0), bias=False)
      (norm0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (conv1): Conv2d(32, 64, kernel_size=(13, 3), stride=(1, 1), dilation=(1, 2), bias=False)
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (pool1): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1), dilation=1, ceil_mode=False)
      (conv2): Conv2d(64, 64, kernel_size=(13, 3), stride=(1, 1), bias=False)
      (norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), dilation=(1, 2), bias=False)
      (norm3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

# To device

In [20]:
model.to(opt.device)
print(model)

DataParallel(
  (module): CQTNetSetlist(
    (features): Sequential(
      (conv0): Conv2d(1, 32, kernel_size=(12, 3), stride=(1, 1), padding=(6, 0), bias=False)
      (norm0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (conv1): Conv2d(32, 64, kernel_size=(13, 3), stride=(1, 1), dilation=(1, 2), bias=False)
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (pool1): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 1), dilation=1, ceil_mode=False)
      (conv2): Conv2d(64, 64, kernel_size=(13, 3), stride=(1, 1), bias=False)
      (norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), dilation=(1, 2), bias=False)
      (norm3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

#  step 3: criterion and optimizer

In [21]:
criterion = torch.nn.CrossEntropyLoss()
lr = opt.lr
#if opt.parallel is True:
#    optimizer = torch.optim.Adam(model.module.parameters(), lr=lr, weight_decay=opt.weight_decay)
#else:
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
    
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=opt.lr_decay, patience=2, min_lr=5e-6)

# step 4: Train

### Validation Function

In [22]:
@torch.no_grad()
def val_slow(model, dataloader, epoch):
    
    model.eval()
    total, correct = 0, 0
    labels, features = None, None

    for ii, (data, label) in enumerate(dataloader):
        input_data = data.to(opt.device)
        #print(input.shape)
        score, feature = model(input_data)
        feature = feature.data.cpu().numpy()
        label = label.data.cpu().numpy()
        if features is not None:
            features = np.concatenate((features, feature), axis=0)
            labels = np.concatenate((labels,label))
        else:
            features = feature
            labels = label
    features = norm(features)

    dis2d = -np.matmul(features, features.T) # [-1,1] Because normalized, so mutmul is equal to ED
    np.save('dis.npy',dis2d)
    np.save('label.npy',labels)

    MAP, top10, rank1 = calc_MAP(dis2d, labels)

    print(epoch, MAP, top10, rank1 )
    
    model.train()
    
    return MAP

### Training

In [23]:
best_MAP=0
    
val_slow(model, val_dataloader, -1)

-1 0.885741948128748 0.1830084235860409 33.35643802647413


0.885741948128748

In [26]:

for epoch in range(opt.max_epoch):
    running_loss = 0
    num = 0
    for (data0, label0),(data1, label1),(data2, label2) in tqdm(zip(train_dataloader0, train_dataloader1, train_dataloader2)):
        for flag in range(3):
            if flag==0:
                data=data0
                label=label0
            elif flag==1:
                data=data1
                label=label1
            else:
                data=data2
                label=label2
            # train model
            input_data = data.requires_grad_()
            input_data = input_data.to(opt.device)
            
            target = label.to(opt.device)

            optimizer.zero_grad()
            score, _ = model(input_data)
            loss = criterion(score, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            num += target.shape[0]
        
    running_loss /= num 
    
    print(running_loss)
    
    if opt.parallel is True:
        model.module.save(opt.notes)
    else:
        model.save(opt.notes)
    
    # update learning rate
    scheduler.step(running_loss) 
    
    # validate
    MAP=0
    MAP += val_slow(model, val_dataloader, epoch)
            
    if MAP>best_MAP:
        best_MAP=MAP
        print('*****************BEST*****************')
    print('')
    model.train()

0it [00:00, ?it/s]

1273it [14:32,  1.46it/s]


0.04845547107165218
model name 0515_23:55:06.pth
0 0.8959867525726601 0.18456077015643804 37.63790613718412
*****************BEST*****************



1273it [14:30,  1.46it/s]


0.021686241091678635
model name 0516_00:11:34.pth
1 0.8945765534066048 0.18400722021660648 34.38688327316486



1273it [14:32,  1.46it/s]


0.01788341089733189
model name 0516_00:28:06.pth
2 0.8978840278480384 0.18471720818291215 37.41925391095066
*****************BEST*****************



1273it [14:30,  1.46it/s]


0.01652729404235408
model name 0516_00:44:36.pth
3 0.8936143779414447 0.18406738868832734 34.14777376654633



1273it [14:31,  1.46it/s]


0.014805727773217581
model name 0516_01:01:07.pth
4 0.886697372086821 0.18293622141997595 39.037906137184116



1273it [14:30,  1.46it/s]


0.014173093404431337
model name 0516_01:17:37.pth
5 0.8948960276598023 0.18457280385078217 34.68760529482551



1273it [14:30,  1.46it/s]


0.013553052089752122
model name 0516_01:34:06.pth
6 0.8949162528699103 0.18439229843561972 32.54476534296029



1273it [14:31,  1.46it/s]


0.012684913771885001
model name 0516_01:50:36.pth
7 0.8961970900040772 0.1843080625752106 35.001925391095064



1273it [14:30,  1.46it/s]


0.012548397507735096
model name 0516_02:07:06.pth
8 0.8948004640543107 0.18433212996389892 33.415042117930206



1273it [14:30,  1.46it/s]


0.012233268506946515
model name 0516_02:23:36.pth
9 0.8926226082726599 0.18405535499398315 35.5531889290012



1273it [14:31,  1.46it/s]


0.0117982409301616
model name 0516_02:40:05.pth
10 0.8876905253794132 0.18339350180505415 39.123225030084235



1273it [14:30,  1.46it/s]


0.011598009375077883
model name 0516_02:56:33.pth
11 0.8911567738289452 0.18400722021660648 35.478941034897716



1273it [14:31,  1.46it/s]


0.011236734268900669
model name 0516_03:13:03.pth
12 0.8927252929390198 0.18409145607701566 35.17882069795427



1273it [14:30,  1.46it/s]


0.011146824651425302
model name 0516_03:29:33.pth
13 0.8924428922391396 0.18424789410348977 34.28748495788207



1273it [14:30,  1.46it/s]


0.010930902450113848
model name 0516_03:46:01.pth
14 0.8925693305474112 0.18418772563176894 32.33020457280385



1273it [14:30,  1.46it/s]


0.010991937233049302
model name 0516_04:02:32.pth
15 0.8895791187860893 0.18377858002406738 35.171480144404335



1273it [14:31,  1.46it/s]


0.010756190248728956
model name 0516_04:19:01.pth
16 0.893441251597018 0.18452466907340553 36.09602888086643



1273it [14:35,  1.45it/s]


0.010614079175320656
model name 0516_04:35:34.pth
17 0.8892046200466249 0.1838748495788207 34.90758122743682



1273it [14:29,  1.46it/s]


0.010452472546702656
model name 0516_04:52:01.pth
18 0.8930485971313826 0.18401925391095067 36.26690734055355



1273it [14:29,  1.46it/s]


0.010371498122831148
model name 0516_05:08:29.pth
19 0.8932416289271591 0.18422382671480145 33.145968712394705



1273it [14:30,  1.46it/s]


0.010446166653580294
model name 0516_05:24:58.pth
20 0.8941942702289314 0.18407942238267147 38.29614921780987



1273it [14:31,  1.46it/s]


0.010166942460282435
model name 0516_05:41:28.pth
21 0.8908055758938438 0.18399518652226235 36.11239470517449



1273it [14:31,  1.46it/s]


0.010301354018543308
model name 0516_05:57:59.pth
22 0.8916457287243019 0.18376654632972322 37.19903730445247



1273it [14:31,  1.46it/s]


0.010185980253133614
model name 0516_06:14:29.pth
23 0.8936369116104991 0.18439229843561972 37.88567990373045



1273it [14:31,  1.46it/s]


0.00997421593756195
model name 0516_06:30:59.pth
24 0.8957947981417832 0.18478941034897714 37.46257521058965



1273it [14:30,  1.46it/s]


0.009876566628559541
model name 0516_06:47:27.pth
25 0.891167358135727 0.1838146811070999 36.5560770156438



1273it [14:30,  1.46it/s]


0.00989069150567212
model name 0516_07:03:57.pth
26 0.8968833399406149 0.18488567990373045 36.65439229843562



1273it [14:27,  1.47it/s]


0.009855488163593974
model name 0516_07:20:23.pth
27 0.8900141390537771 0.18379061371841154 35.461371841155234



1273it [14:29,  1.46it/s]


0.009542039914847474
model name 0516_07:36:50.pth
28 0.8964817881306759 0.18458483754512636 34.2280385078219



1273it [14:30,  1.46it/s]


0.009932783809547823
model name 0516_07:53:20.pth
29 0.8945762704240414 0.18445246690734057 36.5503008423586



1273it [14:26,  1.47it/s]


0.009714642579956247
model name 0516_08:09:45.pth
30 0.891201646808425 0.18362214199759325 34.28026474127557



1273it [14:26,  1.47it/s]


0.009677785515673639
model name 0516_08:26:09.pth
31 0.8957246103289289 0.18460890493381468 35.04055354993983



1273it [14:30,  1.46it/s]


0.007503828593526991
model name 0516_08:42:36.pth
32 0.8942203732596277 0.18425992779783393 36.45006016847172



1273it [14:29,  1.46it/s]


0.007108400487113127
model name 0516_08:59:04.pth
33 0.8961634301014342 0.18466907340553548 32.85932611311673



1273it [14:27,  1.47it/s]


0.0070489642059185665
model name 0516_09:15:31.pth
34 0.8925399809824247 0.18392298435619736 32.96883273164862



1273it [14:29,  1.46it/s]


0.007334433473183297
model name 0516_09:31:59.pth
35 0.8884408011078043 0.18358604091456077 35.567870036101084



1273it [14:27,  1.47it/s]


0.007216434289699641
model name 0516_09:48:26.pth
36 0.8913438628834325 0.18377858002406738 35.41540312876053



1273it [14:27,  1.47it/s]


0.007227118000587542
model name 0516_10:04:53.pth
37 0.8941526546823053 0.18399518652226235 33.6208182912154



1273it [14:29,  1.46it/s]


0.005924634554841754
model name 0516_10:21:20.pth
38 0.890856101013247 0.18386281588447653 34.24344163658243



1273it [14:28,  1.47it/s]


0.005604849152871419
model name 0516_10:37:45.pth
39 0.8923300836754 0.1844765342960289 33.25944645006017



1273it [14:27,  1.47it/s]


0.005773904246960418
model name 0516_10:54:10.pth
40 0.8908733597339801 0.18421179302045726 35.19338146811071



1273it [14:30,  1.46it/s]


0.005768796773691854
model name 0516_11:10:39.pth
41 0.8938409299078944 0.18457280385078217 34.31263537906137



1273it [14:28,  1.47it/s]


0.005870562084660963
model name 0516_11:27:06.pth
42 0.893816049741361 0.18439229843561972 33.52611311672683



1273it [14:28,  1.47it/s]


0.004565184309032536
model name 0516_11:43:32.pth
43 0.8883185443769749 0.18382671480144402 32.043561973525875



1273it [14:28,  1.47it/s]


0.004821226228742124
model name 0516_11:59:59.pth
44 0.8905749160041627 0.18364620938628157 32.55475330926595



1273it [14:28,  1.47it/s]


0.004423532431932555
model name 0516_12:16:25.pth
45 0.8933473489506127 0.18441636582430807 30.151985559566786



1273it [14:27,  1.47it/s]


0.004694014283256642
model name 0516_12:32:50.pth
46 0.8888697959876874 0.1838748495788207 33.471961492178096



1273it [14:29,  1.46it/s]


0.004528028591133716
model name 0516_12:49:17.pth
47 0.8946125950218352 0.18452466907340553 33.02635379061372



1273it [14:29,  1.46it/s]


0.004587180139367258
model name 0516_13:05:44.pth
48 0.8979962095564893 0.18523465703971118 32.58146811070999
*****************BEST*****************



1273it [14:28,  1.47it/s]


0.003742056920366177
model name 0516_13:22:10.pth
49 0.8921344339553335 0.18422382671480145 34.11347773766546



In [35]:
opt.notes='experiment0'

# step 5: Test

In [None]:
def test(**kwargs):
    opt.batch_size=1
    opt.num_workers=1
    opt.model = 'CQTNet'
    opt.load_latest = False
    opt.load_model_path = 'check_points/CQTNet.pth'
    opt._parse(kwargs)
    
    model = getattr(models, opt.model)() 
    #print(model)
    if opt.load_latest is True:
        model.load_latest(opt.notes)
    elif opt.load_model_path:
        model.load(opt.load_model_path)
    model.to(opt.device)

    val_data = CQT('val', out_length=None)
    test_data = CQT('test', out_length=None)
    val_dataloader = DataLoader(val_data, 1, shuffle=False,num_workers=1)
    test_dataloader = DataLoader(test_data, 1, shuffle=False,num_workers=1)
    
    val_slow(model, val_dataloader, 0)

In [40]:
val_slow(model, test_dataloader, -1)

-1 0.900768927377684 0.0933318161128812 111.79449248975877


0.900768927377684