In [4]:
from model import MusicTransformer
import custom
from custom.metrics import *
from custom.criterion import SmoothCrossEntropyLoss, CustomSchedule
from custom.config import config
from data import Data
from midi_processor.processor import encode_midi, decode_midi
from extra import *

import os
from preprocess import preprocess_midi_files_under
from progress.bar import Bar
import pickle

import utils
import datetime
import time

import torch
import torch.optim as optim
from tensorboardX import SummaryWriter


In [3]:
midi_folder = os.path.join('dataset', 'midi')
preprocess_folder = os.path.join('dataset', 'preprocess')

preprocess_midi_files_under(midi_folder, preprocess_folder)

 [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_01_R1_2014_wav--1.midi]

 [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_01_R1_2014_wav--2.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_01_R1_2014_wav--3.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_01_R1_2014_wav--5.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_02_R1_2014_wav--1.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_02_R1_2014_wav--2.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_02_R1_2014_wav--4.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_02_R1_2014_wav--5.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_03_R1_2014_wav--2.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_03_R1_2014_wav--3.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_03_R1_2014_wav--4.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_03_R1_2014_wav--5.midi] [dataset\midi\MIDI-UNPROCESSED_01-03_R1_2014_MID--AUDIO_03_R1_2014_wav--6.midi] [dataset\midi\MIDI-Unprocessed_01_R1_20

KeyboardInterrupt: 

In [5]:
get_config(config, ["config/thor_basic.yml"])

CONFIG_FILE_NAME = save.yml
batch_size = 8
debug = true
device = cuda
dropout = 0.1
embedding_dim = 128
epochs = 200
event_dim = 388
experiment = embedding256-layer6
fp16 = None
l_r = 0.01
label_smooth = 0.1
load_path = None
max_seq = 256
num_layers = 6
pad_token = 388
pickle_dir = dataset\temp
token_eos = 390
token_sos = 389
vocab_size = 391

In [6]:
dataset = Data(config.pickle_dir)
print(dataset)

<class Data has "124" files>


In [None]:
# load data
dataset = Data(config.pickle_dir)
print(dataset)


# load model
learning_rate = config.l_r

# define model
mt = load_model('models/final.pth', config, new=False)

mt.to(config.device)
opt = optim.Adam(mt.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = CustomSchedule(config.embedding_dim, optimizer=opt)

# multi-GPU set
if torch.cuda.device_count() > 1:
    single_mt = mt
    mt = torch.nn.DataParallel(mt, output_device=torch.cuda.device_count()-1)
else:
    single_mt = mt

# init metric set
metric_set = MetricsSet({
    'accuracy': CategoricalAccuracy(),
    'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token),
    'bucket':  LogitsBucketting(config.vocab_size)
})

print(mt)
print('| Summary - Device Info : {}'.format(torch.cuda.device))

# define tensorboard writer
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir = 'logs/'+config.experiment+'/'+current_time+'/train'
eval_log_dir = 'logs/'+config.experiment+'/'+current_time+'/eval'

train_summary_writer = SummaryWriter(train_log_dir)
eval_summary_writer = SummaryWriter(eval_log_dir)

best_loss = 1e9

# Train Start
print(">> Train start...")
idx = 0
for e in range(config.epochs):
    print(">>> [Epoch was updated]")
    for b in range(len(dataset.files) // config.batch_size):
        scheduler.optimizer.zero_grad()
        try:
            batch_x, batch_y = dataset.slide_seq2seq_batch(config.batch_size, config.max_seq)
            batch_x = torch.from_numpy(batch_x).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
            batch_y = torch.from_numpy(batch_y).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
        except IndexError:
            continue

        start_time = time.time()
        mt.train()
        sample = mt.forward(batch_x)
        metrics = metric_set(sample, batch_y)
        loss = metrics['loss']
        loss.backward()
        scheduler.step()
        end_time = time.time()

        if config.debug:
            print("[Loss]: {}".format(loss))

        train_summary_writer.add_scalar('loss', metrics['loss'], global_step=idx)
        train_summary_writer.add_scalar('accuracy', metrics['accuracy'], global_step=idx)
        train_summary_writer.add_scalar('learning_rate', scheduler.rate(), global_step=idx)
        train_summary_writer.add_scalar('iter_p_sec', end_time-start_time, global_step=idx)

        # result_metrics = metric_set(sample, batch_y)
        if b % 100 == 0:
            single_mt.eval()
            eval_x, eval_y = dataset.slide_seq2seq_batch(2, config.max_seq, 'eval')
            eval_x = torch.from_numpy(eval_x).contiguous().to(config.device, dtype=torch.int)
            eval_y = torch.from_numpy(eval_y).contiguous().to(config.device, dtype=torch.int)

            eval_prediction, weights = single_mt.forward(eval_x)

            eval_metrics = metric_set(eval_prediction, eval_y)
            if eval_metrics['loss'] < best_loss:
                torch.save(single_mt.state_dict(), 'models/train-{}.pth'.format(metrics['loss']))
                best_loss = metrics['loss']

            if b == 0:
                train_summary_writer.add_histogram("target_analysis", batch_y, global_step=e)
                train_summary_writer.add_histogram("source_analysis", batch_x, global_step=e)
                for i, weight in enumerate(weights):
                    attn_log_name = "attn/layer-{}".format(i)
                    utils.attention_image_summary(
                        attn_log_name, weight, step=idx, writer=eval_summary_writer)

            eval_summary_writer.add_scalar('loss', eval_metrics['loss'], global_step=idx)
            eval_summary_writer.add_scalar('accuracy', eval_metrics['accuracy'], global_step=idx)
            eval_summary_writer.add_histogram("logits_bucket", eval_metrics['bucket'], global_step=idx)

            print('\n====================================================')
            print('Epoch/Batch: {}/{}'.format(e, b))
            print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(metrics['loss'], metrics['accuracy']))
            print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_metrics['loss'], eval_metrics['accuracy']))
        torch.cuda.empty_cache()
        idx += 1

        # switch output device to: gpu-1 ~ gpu-n
        sw_start = time.time()
        if torch.cuda.device_count() > 1:
            mt.output_device = idx % (torch.cuda.device_count() -1) + 1
        sw_end = time.time()
        if config.debug:
            print('output switch time: {}'.format(sw_end - sw_start) )

