# Test saving a npy as midi file

In [3]:
!pip -q install pretty_midi

[K     |████████████████████████████████| 5.6MB 11.2MB/s 
[K     |████████████████████████████████| 61kB 9.4MB/s 
[?25h  Building wheel for pretty-midi (setup.py) ... [?25l[?25hdone


In [4]:
import numpy as np
import pretty_midi
from google.colab import drive
 
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
!unzip -q "/content/drive/MyDrive/MVA/Deep Learning/GAN/Dataset (use directly).zip"

In [6]:
def set_piano_roll_to_instrument(piano_roll, instrument, velocity=100, tempo=120.0, beat_resolution=16):
    # Calculate time per pixel
    tpp = 60.0 / tempo / float(beat_resolution)
    threshold = 60.0 / tempo / 4
    phrase_end_time = 60.0 / tempo * 4 * piano_roll.shape[0]
    # Create piano_roll_search that captures note onsets and offsets
    piano_roll = piano_roll.reshape((piano_roll.shape[0] * piano_roll.shape[1], piano_roll.shape[2]))
    piano_roll_diff = np.concatenate((np.zeros((1, 128), dtype=int), piano_roll, np.zeros((1, 128), dtype=int)))
    piano_roll_search = np.diff(piano_roll_diff.astype(int), axis=0)
    # Iterate through all possible(128) pitches

    for note_num in range(128):
        # Search for notes
        start_idx = (piano_roll_search[:, note_num] > 0).nonzero()
        start_time = list(tpp * (start_idx[0].astype(float)))
        # print('start_time:', start_time)
        # print(len(start_time))
        end_idx = (piano_roll_search[:, note_num] < 0).nonzero()
        end_time = list(tpp * (end_idx[0].astype(float)))
        # print('end_time:', end_time)
        # print(len(end_time))
        duration = [pair[1] - pair[0] for pair in zip(start_time, end_time)]
        # print('duration each note:', duration)
        # print(len(duration))

        temp_start_time = [i for i in start_time]
        temp_end_time = [i for i in end_time]

        for i in range(len(start_time)):
            # print(start_time)
            if start_time[i] in temp_start_time and i != len(start_time) - 1:
                # print('i and start_time:', i, start_time[i])
                t = []
                current_idx = temp_start_time.index(start_time[i])
                for j in range(current_idx + 1, len(temp_start_time)):
                    # print(j, temp_start_time[j])
                    if temp_start_time[j] < start_time[i] + threshold and temp_end_time[j] <= start_time[i] + threshold:
                        # print('popped start time:', temp_start_time[j])
                        t.append(j)
                        # print('popped temp_start_time:', t)
                for _ in t:
                    temp_start_time.pop(t[0])
                    temp_end_time.pop(t[0])
                # print('popped temp_start_time:', temp_start_time)

        start_time = temp_start_time
        # print('After checking, start_time:', start_time)
        # print(len(start_time))
        end_time = temp_end_time
        # print('After checking, end_time:', end_time)
        # print(len(end_time))
        duration = [pair[1] - pair[0] for pair in zip(start_time, end_time)]
        # print('After checking, duration each note:', duration)
        # print(len(duration))

        if len(end_time) < len(start_time):
            d = len(start_time) - len(end_time)
            start_time = start_time[:-d]
        # Iterate through all the searched notes
        for idx in range(len(start_time)):
            if duration[idx] >= threshold:
                # Create an Note object with corresponding note number, start time and end time
                note = pretty_midi.Note(velocity=velocity, pitch=note_num, start=start_time[idx], end=end_time[idx])
                # Add the note to the Instrument object
                instrument.notes.append(note)
            else:
                if start_time[idx] + threshold <= phrase_end_time:
                    # Create an Note object with corresponding note number, start time and end time
                    note = pretty_midi.Note(velocity=velocity, pitch=note_num, start=start_time[idx],
                                            end=start_time[idx] + threshold)
                else:
                    # Create an Note object with corresponding note number, start time and end time
                    note = pretty_midi.Note(velocity=velocity, pitch=note_num, start=start_time[idx],
                                            end=phrase_end_time)
                # Add the note to the Instrument object
                instrument.notes.append(note)
    # Sort the notes by their start time
    instrument.notes.sort(key=lambda note: note.start)
    # print(max([i.end for i in instrument.notes]))
    # print('tpp, threshold, phrases_end_time:', tpp, threshold, phrase_end_time)


def write_piano_roll_to_midi(piano_roll, filename, program_num=0, is_drum=False, velocity=100,
                             tempo=120.0, beat_resolution=16):
    # Create a PrettyMIDI object
    midi = pretty_midi.PrettyMIDI(initial_tempo=tempo)
    # Create an Instrument object
    instrument = pretty_midi.Instrument(program=program_num, is_drum=is_drum)
    # Set the piano roll to the Instrument object
    set_piano_roll_to_instrument(piano_roll, instrument, velocity, tempo, beat_resolution)
    # Add the instrument to the PrettyMIDI object
    midi.instruments.append(instrument)
    # Write out the MIDI data
    midi.write(filename)


def write_piano_rolls_to_midi(piano_rolls, program_nums=None, is_drum=None, filename='test.mid', velocity=100,
                              tempo=120.0, beat_resolution=24):
    if len(piano_rolls) != len(program_nums) or len(piano_rolls) != len(is_drum):
        print("Error: piano_rolls and program_nums have different sizes...")
        return False
    if not program_nums:
        program_nums = [0, 0, 0]
    if not is_drum:
        is_drum = [False, False, False]
    # Create a PrettyMIDI object
    midi = pretty_midi.PrettyMIDI(initial_tempo=tempo)
    # Iterate through all the input instruments
    for idx in range(len(piano_rolls)):
        # Create an Instrument object
        instrument = pretty_midi.Instrument(program=program_nums[idx], is_drum=is_drum[idx])
        # Set the piano roll to the Instrument object
        set_piano_roll_to_instrument(piano_rolls[idx], instrument, velocity, tempo, beat_resolution)
        # Add the instrument to the PrettyMIDI object
        midi.instruments.append(instrument)
    # Write out the MIDI data
    midi.write(filename)

In [7]:
def save_midis(bars, file_path, tempo=80.0):
    padded_bars = np.concatenate((np.zeros((bars.shape[0], bars.shape[1], 24, bars.shape[3])), bars,
                                  np.zeros((bars.shape[0], bars.shape[1], 20, bars.shape[3]))), axis=2)
    pause = np.zeros((bars.shape[0], 64, 128, bars.shape[3]))
    images_with_pause = padded_bars
    images_with_pause = images_with_pause.reshape(-1, 64, padded_bars.shape[2], padded_bars.shape[3])
    images_with_pause_list = []
    for ch_idx in range(padded_bars.shape[3]):
        images_with_pause_list.append(images_with_pause[:, :, :, ch_idx].reshape(images_with_pause.shape[0],
                                                                                 images_with_pause.shape[1],
                                                                                 images_with_pause.shape[2]))
    # write_midi.write_piano_rolls_to_midi(images_with_pause_list, program_nums=[33, 0, 25, 49, 0],
    #                                      is_drum=[False, True, False, False, False], filename=file_path, tempo=80.0)
    write_piano_rolls_to_midi(images_with_pause_list, program_nums=[0], is_drum=[False], filename=file_path,
                                         tempo=tempo, beat_resolution=4)



In [8]:
X = np.load("/content/JC_C/test/classic_piano_test_995.npy")

In [9]:
save_midis(X.reshape(-1, 64, 84, 1),'/content/test_save.mid')

# Data

In [10]:
"""import os
import shutil
files = [f for f in os.listdir('/content/') if "_J" in f]
#files.remove("JCP_mixed")
!mkdir /content/Data3
for f in files : 
  print(f)
  train ='/content/' +f+"/train"
  train_file = os.listdir(train)
  for k in range(len(train_file)) :
    shutil.copy2(train +"/" + train_file[k], '/content/Data3/'+ train_file[k])
  test ='/content/' + f+"/test"
  test_file = os.listdir(test)
  for k in range(len(test_file)) :
    shutil.copy2(test +"/" + test_file[k], '/content/Data3/'+ test_file[k])"""

'import os\nimport shutil\nfiles = [f for f in os.listdir(\'/content/\') if "_J" in f]\n#files.remove("JCP_mixed")\n!mkdir /content/Data3\nfor f in files : \n  print(f)\n  train =\'/content/\' +f+"/train"\n  train_file = os.listdir(train)\n  for k in range(len(train_file)) :\n    shutil.copy2(train +"/" + train_file[k], \'/content/Data3/\'+ train_file[k])\n  test =\'/content/\' + f+"/test"\n  test_file = os.listdir(test)\n  for k in range(len(test_file)) :\n    shutil.copy2(test +"/" + test_file[k], \'/content/Data3/\'+ test_file[k])'

In [11]:
import os
import shutil
!mkdir /content/Data
arrays = os.listdir('/content/JCP_mixed')
for k in range(len(arrays)) : 
  shutil.copy2('/content/JCP_mixed/' + arrays[k], '/content/Data/'+ arrays[k])

In [12]:
len(os.listdir('/content/Data'))

33648

# Model

In [13]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [14]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_dim=5376, hidden_dim=128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            
            nn.Linear(z_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.BatchNorm1d(hidden_dim * 4),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim * 4, hidden_dim * 8),
            nn.BatchNorm1d(hidden_dim * 8),
            nn.ReLU(inplace=True),

            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )
    def forward(self, noise):
        return self.gen(noise)

In [15]:
def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

In [16]:
class Discriminator(nn.Module):

    def __init__(self, im_dim=5376, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(im_dim, hidden_dim * 4),
            nn.LeakyReLU(0.2),

            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LeakyReLU(0.2),

            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.2),

            nn.Linear(hidden_dim, 1)
        )

    def forward(self, image):
        return self.disc(image)


In [17]:
# Setting parameters
criterion = nn.BCEWithLogitsLoss()

z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001


device = 'cuda'

In [18]:
gen = Generator(z_dim).to(device) # add code here to initialize the generator
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr) # add code here to initialize the generator's optimizer
disc = Discriminator().to(device) # add code here to initialize the discriminator
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr) # add code here to initialize the discriminator's optimizer

In [19]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):

    noise = get_noise(num_images, z_dim, device)
    fake_images = gen(noise)
    
    prediction_fake = disc(fake_images.detach())
    ground_truth_fake = torch.zeros(num_images,1).to(device)
    loss_fake = criterion(prediction_fake, ground_truth_fake)

    prediction_real = disc(real)
    ground_truth_real = torch.ones(num_images,1).to(device)
    loss_real = criterion(prediction_real, ground_truth_real)
    
    disc_loss = (loss_fake + loss_real) / 2
    return disc_loss

