In [4]:
from model import MusicTransformer
import custom
from custom.metrics import *
from custom.criterion import SmoothCrossEntropyLoss, CustomSchedule
from custom.config import config
from data import Data

import utils
import datetime
import time

import torch
import torch.optim as optim
from tensorboardX import SummaryWriter

# check cuda
if torch.cuda.is_available():
    config.device = torch.device('cuda')
else:
    config.device = torch.device('cpu')


In [2]:
import os
from midi_processor.processor import encode_midi
from preprocess import preprocess_midi
from progress.bar import Bar
import pickle
# data = preprocess_midi('preprocessed_midi/midi')

midi_paths = list(utils.find_files_by_extensions('dataset\\midi', ['.mid', '.midi']))
os.makedirs('dataset\\midi', exist_ok=True)
os.makedirs('dataset\\preprocessed_midi', exist_ok=True)
# out_fmt = '{}-{}.data'


for path in Bar('Processing').iter(midi_paths):
    print(' ', end='[{}]'.format(path), flush=True)

    try:
        data = preprocess_midi(path)
    except KeyboardInterrupt:
        print(' Abort')
    except EOFError:
        print('EOF Error')

    with open('{}\\{}.pickle'.format('dataset\\preprocessed_midi', path.split('\\')[-1]), 'wb') as f:
        pickle.dump(data, f)

 [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_01_R1_2014_wav--1.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_01_R1_2014_wav--2.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_01_R1_2014_wav--3.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_01_R1_2014_wav--5.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_02_R1_2014_wav--1.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_02_R1_2014_wav--2.midi] Abort
 [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_02_R1_2014_wav--4.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_02_R1_2014_wav--5.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_03_R1_2014_wav--2.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_03_R1_2014_wav--3.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_03_R1_2014_wav--4.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_03_R1_2014_wav--5.midi] [dataset\midi\MIDI-UNPROCESSED_0

In [16]:
from custom.config import config
model_dir = "config"
configs = ["config\\thor_basic.yml"]
config.load(model_dir, configs)
model_folder = "models"
config.pickle_dir = "dataset\\temp"
config

CONFIG_FILE_NAME = save.yml
batch_size = 8
debug = true
device = cuda
dropout = 0.1
embedding_dim = 128
epochs = 100
event_dim = 388
experiment = embedding256-layer6
fp16 = None
l_r = 0.01
label_smooth = 0.1
load_path = None
max_seq = 1024
num_layers = 6
pad_token = 388
pickle_dir = dataset\temp
token_eos = 390
token_sos = 389
vocab_size = 391

In [7]:
load_model_path = "C:\\Users\\Draco\\Documents\\GitHub\\MusicTransformer-pytorch\\models\\train-10.pth"
model = MusicTransformer(
            embedding_dim=config.embedding_dim,
            vocab_size=config.vocab_size,
            num_layer=config.num_layers,
            max_seq=config.max_seq,
            dropout=config.dropout,
            debug=config.debug, loader_path=config.load_path
)
model.load_state_dict(torch.load(load_model_path, weights_only=True))
model.eval()

  return self.fget.__get__(instance, owner)()


MusicTransformer(
  (Decoder): Encoder(
    (embedding): Embedding(391, 128)
    (pos_encoding): DynamicPositionEmbedding()
    (enc_layers): ModuleList(
      (0-3): 4 x EncoderLayer(
        (rga): RelativeGlobalAttention(
          (Wq): Linear(in_features=128, out_features=128, bias=True)
          (Wk): Linear(in_features=128, out_features=128, bias=True)
          (Wv): Linear(in_features=128, out_features=128, bias=True)
          (fc): Linear(in_features=128, out_features=128, bias=True)
        )
        (FFN_pre): Linear(in_features=128, out_features=64, bias=True)
        (FFN_suf): Linear(in_features=64, out_features=128, bias=True)
        (layernorm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (layernorm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (fc): Linear(in_features=128, 

In [19]:
# load data
# dataset = Data(config.pickle_dir)
# print(dataset)


# # load model
# learning_rate = config.l_r

# # define model
# # mt = model
# mt = MusicTransformer(
#             embedding_dim=config.embedding_dim,
#             vocab_size=config.vocab_size,
#             num_layer=config.num_layers,
#             max_seq=config.max_seq,
#             dropout=config.dropout,
#             debug=config.debug, loader_path=config.load_path
# )
# mt.to(config.device)
opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = CustomSchedule(config.embedding_dim, optimizer=opt)

# multi-GPU set
if torch.cuda.device_count() > 1:
    single_mt = mt
    mt = torch.nn.DataParallel(mt, output_device=torch.cuda.device_count()-1)
else:
    single_mt = mt

# init metric set
metric_set = MetricsSet({
    'accuracy': CategoricalAccuracy(),
    'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token),
    'bucket':  LogitsBucketting(config.vocab_size)
})

print(mt)
print('| Summary - Device Info : {}'.format(torch.cuda.device))

# define tensorboard writer
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir = 'logs/'+config.experiment+'/'+current_time+'/train'
eval_log_dir = 'logs/'+config.experiment+'/'+current_time+'/eval'

train_summary_writer = SummaryWriter(train_log_dir)
eval_summary_writer = SummaryWriter(eval_log_dir)

# Train Start
print(">> Train start...")
idx = 0
for e in range(config.epochs):
    print(">>> [Epoch was updated]")
    for b in range(len(dataset.files) // config.batch_size):
        scheduler.optimizer.zero_grad()
        try:
            batch_x, batch_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq)
            batch_x = torch.from_numpy(batch_x).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
            batch_y = torch.from_numpy(batch_y).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
        except IndexError:
            continue

        start_time = time.time()
        mt.train()
        sample = mt.forward(batch_x)
        metrics = metric_set(sample, batch_y)
        loss = metrics['loss']
        loss.backward()
        scheduler.step()
        end_time = time.time()

        if config.debug:
            print("[Loss]: {}".format(loss))

        train_summary_writer.add_scalar('loss', metrics['loss'], global_step=idx)
        train_summary_writer.add_scalar('accuracy', metrics['accuracy'], global_step=idx)
        train_summary_writer.add_scalar('learning_rate', scheduler.rate(), global_step=idx)
        train_summary_writer.add_scalar('iter_p_sec', end_time-start_time, global_step=idx)

        # result_metrics = metric_set(sample, batch_y)
        if b % 100 == 0:
            single_mt.eval()
            eval_x, eval_y = dataset.slide_seq2seq_batch(2, config.max_seq, 'eval')
            eval_x = torch.from_numpy(eval_x).contiguous().to(config.device, dtype=torch.int)
            eval_y = torch.from_numpy(eval_y).contiguous().to(config.device, dtype=torch.int)

            eval_preiction, weights = single_mt.forward(eval_x)

            eval_metrics = metric_set(eval_preiction, eval_y)
            if e % 10 == 0:
                torch.save(single_mt.state_dict(), model_folder+'/train-{}.pth'.format(e))
            if b == 0:
                train_summary_writer.add_histogram("target_analysis", batch_y, global_step=e)
                train_summary_writer.add_histogram("source_analysis", batch_x, global_step=e)
                for i, weight in enumerate(weights):
                    attn_log_name = "attn/layer-{}".format(i)
                    utils.attention_image_summary(
                        attn_log_name, weight, step=idx, writer=eval_summary_writer)

            eval_summary_writer.add_scalar('loss', eval_metrics['loss'], global_step=idx)
            eval_summary_writer.add_scalar('accuracy', eval_metrics['accuracy'], global_step=idx)
            eval_summary_writer.add_histogram("logits_bucket", eval_metrics['bucket'], global_step=idx)

            print('\n====================================================')
            print('Epoch/Batch: {}/{}'.format(e, b))
            print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(metrics['loss'], metrics['accuracy']))
            print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_metrics['loss'], eval_metrics['accuracy']))
        torch.cuda.empty_cache()
        idx += 1

        # switch output device to: gpu-1 ~ gpu-n
        sw_start = time.time()
        if torch.cuda.device_count() > 1:
            mt.output_device = idx % (torch.cuda.device_count() -1) + 1
        sw_end = time.time()
        if config.debug:
            print('output switch time: {}'.format(sw_end - sw_start) )

torch.save(single_mt.state_dict(), model_folder+'/final.pth'.format(idx))
eval_summary_writer.close()
train_summary_writer.close()

MusicTransformer(
  (Decoder): Encoder(
    (embedding): Embedding(391, 128)
    (pos_encoding): DynamicPositionEmbedding()
    (enc_layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (rga): RelativeGlobalAttention(
          (Wq): Linear(in_features=128, out_features=128, bias=True)
          (Wk): Linear(in_features=128, out_features=128, bias=True)
          (Wv): Linear(in_features=128, out_features=128, bias=True)
          (fc): Linear(in_features=128, out_features=128, bias=True)
        )
        (FFN_pre): Linear(in_features=128, out_features=64, bias=True)
        (FFN_suf): Linear(in_features=64, out_features=128, bias=True)
        (layernorm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (layernorm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (fc): Linear(in_features=128, 

In [20]:
from midi_processor.processor import encode_midi, decode_midi
dataset = Data(config.pickle_dir)
metric_set = MetricsSet({
    'accuracy': CategoricalAccuracy(),
    'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token),
    'bucket':  LogitsBucketting(config.vocab_size)
})
# mt = model
mt.to(config.device)
mt.train()
batch_x, batch_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq)
batch_x = torch.from_numpy(batch_x).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
batch_y = torch.from_numpy(batch_y).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
sample = mt.forward(batch_x)
metrics = metric_set(sample, batch_y)
metrics['bucket'].shape
output = torch.reshape(metrics['bucket'], (batch_x.shape))
output[0]
# decode_midi(output[0], file_path="dataset\\output.midi")

tensor([198,  65, 256,  ...,  75, 256, 374], device='cuda:0',
       dtype=torch.int32)

In [26]:
np_arr = output[0].tolist()
import pretty_midi
from midi_processor.processor import Event, _event_seq2snote_seq, _merge_note
event_sequence = [Event.from_int(idx) for idx in np_arr]
event_sequence
snote_seq = _event_seq2snote_seq(event_sequence)
note_seq = _merge_note(snote_seq)
note_seq.sort(key=lambda x:x.start)

mid = pretty_midi.PrettyMIDI()
# if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set
instument = pretty_midi.Instrument(1, False, "Developed By Yang-Kichang")
instument.notes = note_seq

mid.instruments.append(instument)
mid.instruments

info removed pitch: 70
info removed pitch: 79
info removed pitch: 66
info removed pitch: 70
info removed pitch: 70
info removed pitch: 70
info removed pitch: 66
info removed pitch: 71


[Instrument(program=1, is_drum=False, name="Developed By Yang-Kichang")]

In [30]:
for idx, _ in enumerate(event_sequence):
    print(event_sequence[idx], np_arr[idx])

<Event type: note_off, value: 70> 198
<Event type: note_on, value: 65> 65
<Event type: time_shift, value: 0> 256
<Event type: velocity, value: 17> 373
<Event type: velocity, value: 18> 374
<Event type: note_on, value: 72> 72
<Event type: time_shift, value: 8> 264
<Event type: velocity, value: 18> 374
<Event type: velocity, value: 18> 374
<Event type: note_on, value: 56> 56
<Event type: time_shift, value: 8> 264
<Event type: velocity, value: 17> 373
<Event type: velocity, value: 18> 374
<Event type: note_on, value: 60> 60
<Event type: time_shift, value: 9> 265
<Event type: velocity, value: 17> 373
<Event type: velocity, value: 16> 372
<Event type: note_on, value: 77> 77
<Event type: time_shift, value: 9> 265
<Event type: velocity, value: 18> 374
<Event type: velocity, value: 17> 373
<Event type: note_on, value: 77> 77
<Event type: time_shift, value: 9> 265
<Event type: velocity, value: 18> 374
<Event type: velocity, value: 17> 373
<Event type: note_on, value: 60> 60
<Event type: time_sh

In [25]:
# np_arr = batch_y[0].tolist()
np_arr = batch_x[0].tolist()
# np_arr = output[0].tolist()
decode_midi(np_arr, file_path="dataset\\output.midi")

info removed pitch: 84
info removed pitch: 82
info removed pitch: 80
info removed pitch: 79
info removed pitch: 77
info removed pitch: 75
info removed pitch: 74
info removed pitch: 72
info removed pitch: 56
info removed pitch: 58
info removed pitch: 68
info removed pitch: 46
info removed pitch: 65
info removed pitch: 67
info removed pitch: 51
info removed pitch: 63
info removed pitch: 73
info removed pitch: 70
info removed pitch: 62
info removed pitch: 55
info removed pitch: 61
info removed pitch: 59
info removed pitch: 60
info removed pitch: 53
info removed pitch: 64
info removed pitch: 29
info removed pitch: 52
info removed pitch: 47
info removed pitch: 43
info removed pitch: 40
info removed pitch: 36
info removed pitch: 41
info removed pitch: 44
info removed pitch: 48
info removed pitch: 66
info removed pitch: 57
info removed pitch: 76
info removed pitch: 78
info removed pitch: 81
info removed pitch: 83
info removed pitch: 50
info removed pitch: 87
info removed pitch: 86


<pretty_midi.pretty_midi.PrettyMIDI at 0x20483024850>

In [80]:
encoded_midi = encode_midi("dataset\\midi\\MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav.midi")
decode_midi(encoded_midi, file_path="dataset\\output3.midi")

<pretty_midi.pretty_midi.PrettyMIDI at 0x218898b2290>

In [95]:
encoded_midi

[355,
 260,
 373,
 42,
 313,
 375,
 49,
 264,
 170,
 285,
 376,
 57,
 283,
 185,
 377,
 57,
 276,
 177,
 376,
 49,
 275,
 376,
 42,
 278,
 170,
 374,
 42,
 282,
 177,
 376,
 49,
 284,
 185,
 374,
 57,
 268,
 379,
 61,
 270,
 378,
 66,
 256,
 170,
 371,
 42,
 290,
 177,
 185,
 189,
 275,
 368,
 49,
 271,
 170,
 267,
 372,
 57,
 278,
 185,
 375,
 57,
 275,
 177,
 377,
 49,
 272,
 376,
 42,
 274,
 170,
 374,
 42,
 276,
 177,
 376,
 49,
 274,
 185,
 376,
 57,
 266,
 379,
 61,
 257,
 194,
 269,
 377,
 62,
 257,
 170,
 371,
 42,
 290,
 177,
 185,
 189,
 267,
 370,
 50,
 263,
 170,
 271,
 372,
 59,
 276,
 187,
 372,
 59,
 273,
 178,
 375,
 50,
 272,
 375,
 42,
 274,
 170,
 374,
 42,
 274,
 178,
 375,
 50,
 275,
 376,
 54,
 267,
 187,
 378,
 59,
 264,
 190,
 261,
 376,
 57,
 258,
 170,
 371,
 42,
 286,
 370,
 49,
 264,
 178,
 182,
 187,
 261,
 372,
 54,
 256,
 170,
 257,
 177,
 272,
 182,
 371,
 54,
 274,
 371,
 49,
 274,
 369,
 42,
 280,
 170,
 369,
 42,
 276,
 177,
 369,
 49,
 291,
 367,
 53