In [1]:
from torch import cuda

device = 'cuda' if cuda.is_available() else 'cpu'

In [2]:
!git clone https://github.com/xxbean/pozalabs

Cloning into 'pozalabs'...
remote: Enumerating objects: 103, done.[K
remote: Counting objects: 100% (34/34), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 103 (delta 6), reused 0 (delta 0), pack-reused 69[K
Receiving objects: 100% (103/103), 64.63 MiB | 17.23 MiB/s, done.
Resolving deltas: 100% (13/13), done.


In [3]:

import sys
sys.path.append("/content/pozalabs/ComMU-code")
sys.path.append("/content/pozalabs/Multi-Scale-1D-ResNet")
from model.multi_scale_ori import *
from model.meta_labeling import * #여기에 메타 데이터 함수 정보 있음
import torch
import numpy as np
from torchsummary import summary
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import tqdm
import torch.nn.functional as F
import sys
import glob
import copy

In [5]:
%cd /content/pozalabs/
%ls

/content/pozalabs
CAS_colab.ipynb  [0m[01;34mComMU-code[0m/  [01;34mMulti-Scale-1D-ResNet[0m/  README.md


In [9]:
import glob
from sklearn.model_selection import train_test_split

raw_train = glob.glob("./ComMU-code/dataset/commu_midi/train/raw/**")
raw_val = glob.glob("./ComMU-code/dataset/commu_midi/val/raw/**")

train_meta_npy_ = np.load("/content/pozalabs/ComMU-code/dataset/output_npy_ 2/input_train.npy", allow_pickle=True)[:10000]
train_midi_npy_ = np.load("/content/pozalabs/ComMU-code/dataset/output_npy_ 2/target_train.npy", allow_pickle=True)[:10000]
val_meta_npy_ = np.load("/content/pozalabs/ComMU-code/dataset/output_npy_ 2/input_train.npy", allow_pickle=True)[10000:]
val_midi_npy_ = np.load("/content/pozalabs/ComMU-code/dataset/output_npy_ 2/target_train.npy", allow_pickle=True)[10000:]
test_meta_npy_ = np.load('/content/pozalabs/ComMU-code/dataset/output_npy_ 2/input_val.npy', allow_pickle=True)
test_midi_npy_ = np.load("/content/pozalabs/ComMU-code/dataset/output_npy_ 2/target_val.npy", allow_pickle=True)

label_list = np.unique(train_meta_npy_[:,0])

In [10]:
print(
    
len(train_meta_npy_),
len(train_midi_npy_),
len(val_meta_npy_),
len(val_midi_npy_),
len(test_meta_npy_),
len(test_midi_npy_))

10000 10000 381 381 763 763


In [11]:
class Commu(Dataset):
    def __init__(self, meta_npy, midi_npy,seq_len):
        self.meta_npy = meta_npy
        self.midi_npy = midi_npy
        self.seq_len = seq_len
        self.label_npy = np.zeros_like(self.meta_npy)
        for i in range(11):
            self.label_npy[:,i] = np.array(list(map(meta_list[i],meta_npy[:,i])))

    def __len__(self):
        return len(self.meta_npy)

    def __getitem__(self, idx):
        label = self.label_npy[idx]
        midi = np.zeros((1,self.seq_len))
        midi_real = self.midi_npy[idx][:self.seq_len]
        midi[:,:len(midi_real)] = midi_real
        midi = torch.tensor(midi.tolist(),dtype=torch.float)
        label = torch.tensor(label.tolist(),dtype=torch.float)
        return midi,label

In [12]:
real_data = Commu(train_meta_npy_, train_midi_npy_,512)
real_loader = DataLoader(real_data, batch_size=256, shuffle=True)
val_data = Commu(val_meta_npy_, val_midi_npy_,512)
val_loader = DataLoader(val_data, batch_size=256, shuffle=True)
test_data = Commu(test_meta_npy_, test_midi_npy_,512)
test_loader = DataLoader(test_data, batch_size=256, shuffle=True)

In [13]:
test_batch = next(iter(real_loader))[0]
test_target = next(iter(real_loader))[1]

In [14]:
msresnet = MSResNet(input_channel=1, layers=[1, 1, 1, 1], num_classes=5)
msresnet = msresnet.cuda()
# summary(msresnet,(1,512))

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(msresnet.parameters(), lr=0.001, momentum=0.9)
num_epochs = 100
meta_num = 1

In [18]:
best_acc=0

for epoch in range(num_epochs):  

    running_loss = 0.0
    for i, data in enumerate(tqdm.tqdm(real_loader)):
        msresnet.train()
        # get the inputs; data is a list of [inputs, labels]
        midi,label = data
        midi = midi.cuda()
        label = label.cuda()[:,meta_num]
#         label = F.one_hot(label.to(torch.int64), num_classes = 10)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = msresnet(midi)[0]
        loss = criterion(outputs, label.long())
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
    print(running_loss / 41)
    if epoch %6 == 5:
        val_num = 0
        correct_num = 0
        with torch.no_grad():
            msresnet.eval()
            for i, data in enumerate(val_loader):
                midi,label = data
                midi = midi.cuda()
                label = label[:,meta_num]
                num = label.shape[0]
                val_num += num
                outputs = msresnet(midi)[0].argmax(1).cpu()
                correct_num += sum(outputs==label)
                acc = correct_num/val_num
                if acc > best_acc:
                  torch.save(msresnet.state_dict(), '/content/pozalabs/ComMU-code/model_state_dict.pt')
                  best_model = copy.deepcopy(msresnet.state_dict())
                
            print(f"validation accuracy for epoch {epoch} is : {correct_num/val_num}")
        

print('Finished Training')

100%|██████████| 40/40 [00:09<00:00,  4.02it/s]


0.7477613120544248


100%|██████████| 40/40 [00:02<00:00, 18.61it/s]


0.598884576704444


100%|██████████| 40/40 [00:02<00:00, 18.06it/s]


0.5772044455132833


100%|██████████| 40/40 [00:02<00:00, 17.09it/s]


0.5570219071899972


100%|██████████| 40/40 [00:02<00:00, 17.86it/s]


0.5366361591874099


100%|██████████| 40/40 [00:02<00:00, 17.44it/s]


0.5231320370988148
validation accuracy for epoch 5 is : 0.6299212574958801


100%|██████████| 40/40 [00:02<00:00, 18.39it/s]


0.501304014426906


100%|██████████| 40/40 [00:02<00:00, 18.42it/s]


0.4795155481594365


100%|██████████| 40/40 [00:02<00:00, 16.00it/s]


0.4563251210422051


100%|██████████| 40/40 [00:02<00:00, 18.24it/s]


0.43505435745890547


100%|██████████| 40/40 [00:02<00:00, 18.14it/s]


0.41595010931898907


100%|██████████| 40/40 [00:02<00:00, 18.14it/s]


0.3980434827688264
validation accuracy for epoch 11 is : 0.6640419960021973


100%|██████████| 40/40 [00:02<00:00, 18.06it/s]


0.3715086877346039


100%|██████████| 40/40 [00:02<00:00, 16.67it/s]


0.3506080108444865


100%|██████████| 40/40 [00:02<00:00, 17.86it/s]


0.33674886531946135


100%|██████████| 40/40 [00:02<00:00, 18.17it/s]


0.3113880324654463


100%|██████████| 40/40 [00:02<00:00, 17.89it/s]


0.2800146890122716


100%|██████████| 40/40 [00:02<00:00, 18.02it/s]


0.2571710929638002
validation accuracy for epoch 17 is : 0.7322834730148315


100%|██████████| 40/40 [00:02<00:00, 16.54it/s]


0.24050199876471265


100%|██████████| 40/40 [00:02<00:00, 15.55it/s]


0.23378275171285723


100%|██████████| 40/40 [00:02<00:00, 17.91it/s]


0.19725232698568484


100%|██████████| 40/40 [00:02<00:00, 17.88it/s]


0.172179400193982


100%|██████████| 40/40 [00:02<00:00, 18.00it/s]


0.1582961305975914


100%|██████████| 40/40 [00:02<00:00, 17.29it/s]


0.17495775077401138
validation accuracy for epoch 23 is : 0.8136482834815979


100%|██████████| 40/40 [00:02<00:00, 16.36it/s]


0.15859320160092377


100%|██████████| 40/40 [00:02<00:00, 17.84it/s]


0.12713069105293692


100%|██████████| 40/40 [00:02<00:00, 17.74it/s]


0.12941828515471482


100%|██████████| 40/40 [00:02<00:00, 17.90it/s]


0.10488486262719805


100%|██████████| 40/40 [00:02<00:00, 17.55it/s]


0.0901439150840771


100%|██████████| 40/40 [00:02<00:00, 16.25it/s]


0.0839554860097606
validation accuracy for epoch 29 is : 0.8136482834815979


100%|██████████| 40/40 [00:02<00:00, 17.75it/s]


0.07276199276490909


100%|██████████| 40/40 [00:02<00:00, 17.61it/s]


0.08911125598157324


100%|██████████| 40/40 [00:02<00:00, 17.59it/s]


0.08113347684464804


100%|██████████| 40/40 [00:02<00:00, 17.58it/s]


0.07095824345582868


100%|██████████| 40/40 [00:02<00:00, 16.05it/s]


0.07367337199791175


100%|██████████| 40/40 [00:02<00:00, 17.65it/s]


0.05048013655695974
validation accuracy for epoch 35 is : 0.7769029140472412


100%|██████████| 40/40 [00:02<00:00, 17.58it/s]


0.0449545842300101


100%|██████████| 40/40 [00:02<00:00, 17.46it/s]


0.04439944287807476


100%|██████████| 40/40 [00:02<00:00, 17.49it/s]


0.04857441878355131


100%|██████████| 40/40 [00:02<00:00, 15.98it/s]


0.04598690142355314


100%|██████████| 40/40 [00:02<00:00, 17.51it/s]


0.09545704895039885


100%|██████████| 40/40 [00:02<00:00, 17.06it/s]


0.07161402770477097
validation accuracy for epoch 41 is : 0.8530183434486389


100%|██████████| 40/40 [00:02<00:00, 17.55it/s]


0.047220688449536884


100%|██████████| 40/40 [00:02<00:00, 17.62it/s]


0.060880362578644986


100%|██████████| 40/40 [00:02<00:00, 16.34it/s]


0.043700653105610755


100%|██████████| 40/40 [00:02<00:00, 17.35it/s]


0.03140346908078688


100%|██████████| 40/40 [00:02<00:00, 17.58it/s]


0.027771229424127717


100%|██████████| 40/40 [00:02<00:00, 17.38it/s]


0.023274189724427897
validation accuracy for epoch 47 is : 0.8110235929489136


100%|██████████| 40/40 [00:02<00:00, 17.59it/s]


0.030045895133076643


100%|██████████| 40/40 [00:02<00:00, 15.51it/s]


0.03679520691313395


100%|██████████| 40/40 [00:02<00:00, 16.45it/s]


0.026376128333007415


100%|██████████| 40/40 [00:02<00:00, 17.72it/s]


0.02673690990976444


100%|██████████| 40/40 [00:02<00:00, 17.69it/s]


0.031629302629792105


100%|██████████| 40/40 [00:02<00:00, 16.43it/s]


0.021763787059703978
validation accuracy for epoch 53 is : 0.8188976645469666


100%|██████████| 40/40 [00:02<00:00, 15.98it/s]


0.027040485610685696


100%|██████████| 40/40 [00:02<00:00, 16.93it/s]


0.03580070216572139


100%|██████████| 40/40 [00:02<00:00, 17.69it/s]


0.030689268986262928


100%|██████████| 40/40 [00:02<00:00, 17.21it/s]


0.024752619324206578


100%|██████████| 40/40 [00:02<00:00, 17.86it/s]


0.10397661281976758


100%|██████████| 40/40 [00:02<00:00, 17.12it/s]


0.09918424932331574
validation accuracy for epoch 59 is : 0.8320209980010986


100%|██████████| 40/40 [00:02<00:00, 16.86it/s]


0.053809087768923944


100%|██████████| 40/40 [00:02<00:00, 16.45it/s]


0.02700198543962182


100%|██████████| 40/40 [00:02<00:00, 17.18it/s]


0.021068898393068372


100%|██████████| 40/40 [00:02<00:00, 17.63it/s]


0.03140459027959079


100%|██████████| 40/40 [00:02<00:00, 17.30it/s]


0.06572893847961252


100%|██████████| 40/40 [00:02<00:00, 16.58it/s]


0.04770590301330497
validation accuracy for epoch 65 is : 0.8372703194618225


100%|██████████| 40/40 [00:02<00:00, 17.72it/s]


0.02721556507777877


100%|██████████| 40/40 [00:02<00:00, 17.66it/s]


0.050932803531972375


100%|██████████| 40/40 [00:02<00:00, 17.63it/s]


0.03211825509078619


100%|██████████| 40/40 [00:02<00:00, 17.53it/s]


0.026547439850685074


100%|██████████| 40/40 [00:02<00:00, 16.18it/s]


0.01607734754272714


100%|██████████| 40/40 [00:02<00:00, 17.74it/s]


0.013998351559588096
validation accuracy for epoch 71 is : 0.8241469860076904


100%|██████████| 40/40 [00:02<00:00, 17.58it/s]


0.017675089150122027


100%|██████████| 40/40 [00:02<00:00, 17.63it/s]


0.01296647919750795


100%|██████████| 40/40 [00:02<00:00, 17.07it/s]


0.013434515845757432


100%|██████████| 40/40 [00:02<00:00, 16.01it/s]


0.01295044989802125


100%|██████████| 40/40 [00:02<00:00, 17.72it/s]


0.010994154982632252


100%|██████████| 40/40 [00:02<00:00, 17.68it/s]


0.015285747489187776
validation accuracy for epoch 77 is : 0.8110235929489136


100%|██████████| 40/40 [00:02<00:00, 17.67it/s]


0.014548755500738213


100%|██████████| 40/40 [00:02<00:00, 17.60it/s]


0.013269333577737576


100%|██████████| 40/40 [00:02<00:00, 16.29it/s]


0.012331777070535392


100%|██████████| 40/40 [00:02<00:00, 17.10it/s]


0.05515711003255735


100%|██████████| 40/40 [00:02<00:00, 17.56it/s]


0.018257585763022666


100%|██████████| 40/40 [00:02<00:00, 17.66it/s]


0.014138126654959306
validation accuracy for epoch 83 is : 0.8241469860076904


100%|██████████| 40/40 [00:02<00:00, 17.52it/s]


0.025602213765789823


100%|██████████| 40/40 [00:02<00:00, 16.37it/s]


0.014760287620536074


100%|██████████| 40/40 [00:02<00:00, 16.79it/s]


0.03934084462774236


100%|██████████| 40/40 [00:02<00:00, 17.49it/s]


0.10039095353426003


100%|██████████| 40/40 [00:02<00:00, 17.50it/s]


0.11261047295680861


100%|██████████| 40/40 [00:02<00:00, 17.24it/s]


0.053909162286578154
validation accuracy for epoch 89 is : 0.8188976645469666


100%|██████████| 40/40 [00:02<00:00, 16.52it/s]


0.020830673869790102


100%|██████████| 40/40 [00:02<00:00, 16.05it/s]


0.020806623520556747


100%|██████████| 40/40 [00:02<00:00, 17.36it/s]


0.014672363778894268


100%|██████████| 40/40 [00:02<00:00, 17.54it/s]


0.014211026576870098


100%|██████████| 40/40 [00:02<00:00, 17.37it/s]


0.022343165275254626


100%|██████████| 40/40 [00:02<00:00, 16.62it/s]


0.011459235045150286
validation accuracy for epoch 95 is : 0.8320209980010986


100%|██████████| 40/40 [00:02<00:00, 16.40it/s]


0.012905118793885156


100%|██████████| 40/40 [00:02<00:00, 17.35it/s]


0.008543537577568757


100%|██████████| 40/40 [00:02<00:00, 17.33it/s]


0.007871362908811467


100%|██████████| 40/40 [00:02<00:00, 17.23it/s]

0.008399380439120094
Finished Training





In [19]:
msresnet_test = MSResNet(input_channel=1, layers=[1, 1, 1, 1], num_classes=5)
msresnet_test = msresnet_test.cuda()
msresnet_test.load_state_dict(best_model)
msresnet_test.eval()


MSResNet(
  (conv1): Conv1d(1, 64, kernel_size=(7,), stride=(2,), padding=(3,), bias=False)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer3x3_1): Sequential(
    (0): BasicBlock3x3(
      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv1d(64, 64, kernel_size=(1,), stride=(2,), bias=False)
        (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layer3x3_2): Sequential(
    (0): BasicB

In [20]:

correct_test = 0
for i, data in enumerate(test_loader):
    with torch.no_grad():
        midi,label = data
        midi = midi.cuda()
        label = label[:,meta_num]
        num = label.shape[0]
        val_num += num
        outputs = msresnet(midi)[0].argmax(1).cpu()
        correct_num += sum(outputs==label)
        acc = correct_num/val_num
print("Test accuracy:", (100 * float(acc)))


Test accuracy: 81.2062919139862
