# Training & Generation for Paper "Responding to the Call: Exploring Automatic Music Composition Using a Knowledge-Enhanced Model"

In [2]:
# GPU settings
# On Mac (Apple Silicon), CUDA is not available.
# We will select device (CPU or MPS) later using torch, so no need to set CUDA_VISIBLE_DEVICES here.
device = torch.device("mps") 

In [1]:
import sys
import math
import time
import glob
import datetime
import random
import pickle
import json
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from main_knowledge import *
import saver

In [3]:
# ###--- data ---###
path_data_root = './dataset/'
path_test_data = os.path.join(path_data_root, 'test.npz')
path_train_data = os.path.join(path_data_root, 'train.npz')
path_dictionary =  os.path.join(path_data_root, 'dictionary.pkl')

##uncomment the following to run on the complete training and test data
###--- data ---###
# path_data_root = '../train_test_data/'
# path_test_data = os.path.join(path_data_root, 'test.npz')
# path_train_data = os.path.join(path_data_root, 'train.npz')
# path_dictionary =  os.path.join(path_data_root, 'dictionary.pkl')



###--- training config ---###
path_exp = './exp'
batch_size =8
init_lr = 0.0001
max_grad_norm = 3
path_gendir = 'gen_midis'

import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")     # Apple GPU
else:
    device = torch.device("cpu")
# NaN が続く場合の対処: 以下で CPU に切り替えて再実行してみてください
# device = torch.device("cpu")

In [4]:
def get_train_data():
  dictionary = pickle.load(open(path_dictionary, 'rb'))
  event2word, word2event = dictionary
  train_data = np.load(path_train_data,allow_pickle=True)
  return train_data, event2word, word2event, dictionary
def get_test_data():
  dictionary = pickle.load(open(path_dictionary, 'rb'))
  event2word, word2event = dictionary
  test_data = np.load(path_test_data,allow_pickle=True)
  return test_data, event2word, word2event, dictionary


In [5]:
# Load data
train_data, event2word, word2event, dictionary = get_train_data()
test_data, event2word, word2event, dictionary = get_test_data()



# config
n_class = []
for key in event2word.keys():
    n_class.append(len(dictionary[0][key]))
print('num of classes:', n_class)

# # unpack
train_x = train_data['x']
train_y = train_data['y']
train_mask = train_data['mask']


fact_candidate=np.load('./dataset/knowledge.npy',allow_pickle=True)

# uncomment the following to load predefined external knowledge
# fact_candidate=torch.load('../train_test_data/external_knowledge/candidate_train')


test_x = test_data['x'][:]
test_y = test_data['y'][:]
test_mask = test_data['mask'][:]

# run
start_time = time.time()

num of classes: [234, 135, 18, 7, 130, 22, 130]


Initialize Model

In [6]:
net = TransformerModel(n_class)
info_load_model = False
# Use single-device training (CPU or MPS) on Mac
net.to(device)
net.train()
n_parameters = network_paras(net)
print('n_parameters: {:,}'.format(n_parameters))
# optimizers
optimizer = optim.Adam(net.parameters(), lr=init_lr)


>>>>>: [234, 135, 18, 7, 130, 22, 130]
n_parameters: 46,876,164


Train Model