torch.save(single_mt.state_dict(), 'models/final.pth'.format(idx))
eval_summary_writer.close()
train_summary_writer.close()

<class Data has "124" files>


  return self.fget.__get__(instance, owner)()


MusicTransformer(
  (Decoder): Encoder(
    (embedding): Embedding(391, 128)
    (pos_encoding): DynamicPositionEmbedding()
    (enc_layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (rga): RelativeGlobalAttention(
          (Wq): Linear(in_features=128, out_features=128, bias=True)
          (Wk): Linear(in_features=128, out_features=128, bias=True)
          (Wv): Linear(in_features=128, out_features=128, bias=True)
          (fc): Linear(in_features=128, out_features=128, bias=True)
        )
        (FFN_pre): Linear(in_features=128, out_features=64, bias=True)
        (FFN_suf): Linear(in_features=64, out_features=128, bias=True)
        (layernorm1): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (layernorm2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (fc): Linear(in_features=128, 

In [5]:
get_output('models/final.pth')

tensor([256,  61, 256,  ..., 374,  54, 256], device='cuda:0',
       dtype=torch.int32)

In [14]:
batch_x.shape

torch.Size([8, 1024])

In [15]:
np_arr = output[0].tolist()
import pretty_midi
from midi_processor.processor import Event, _event_seq2snote_seq, _merge_note
event_sequence = [Event.from_int(idx) for idx in np_arr]
event_sequence
snote_seq = _event_seq2snote_seq(event_sequence)
note_seq = _merge_note(snote_seq)
note_seq.sort(key=lambda x:x.start)

mid = pretty_midi.PrettyMIDI()
# if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set
instument = pretty_midi.Instrument(1, False, "Developed By Yang-Kichang")
instument.notes = note_seq

mid.instruments.append(instument)
mid.instruments

info removed pitch: 68
info removed pitch: 64
info removed pitch: 64
info removed pitch: 64
info removed pitch: 52
info removed pitch: 64
info removed pitch: 64
info removed pitch: 76
info removed pitch: 64
info removed pitch: 80
info removed pitch: 64
info removed pitch: 64
info removed pitch: 64
info removed pitch: 87
info removed pitch: 64
info removed pitch: 87
info removed pitch: 85


[Instrument(program=1, is_drum=False, name="Developed By Yang-Kichang")]

In [7]:
for idx, _ in enumerate(event_sequence):
    print(event_sequence[idx], np_arr[idx])

<Event type: time_shift, value: 0> 256
<Event type: note_on, value: 61> 61
<Event type: time_shift, value: 0> 256
<Event type: velocity, value: 14> 370
<Event type: velocity, value: 14> 370
<Event type: note_on, value: 44> 44
<Event type: time_shift, value: 0> 256
<Event type: note_off, value: 68> 196
<Event type: note_on, value: 59> 59
<Event type: time_shift, value: 0> 256
<Event type: note_off, value: 61> 189
<Event type: note_on, value: 66> 66
<Event type: time_shift, value: 0> 256
<Event type: note_on, value: 78> 78
<Event type: time_shift, value: 0> 256
<Event type: velocity, value: 17> 373
<Event type: time_shift, value: 0> 256
<Event type: note_on, value: 74> 74
<Event type: time_shift, value: 0> 256
<Event type: note_off, value: 74> 202
<Event type: time_shift, value: 0> 256
<Event type: note_on, value: 44> 44
<Event type: time_shift, value: 0> 256
<Event type: velocity, value: 18> 374
<Event type: note_on, value: 71> 71
<Event type: time_shift, value: 0> 256
<Event type: note

In [16]:
# np_arr = batch_x[0].tolist()
decode_midi(np_arr, file_path="dataset\\output.midi")

info removed pitch: 68
info removed pitch: 64
info removed pitch: 64
info removed pitch: 64
info removed pitch: 52
info removed pitch: 64
info removed pitch: 64
info removed pitch: 76
info removed pitch: 64
info removed pitch: 80
info removed pitch: 64
info removed pitch: 64
info removed pitch: 64
info removed pitch: 87
info removed pitch: 64
info removed pitch: 87
info removed pitch: 85


<pretty_midi.pretty_midi.PrettyMIDI at 0x13f30a6bd90>

In [80]:
encoded_midi = encode_midi("dataset\\midi\\MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav.midi")
decode_midi(encoded_midi, file_path="dataset\\output3.midi")

<pretty_midi.pretty_midi.PrettyMIDI at 0x218898b2290>

In [17]:

def get_output(model_path):
    dataset = Data(config.pickle_dir)
    metric_set = MetricsSet({
        'accuracy': CategoricalAccuracy(),
        'loss': SmoothCrossEntropyLoss(config.label_smooth, config.vocab_size, config.pad_token),
        'bucket':  LogitsBucketting(config.vocab_size)
    })
    mt = load_model(model_path, config)
    mt.to(config.device)
    mt.train()
    batch_x, batch_y = dataset.slide_seq2seq_batch(1, config.max_seq)
    batch_x = torch.from_numpy(batch_x).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
    batch_y = torch.from_numpy(batch_y).contiguous().to(config.device, non_blocking=True, dtype=torch.int)
    sample = mt.forward(batch_x)
    metrics = metric_set(sample, batch_y)
    metrics['bucket'].shape
    output = torch.reshape(metrics['bucket'], (batch_x.shape))
    print(output[0])

get_output('models/final.pth')

tensor([163, 304,  83,  ..., 304, 208, 105], device='cuda:0',
       dtype=torch.int32)


In [18]:
batch_x.shape

torch.Size([8, 1024])