In [23]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from dataset import JSBChoralesDataset, VOCAB_SIZE, TOKEN_PAD, jsb_refactor_json
from lstm import MusicLSTM
from tqdm import tqdm

In [9]:
new_dataset = jsb_refactor_json('JSB-Chorales-dataset/Jsb16thSeparated.json', 'JSB-Chorales-dataset/Jsb16thSeparated_refactored.json')
test_dataset = JSBChoralesDataset(new_dataset['test'], max_seq = 1024, random_seq = False)

In [25]:
sequence = test_dataset[5]
print(sequence['input_ids'].shape)
music = sequence['input_ids'][:100]
music = music.long().unsqueeze(0)
music.shape

torch.Size([1024])


torch.Size([1, 100])

In [28]:
device = torch.device('cuda')
model = MusicLSTM(VOCAB_SIZE + 1, embedding_dim = 768, hidden_dim = 768, num_layers = 3, fc_dim = 2048, device = device)
model.load_state_dict(torch.load('../saved_models/music_lstm_jsb.pt'))
model = model.to(device)

In [19]:
model

MusicLSTM(
  (embedding): Embedding(391, 768)
  (lstm): ModuleList(
    (0-2): 3 x LSTM(768, 768, batch_first=True)
  )
  (fc1): Linear(in_features=768, out_features=391, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (relu): ReLU()
)

In [30]:
hidden = None
generated_length = 400
temperature = 1
music = music.to(device)
with torch.no_grad():
    generated_music = music.clone()
    print(music.shape)
    for _ in tqdm(range(generated_length)):
        outputs = model(generated_music, hidden)
        
        probabilities = torch.softmax(outputs[:, -1, :] / temperature, dim=-1)
        next_token = torch.multinomial(probabilities, num_samples=1)
        
        # Context is collected and fed back into the model
        generated_music = torch.cat((generated_music, next_token), dim=1)

torch.Size([1, 100])


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:11<00:00, 33.81it/s]


In [31]:
generated_music.shape

torch.Size([1, 500])

In [32]:
import pretty_midi
import IPython.display
import numpy as np

In [33]:
def synthesize_chorale(note_list, output_file="bach_chorale.mid"):
    # Check if the length of the list is a multiple of 4
    if len(note_list) % 4 != 0:
        raise ValueError("The length of the list must be a multiple of 4.")

    midi = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=0)
    bpm = 120
    sixteenth_note_duration = 60 / bpm / 4

    # Add notes to the instrument
    voice_start = [0, 0, 0, 0]
    for i in range(4, len(note_list), 4):
        end_time = (i // 4) * sixteenth_note_duration + sixteenth_note_duration
        for j, note in enumerate(note_list[i:i + 4]):
            if not (note == -1):
                if note_list[i + j - 4] != note:
                    voice_start[j] = (i // 4) * sixteenth_note_duration
                elif i + j + 4 >= len(note_list):
                    midi_note = pretty_midi.Note(velocity=100, pitch=note, start=voice_start[j], end=end_time)
                    instrument.notes.append(midi_note)
                elif note_list[i + j + 4] != note:
                    midi_note = pretty_midi.Note(velocity=100, pitch=note, start=voice_start[j], end=end_time)
                    instrument.notes.append(midi_note)




    # Add the instrument to the PrettyMIDI object
    midi.instruments.append(instrument)

    # Write the MIDI data to a file
    midi.write(output_file)
    print(f"MIDI file saved as {output_file}")

In [36]:
note_list = generated_music.detach().squeeze(0).cpu().tolist()
synthesize_chorale(note_list)
midi_data = pretty_midi.PrettyMIDI('bach_chorale.mid')
sr = 48000
IPython.display.Audio(midi_data.synthesize(), rate=sr)

MIDI file saved as bach_chorale.mid