In [8]:
saver_agent = saver.Saver(path_exp)
###TAsk1 ###
num_batch = len(train_x) // batch_size
candidate_number = 3
n_epoch = 5
start_time = time.time()
for epoch in range(n_epoch):
    acc_loss = 0
    acc_losses = np.zeros(7)
    n_updated = 0  # 有限 loss で更新したバッチ数（NaN スキップ用）
    with tqdm(range(num_batch)) as bar:
        for bidx in range(num_batch):  # num_batch
            # index
            bidx_st = batch_size * bidx
            bidx_ed = batch_size * (bidx + 1)
            # unpack batch data
            batch_x = train_x[bidx_st:bidx_ed]
            batch_y = train_y[bidx_st:bidx_ed]
            batch_mask = train_mask[bidx_st:bidx_ed]
            batch_x = torch.from_numpy(batch_x).long().to(device)
            batch_y = torch.from_numpy(batch_y).long().to(device)
            batch_mask = torch.from_numpy(batch_mask).float().to(device)

            # first task (loss_mask は float で渡す)
            zeros_token = torch.LongTensor([[0, 0, 0, 0, 0, 0, 0]]).to(device)
            zeros_mask_float = torch.zeros(1, 1, dtype=torch.float32, device=device)
            t1 = torch.cat([batch_y[:, :-1], zeros_token.expand(batch_y.shape[0], 1, 7)], 1)
            t2 = torch.cat([batch_y[:, 1:], zeros_token.expand(batch_y.shape[0], 1, 7)], 1)
            batch_mask1 = torch.cat([batch_mask[:, :-1], zeros_mask_float.expand(batch_y.shape[0], 1)], 1)

            src_mask = []
            for i in range(len(batch_x)):
                src_mask.append(int(torch.where(batch_x[i][:, 3] == 0)[0][0]))

            tgt_mask = []
            for i in range(len(t1)):
                tgt_mask.append(int(torch.where(t1[i][:, 3] == 0)[0][0]))

            losses = net.train_step(batch_x, t1, t2, src_mask, tgt_mask, batch_mask1, None)
            loss_task1 = (losses[0] + losses[1] + losses[2] + losses[3] + losses[4] + losses[5] + losses[6]) / 7

            # second task
            knowledge_base = {}
            knowledge_base['item'] = {}
            knw_mask_t = []
            batch_knowledge = {}
            for idx in range(candidate_number):
                batch_knowledge[idx] = fact_candidate[idx][bidx_st:bidx_ed]
                batch_knowledge[idx] = torch.from_numpy(batch_knowledge[idx]).long().to(device)

                knw_mask = []
                for j in range(len(batch_knowledge[idx])):
                    knw_mask.append(int(torch.where(batch_knowledge[idx][j][:, 3] == 0)[0][0]))
                knw_mask_t.append(knw_mask)

            for i in range(candidate_number):
                knowledge_base['item'][i] = net.forward_hidden(batch_knowledge[i], knw_mask_t[i])

            losses = net.train_step(batch_x, t1, t2, src_mask, tgt_mask, batch_mask1, knowledge_base['item'])
            loss_task2 = (losses[0] + losses[1] + losses[2] + losses[3] + losses[4] + losses[5] + losses[6]) / 7

            # third task
            loss_task3 = 0
            for can in range(candidate_number):
                t1 = torch.cat([batch_knowledge[can][:, :-1], zeros_token.expand(batch_knowledge[can].shape[0], 1, 7)], 1)
                t2 = torch.cat([batch_knowledge[can][:, 1:], zeros_token.expand(batch_knowledge[can].shape[0], 1, 7)], 1)
                mask_list = []
                for i in knw_mask_t[can]:
                    mask = np.concatenate([np.ones(i), np.zeros(256 - i)])
                    mask_list.append(mask)
                batch_mask = torch.tensor(mask_list, dtype=torch.float32, device=device)
                batch_mask1 = torch.cat([batch_mask[:, :-1], zeros_mask_float.expand(batch_y.shape[0], 1)], 1)
                tgt_mask = []
                for i in range(len(t1)):
                    tgt_mask.append(int(torch.where(t1[i][:, 3] == 0)[0][0]))
                losses = net.train_step(batch_x, t1, t2, src_mask, tgt_mask, batch_mask1, knowledge_base['item'])
                loss_3 = (losses[0] + losses[1] + losses[2] + losses[3] + losses[4] + losses[5] + losses[6]) / 7
                loss_task3 += loss_3

            loss_task3 = loss_task3 / candidate_number

            # final loss
            loss = loss_task1 + (loss_task2 + loss_task3) / 2

            # Update (NaN/Inf のときは更新をスキップしてモデルを守る)
            if torch.isfinite(loss):
                net.zero_grad()
                loss.backward()
                if max_grad_norm is not None:
                    clip_grad_norm_(net.parameters(), max_grad_norm)
                optimizer.step()
                acc_losses += np.array([l.item() for l in losses])
                acc_loss += loss.item()
                n_updated += 1
            else:
                net.zero_grad()

            # print
            sys.stdout.write('{}/{} | Loss: {:06f} | {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}, {:04f}\r'.format(
                bidx, num_batch, loss.item() if torch.isfinite(loss) else float('nan'),
                losses[0].item() if torch.isfinite(losses[0]) else float('nan'),
                losses[1].item() if torch.isfinite(losses[1]) else float('nan'),
                losses[2].item() if torch.isfinite(losses[2]) else float('nan'),
                losses[3].item() if torch.isfinite(losses[3]) else float('nan'),
                losses[4].item() if torch.isfinite(losses[4]) else float('nan'),
                losses[5].item() if torch.isfinite(losses[5]) else float('nan'),
                losses[6].item() if torch.isfinite(losses[6]) else float('nan')))
            sys.stdout.flush()
            bar.update()

    # epoch loss（有限で更新したバッチ数で割る）
    runtime = time.time() - start_time
    epoch_loss = acc_loss / n_updated if n_updated > 0 else float('nan')
    acc_losses = acc_losses / n_updated if n_updated > 0 else acc_losses
    print('------------------------------------')
    print('epoch: {}/{} | Loss: {} | time: {} | batches_ok: {}/{}'.format(
        epoch + 1, n_epoch, epoch_loss, str(datetime.timedelta(seconds=runtime)), n_updated, num_batch))

    # save model, with policy
    loss = epoch_loss
    if 0.4 < loss <= 1:
        fn = int(loss * 10) * 10
        saver_agent.save_model(net, name='loss_' + str(fn))
    elif 0.01 < loss <= 0.40:
        fn = int(loss * 100)
        saver_agent.save_model(net, name='loss_' + str(fn))
    elif loss <= 0.01:
        print('Finished')
    else:
        saver_agent.save_model(net, name='loss_high' + "_epoch_" + str(epoch))

  0%|          | 0/12 [00:00<?, ?it/s]

  batch_mask = torch.tensor(mask_list, dtype=torch.float32, device=device)