In [20]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):

    noise = get_noise(num_images, z_dim, device)
    fake_images = gen(noise)
    
    prediction_fake = disc(fake_images)
    gen_loss = criterion(prediction_fake, torch.ones(num_images,1).to(device))
    return gen_loss

In [21]:
path = os.listdir("/content/Data")
datas = [np.load("/content/Data/"+path[k]).reshape(1,64,84,1) for k in range(len(path))]

In [22]:
datas = np.concatenate(datas, axis = 0)

In [23]:
datas = 1*datas

In [24]:
import random 
N = datas.shape[0]
idx = np.arange(N)
random.shuffle(idx)
idx = list(idx)

In [25]:
chunks = [idx[x:x+batch_size] for x in range(0, len(idx), batch_size)]

In [32]:
n_epochs = 400
G_losses = []
D_losses = []
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = False
error = False

for epoch in range(n_epochs):
  print("Epoch {}".format(epoch))
  for k in range(len(chunks)) :

    real = torch.tensor(datas[chunks[k],:,:,:]).type(torch.FloatTensor).to(device)
    cur_batch_size = real.size(0)
    real = torch.reshape(real, (cur_batch_size, real.size(1) * real.size(2)))
    
    # Zero out the gradients before backpropagation
    disc_opt.zero_grad()
    # Calculate discriminator loss
    disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
    # Update gradients
    disc_loss.backward(retain_graph=True)
    # Update optimizer
    disc_opt.step()
    gen_opt.zero_grad()
    # Calculate generator loss
    gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
    # Update gradients
    gen_loss.backward(retain_graph=True)

    # Update optimizer
    gen_opt.step()
    ##### 
    # Keep track of the average discriminator loss
    mean_discriminator_loss += disc_loss.item() / display_step

    # Keep track of the average generator loss
    mean_generator_loss += gen_loss.item() / display_step

    # Save Losses for plotting later
    G_losses.append(mean_generator_loss)
    D_losses.append(mean_discriminator_loss)

Epoch 0
Epoch 1
Epoch 2
Epoch 3
Epoch 4
Epoch 5
Epoch 6
Epoch 7
Epoch 8
Epoch 9
Epoch 10
Epoch 11
Epoch 12
Epoch 13
Epoch 14
Epoch 15
Epoch 16
Epoch 17
Epoch 18
Epoch 19
Epoch 20
Epoch 21
Epoch 22
Epoch 23
Epoch 24
Epoch 25
Epoch 26
Epoch 27
Epoch 28
Epoch 29
Epoch 30
Epoch 31
Epoch 32
Epoch 33
Epoch 34
Epoch 35
Epoch 36
Epoch 37
Epoch 38
Epoch 39
Epoch 40
Epoch 41
Epoch 42
Epoch 43
Epoch 44
Epoch 45
Epoch 46
Epoch 47
Epoch 48
Epoch 49
Epoch 50
Epoch 51
Epoch 52
Epoch 53
Epoch 54
Epoch 55
Epoch 56
Epoch 57
Epoch 58
Epoch 59
Epoch 60
Epoch 61
Epoch 62
Epoch 63
Epoch 64
Epoch 65
Epoch 66
Epoch 67
Epoch 68
Epoch 69
Epoch 70
Epoch 71
Epoch 72
Epoch 73
Epoch 74
Epoch 75
Epoch 76
Epoch 77
Epoch 78
Epoch 79
Epoch 80
Epoch 81
Epoch 82
Epoch 83
Epoch 84
Epoch 85
Epoch 86
Epoch 87
Epoch 88
Epoch 89
Epoch 90
Epoch 91
Epoch 92
Epoch 93
Epoch 94
Epoch 95
Epoch 96
Epoch 97
Epoch 98
Epoch 99
Epoch 100
Epoch 101
Epoch 102
Epoch 103
Epoch 104
Epoch 105
Epoch 106
Epoch 107
Epoch 108
Epoch 109
Epoch 110