0/12 | Loss: 7.940293 | 4.742430, 5.434772, 2.712217, 1.711072, 4.584793, 3.487069, 4.829752

  8%|▊         | 1/12 [00:59<10:58, 59.89s/it]

1/12 | Loss: 6.673228 | 3.813344, 4.247586, 1.997353, 0.832455, 4.569431, 3.098693, 4.679370

 17%|█▋        | 2/12 [04:28<24:35, 147.51s/it]

2/12 | Loss: 5.944417 | 3.009916, 3.337332, 1.870485, 0.831047, 4.305063, 2.868166, 4.513022

 25%|██▌       | 3/12 [08:24<28:08, 187.67s/it]

3/12 | Loss: 5.385541 | 2.324681, 2.666664, 1.687931, 0.810862, 4.262573, 2.710723, 4.360603

 33%|███▎      | 4/12 [18:16<46:19, 347.39s/it]

4/12 | Loss: 5.063503 | 2.065205, 2.411430, 1.685024, 0.847697, 4.100104, 2.599566, 4.072678

 42%|████▏     | 5/12 [21:38<34:25, 295.10s/it]

5/12 | Loss: 4.737877 | 1.768463, 2.075567, 1.663471, 0.816741, 4.013179, 2.532423, 3.830820

 50%|█████     | 6/12 [25:18<26:57, 269.56s/it]

6/12 | Loss: 4.512064 | 1.567641, 1.809139, 1.601376, 0.806081, 3.896291, 2.301658, 3.690968

 58%|█████▊    | 7/12 [28:01<19:33, 234.62s/it]

7/12 | Loss: 4.287124 | 1.480174, 1.697854, 1.692660, 0.807396, 3.667346, 2.238256, 3.414869

 67%|██████▋   | 8/12 [32:26<16:17, 244.28s/it]

8/12 | Loss: 4.113119 | 1.335177, 1.583395, 1.666529, 0.790374, 3.476475, 2.142140, 3.412488

 75%|███████▌  | 9/12 [35:32<11:17, 225.94s/it]

9/12 | Loss: 3.919990 | 1.160221, 1.440814, 1.552501, 0.774239, 3.400986, 2.030225, 3.287166

 83%|████████▎ | 10/12 [38:44<07:11, 215.71s/it]