In [33]:
noise = get_noise(2, z_dim, device)    

In [34]:
fake_images = gen(noise)

In [35]:
generat = np.array(torch.reshape(fake_images[0], (64,84)).detach().cpu())

In [36]:
for i in range(generat.shape[0]) :
  for j in range(generat.shape[1]) :
    if generat[i,j] > 0.5 : 
      generat[i,j] = True
    else :
      generat[i,j] = False

In [37]:
save_midis(generat.reshape(-1, 64, 84, 1),'/content/test_GAN_A.midi')

In [38]:
"""shutil.copy2('/content/test_GAN_A.midi', '/content/drive/MyDrive/MVA/Deep Learning/GAN/test_GAN_A.midi')
save_midis(generat.reshape(-1, 64, 84, 1),'/content/test_GAN_A.mp3')
from IPython.display import Audio
Audio("/content/test_GAN_A (8).wav", autoplay=True)"""

'shutil.copy2(\'/content/test_GAN_A.midi\', \'/content/drive/MyDrive/MVA/Deep Learning/GAN/test_GAN_A.midi\')\nsave_midis(generat.reshape(-1, 64, 84, 1),\'/content/test_GAN_A.mp3\')\nfrom IPython.display import Audio\nAudio("/content/test_GAN_A (8).wav", autoplay=True)'