10/12 | Loss: 3.755123 | 0.954517, 1.241985, 1.442388, 0.741055, 3.476633, 2.026474, 3.380581

 92%|█████████▏| 11/12 [43:16<03:52, 232.77s/it]

11/12 | Loss: 3.689611 | 0.875644, 1.189360, 1.535289, 0.762414, 3.484812, 2.121113, 3.195385

100%|██████████| 12/12 [1:13:44<00:00, 368.74s/it]

------------------------------------
epoch: 1/5 | Loss: 5.0018242200215655 | time: 1:13:44.866081 | batches_ok: 12/12
 [*] saving model to ./exp, name: loss_high_epoch_0



  0%|          | 0/12 [00:00<?, ?it/s]

0/12 | Loss: 3.541519 | 0.743523, 1.093144, 1.439790, 0.747074, 3.199582, 1.985071, 3.183506

  8%|▊         | 1/12 [02:57<32:33, 177.56s/it]

1/12 | Loss: 3.431230 | 0.568942, 0.933158, 1.280981, 0.693550, 3.371836, 1.977018, 3.340547

 17%|█▋        | 2/12 [06:29<32:56, 197.63s/it]

2/12 | Loss: 3.277850 | 0.522235, 0.894832, 1.269280, 0.714085, 3.146564, 1.736918, 3.283634

 25%|██▌       | 3/12 [08:55<26:05, 173.95s/it]

3/12 | Loss: 3.205020 | 0.422663, 0.697381, 1.131768, 0.687285, 3.264328, 1.827230, 3.259530

 33%|███▎      | 4/12 [11:20<21:42, 162.85s/it]

4/12 | Loss: 3.126839 | 0.450148, 0.798900, 1.260865, 0.744214, 3.065507, 1.799808, 2.814043

 42%|████▏     | 5/12 [14:26<19:57, 171.14s/it]

5/12 | Loss: 3.031525 | 0.341029, 0.661826, 1.300632, 0.710162, 3.024765, 1.825780, 2.751773

 50%|█████     | 6/12 [18:17<19:09, 191.57s/it]

6/12 | Loss: 3.015814 | 0.365505, 0.691497, 1.273282, 0.720721, 3.057775, 1.632331, 2.765865

 58%|█████▊    | 7/12 [20:52<14:56, 179.39s/it]

7/12 | Loss: 2.934023 | 0.359105, 0.683229, 1.408111, 0.740560, 2.869030, 1.604786, 2.552162

 67%|██████▋   | 8/12 [27:31<16:37, 249.28s/it]

8/12 | Loss: 2.858841 | 0.326201, 0.641367, 1.406396, 0.740294, 2.714448, 1.553025, 2.515364

 75%|███████▌  | 9/12 [30:28<11:20, 226.68s/it]

9/12 | Loss: 2.763304 | 0.308783, 0.634121, 1.300616, 0.732511, 2.711044, 1.462580, 2.404343

 83%|████████▎ | 10/12 [33:33<07:07, 213.84s/it]

10/12 | Loss: 2.712698 | 0.266342, 0.572841, 1.227334, 0.708226, 2.862624, 1.542729, 2.638016

 92%|█████████▏| 11/12 [36:57<03:30, 210.81s/it]

11/12 | Loss: 2.910830 | 0.257439, 0.581181, 1.339155, 0.756046, 3.135469, 1.722997, 2.651734

100%|██████████| 12/12 [39:50<00:00, 199.25s/it]

------------------------------------
epoch: 2/5 | Loss: 3.0674577554066977 | time: 1:53:37.348024 | batches_ok: 12/12
 [*] saving model to ./exp, name: loss_high_epoch_1



  0%|          | 0/12 [00:00<?, ?it/s]

0/12 | Loss: 2.772635 | 0.245213, 0.582078, 1.242617, 0.740015, 2.593395, 1.534751, 2.713926

  8%|▊         | 1/12 [01:57<21:30, 117.29s/it]

1/12 | Loss: 2.744488 | 0.190346, 0.552776, 1.103043, 0.652681, 2.863896, 1.612399, 2.897841

 17%|█▋        | 2/12 [04:18<21:52, 131.20s/it]

2/12 | Loss: 2.674104 | 0.227872, 0.573631, 1.087437, 0.693090, 2.577864, 1.370390, 3.001666

 25%|██▌       | 3/12 [06:59<21:43, 144.84s/it]

3/12 | Loss: 2.650094 | 0.202107, 0.426997, 0.962114, 0.686664, 2.703645, 1.513108, 2.909814

 33%|███▎      | 4/12 [09:34<19:51, 148.94s/it]

4/12 | Loss: 2.550059 | 0.281821, 0.591208, 1.092019, 0.715007, 2.522662, 1.527362, 2.359272

 42%|████▏     | 5/12 [12:03<17:22, 148.95s/it]

5/12 | Loss: 2.480480 | 0.201228, 0.467961, 1.145132, 0.676039, 2.465185, 1.580913, 2.310281

 50%|█████     | 6/12 [14:12<14:12, 142.04s/it]

6/12 | Loss: 2.631021 | 0.273063, 0.534167, 1.116470, 0.698224, 2.638277, 1.403778, 2.300033

 58%|█████▊    | 7/12 [15:59<10:54, 130.82s/it]

7/12 | Loss: 2.585763 | 0.285767, 0.551025, 1.250572, 0.710926, 2.493743, 1.390654, 2.164216

 67%|██████▋   | 8/12 [17:59<08:28, 127.14s/it]

8/12 | Loss: 2.481032 | 0.263288, 0.525022, 1.258320, 0.695173, 2.435877, 1.333663, 2.136520

 75%|███████▌  | 9/12 [20:16<06:30, 130.25s/it]

9/12 | Loss: 2.419820 | 0.256173, 0.521473, 1.161520, 0.683896, 2.486696, 1.256647, 2.053239

 83%|████████▎ | 10/12 [22:25<04:20, 130.02s/it]

10/12 | Loss: 2.391836 | 0.224069, 0.473274, 1.086724, 0.668317, 2.660575, 1.359873, 2.316467

 92%|█████████▏| 11/12 [24:50<02:14, 134.66s/it]

11/12 | Loss: 2.679949 | 0.222629, 0.513678, 1.195316, 0.674163, 3.028886, 1.564899, 2.475988

100%|██████████| 12/12 [27:10<00:00, 135.85s/it]

------------------------------------
epoch: 3/5 | Loss: 2.588439921538035 | time: 2:20:48.011048 | batches_ok: 12/12
 [*] saving model to ./exp, name: loss_high_epoch_2



  0%|          | 0/12 [00:00<?, ?it/s]

0/12 | Loss: 2.528439 | 0.216752, 0.481240, 1.116181, 0.690079, 2.404269, 1.367196, 2.522818

  8%|▊         | 1/12 [02:10<23:52, 130.23s/it]

1/12 | Loss: 2.527067 | 0.162367, 0.468114, 0.978122, 0.615470, 2.718792, 1.466161, 2.711597

 17%|█▋        | 2/12 [04:35<23:11, 139.15s/it]

2/12 | Loss: 2.483035 | 0.207221, 0.505805, 0.986567, 0.606975, 2.421234, 1.264939, 2.842391

 25%|██▌       | 3/12 [06:56<21:00, 140.07s/it]

3/12 | Loss: 2.457973 | 0.185041, 0.379074, 0.886742, 0.631476, 2.528544, 1.408723, 2.744700

 33%|███▎      | 4/12 [09:36<19:43, 147.96s/it]

4/12 | Loss: 2.342227 | 0.264745, 0.535059, 1.021010, 0.590495, 2.342260, 1.380138, 2.218275

 42%|████▏     | 5/12 [11:54<16:48, 144.09s/it]

5/12 | Loss: 2.296931 | 0.185877, 0.420261, 1.056922, 0.604472, 2.296308, 1.474949, 2.180121

 50%|█████     | 6/12 [14:12<14:13, 142.24s/it]

6/12 | Loss: 2.423631 | 0.259520, 0.474356, 1.022384, 0.690223, 2.483207, 1.292141, 2.146458

 58%|█████▊    | 7/12 [16:14<11:17, 135.41s/it]

7/12 | Loss: 2.393756 | 0.276285, 0.509369, 1.150860, 0.685908, 2.364988, 1.295784, 2.026792

 67%|██████▋   | 8/12 [18:15<08:43, 130.97s/it]

8/12 | Loss: 2.351007 | 0.255640, 0.482996, 1.172440, 0.654856, 2.342281, 1.253353, 2.047616

 75%|███████▌  | 9/12 [20:36<06:42, 134.16s/it]

9/12 | Loss: 2.307603 | 0.247785, 0.479031, 1.096418, 0.619164, 2.398427, 1.178603, 2.002655

 83%|████████▎ | 10/12 [22:48<04:26, 133.47s/it]

10/12 | Loss: 2.287364 | 0.218505, 0.436299, 1.019423, 0.668658, 2.570971, 1.289990, 2.242699

 92%|█████████▏| 11/12 [25:10<02:16, 136.08s/it]

11/12 | Loss: 2.592955 | 0.218874, 0.508168, 1.142329, 0.654534, 3.018585, 1.514839, 2.375164

100%|██████████| 12/12 [28:29<00:00, 142.45s/it]

------------------------------------
epoch: 4/5 | Loss: 2.415998876094818 | time: 2:49:17.939095 | batches_ok: 12/12
 [*] saving model to ./exp, name: loss_high_epoch_3



  0%|          | 0/12 [00:00<?, ?it/s]

0/12 | Loss: 2.445384 | 0.211057, 0.445362, 1.068253, 0.655645, 2.359735, 1.319974, 2.439739

  8%|▊         | 1/12 [02:03<22:34, 123.17s/it]

1/12 | Loss: 2.450791 | 0.154860, 0.435413, 0.933436, 0.601720, 2.658819, 1.422014, 2.639668

 17%|█▋        | 2/12 [04:22<22:07, 132.74s/it]

2/12 | Loss: 2.396947 | 0.200556, 0.480146, 0.949934, 0.592843, 2.342627, 1.214474, 2.744860

 25%|██▌       | 3/12 [06:22<19:02, 126.92s/it]

3/12 | Loss: 2.384188 | 0.176860, 0.354062, 0.852017, 0.597666, 2.451303, 1.383612, 2.672909

 33%|███▎      | 4/12 [08:25<16:43, 125.41s/it]

4/12 | Loss: 2.269512 | 0.255654, 0.513087, 0.993759, 0.582714, 2.274784, 1.357853, 2.170282

 42%|████▏     | 5/12 [10:49<15:24, 132.11s/it]

5/12 | Loss: 2.237590 | 0.179691, 0.392628, 1.023642, 0.614951, 2.245870, 1.427526, 2.146179

 50%|█████     | 6/12 [13:01<13:11, 131.94s/it]

6/12 | Loss: 2.342012 | 0.255246, 0.455700, 0.994016, 0.687131, 2.448795, 1.283306, 2.114967

 58%|█████▊    | 7/12 [14:57<10:33, 126.64s/it]

7/12 | Loss: 2.326149 | 0.269774, 0.495623, 1.109826, 0.661094, 2.332737, 1.274046, 2.016455

 67%|██████▋   | 8/12 [17:06<08:30, 127.52s/it]

8/12 | Loss: 2.271423 | 0.247071, 0.473081, 1.125527, 0.650764, 2.288570, 1.227504, 1.981307

 75%|███████▌  | 9/12 [19:17<06:26, 128.74s/it]

9/12 | Loss: 2.213410 | 0.234648, 0.464332, 1.056136, 0.619734, 2.330405, 1.150866, 1.928017

 83%|████████▎ | 10/12 [21:16<04:11, 125.56s/it]

10/12 | Loss: 2.204642 | 0.209058, 0.429428, 0.973246, 0.676543, 2.494991, 1.271742, 2.163692

 92%|█████████▏| 11/12 [23:16<02:03, 123.93s/it]

11/12 | Loss: 2.521985 | 0.211412, 0.518865, 1.095053, 0.686173, 2.944493, 1.476064, 2.292476

100%|██████████| 12/12 [25:41<00:00, 128.48s/it]

------------------------------------
epoch: 5/5 | Loss: 2.3386694192886353 | time: 3:15:00.127150 | batches_ok: 12/12
 [*] saving model to ./exp, name: loss_high_epoch_4





In [7]:
device = torch.device("cpu")

net = TransformerModel(n_class, is_training=False)

path_saved_ckpt = "./exp/loss_high_epoch_4_params.pt"
state = torch.load(path_saved_ckpt, map_location="cpu")
net.load_state_dict(state)

net.to("cpu")
net.eval()

>>>>>: [234, 135, 18, 7, 130, 22, 130]


TransformerModel(
  (loss_func): CrossEntropyLoss()
  (word_emb_tempo): Embeddings(
    (lut): Embedding(234, 512)
  )
  (word_emb_chord): Embeddings(
    (lut): Embedding(135, 256)
  )
  (word_emb_barbeat): Embeddings(
    (lut): Embedding(18, 64)
  )
  (word_emb_type): Embeddings(
    (lut): Embedding(7, 32)
  )
  (word_emb_pitch): Embeddings(
    (lut): Embedding(130, 512)
  )
  (word_emb_duration): Embeddings(
    (lut): Embedding(22, 128)
  )
  (word_emb_velocity): Embeddings(
    (lut): Embedding(130, 512)
  )
  (pos_emb): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (in_linear): Linear(in_features=2016, out_features=512, bias=True)
  (linear_knowledge): Linear(in_features=1024, out_features=512, bias=True)
  (knowledge_selector): KnowledgeSelector(
    (linear): Linear(in_features=512, out_features=512, bias=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): Mul

In [8]:
batch_size=1
num_batch = len(test_x) // batch_size
output_total=[]
for bidx in range(5):
    bidx_st = batch_size * bidx
    bidx_ed = batch_size * (bidx + 1)
  # unpack batch data
    batch_x = test_x[bidx_st:bidx_ed]
    batch_y = test_x[bidx_st:bidx_ed]
    batch_mask = test_x[bidx_st:bidx_ed]
    batch_x = torch.from_numpy(batch_x).long().to(device)
    batch_y = torch.from_numpy(batch_y).long().to(device)
    batch_mask = torch.from_numpy(batch_mask).float().to(device)
    if isinstance(net, torch.nn.DataParallel):
          net = net.module

    src_mask=[]
    for i in range(len(batch_x)):
        src_mask.append(int(torch.where(batch_x[i][:,3]==0)[0][0]))
    output=net.inference(batch_x,src_mask,dictionary)
    output_total.append(output)
 

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


------ generate ------
bar: 1  ==Tempo_90        | D_o             | Beat_12         | Metrical        | 0               | 0               | Note_Velocity_39 | 
bar: 1  ==Tempo_102       | G_m             | 0               | SOC             | Note_Pitch_5    | Note_Duration_2040 | Note_Velocity_74 | 
bar: 1  ==0               | 0               | 0               | Note            | Note_Pitch_71   | Note_Duration_1320 | Note_Velocity_123 | 
bar: 1  ==0               | F#_m7           | 0               | Note            | 0               | Note_Duration_240 | Note_Velocity_89 | 
bar: 1  ==Tempo_91        | 0               | 0               | Note            | Note_Pitch_67   | Note_Duration_1920 | Note_Velocity_96 | 
bar: 1  ==CONTI           | B_7             | Beat_0          | Metrical        | 0               | 0               | Note_Velocity_58 | 
bar: 1  ==Tempo_76        | A#_/o7          | Beat_2          | SOC             | Note_Pitch_83   | Note_Duration_1920 | Note_Velocity_32

In [91]:
import os
os.makedirs("result", exist_ok=True)
for i in range(len(output_total)):
    write_midi(output_total[i],'./result/'+str(i)+'.mid',word2event)

TypeError: argument of type 'int' is not iterable

In [19]:
import os
os.makedirs("result", exist_ok=True)

call_0 = test_x[0]              # shape: (L, 7)
write_midi(call_0, "./result/call_0.mid", word2event)
print("saved -> ./result/call_0.mid")

saved -> ./result/call_0.mid
