In [None]:
!pip install git+https://github.com/PennyLaneAI/pennylane

In [2]:
import pennylane as qml

from pennylane.templates.layers import BasicEntanglerLayers, StronglyEntanglingLayers, RandomLayers
from pennylane.templates.embeddings import AmplitudeEmbedding
import pennylane.numpy as np
import torch
from music21 import converter, instrument, note, chord, stream

from pathlib import Path
import pickle, glob 

In [10]:
n_wires = 12
wires_range = range(n_wires)

n_note_encoding = 7
encoding_range = range(n_note_encoding)

dev = qml.device('default.qubit', wires=n_wires)

running_dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
running_dev

device(type='cuda')

In [None]:
!wget https://github.com/theerfan/Maqenta/raw/main/data/notes.pk

In [8]:
# Midi.py

notes_dir = "notes.pk"


class Midi:
    def __init__(self, seq_length, device):
        self.seq_length = seq_length
        self.device = device

        if Path(notes_dir).is_file():
            self.notes = pickle.load(open(notes_dir, "rb"))
        else:
            self.notes = self.get_notes()
            pickle.dump(self.notes, open(notes_dir, "wb"))

        self.network_input, self.network_output = self.prepare_sequences(self.notes)
        print(f"Input shape: {self.network_input.shape}")
        print(f"Output shape: {self.network_output.shape}")

    def get_notes(self):
        """Get all the notes and chords from the midi files in the ./midi_songs directory"""
        # This is assuming that every interval between notes is the same (0.5)
        notes = []

        for file in glob.glob("midi_songs/*.mid"):
            midi = converter.parse(file)

            print("Parsing %s" % file)

            notes_to_parse = None

            try:  # file has instrument parts
                s2 = instrument.partitionByInstrument(midi)
                notes_to_parse = s2.parts[0].recurse()
            except:  # file has notes in a flat structure
                notes_to_parse = midi.flat.notes

            for element in notes_to_parse:
                if isinstance(element, note.Note):
                    notes.append(str(element.pitch))
                elif isinstance(element, chord.Chord):
                    notes.append(".".join(str(n) for n in element.normalOrder))

        with open(notes_dir, "wb") as filepath:
            pickle.dump(notes, filepath)

        return notes

    def prepare_sequences(self, notes):
        """Prepare the sequences used by the Neural Network"""
        self.n_vocab = len(set(notes))

        # get all pitch names
        pitchnames = sorted(set(item for item in notes))

        # create a dictionary to map pitches to integers
        self.note_to_int = {note: number for number, note in enumerate(pitchnames)}
        self.int_to_note = {number: note for number, note in enumerate(pitchnames)}

        network_input = []
        network_output = []

        # create input sequences and the corresponding outputs
        for i in range(len(self.notes) - self.seq_length):
            sequence_in = self.notes[i : i + self.seq_length]
            sequence_out = self.notes[i + self.seq_length]
            network_input.append([self.note_to_int[char] for char in sequence_in])
            network_output.append(self.note_to_int[sequence_out])

        n_patterns = len(network_input)

        # reshape the input into a format compatible with LSTM layers
        # So this is actually (number of different inputs, sequence length, number of features)
        network_input = np.reshape(network_input, (n_patterns, self.seq_length))
        network_input = torch.tensor(network_input, device=self.device, dtype=torch.double)

        self.input_norms = torch.tensor(torch.linalg.norm(network_input, axis=1))
        
        for i in range(network_input.shape[0]):
            network_input[i] /= self.input_norms[i]

        return (
            network_input,
            torch.tensor(network_output, device=self.device),
        )

    def create_midi_from_model(self, prediction_output, filename):
        """convert the output from the prediction to notes and create a midi file
        from the notes"""
        offset = 0
        output_notes = []

        # create note and chord objects based on the values generated by the model
        for pattern in prediction_output:
            # pattern is a chord
            if ("." in pattern) or pattern.isdigit():
                notes_in_chord = pattern.split(".")
                notes = []
                for current_note in notes_in_chord:
                    new_note = note.Note(int(current_note))
                    new_note.storedInstrument = instrument.Piano()
                    notes.append(new_note)
                new_chord = chord.Chord(notes)
                new_chord.offset = offset
                output_notes.append(new_chord)
            # pattern is a note
            else:
                new_note = note.Note(pattern)
                new_note.offset = offset
                new_note.storedInstrument = instrument.Piano()
                output_notes.append(new_note)

            # increase offset each iteration so that notes do not stack
            offset += 0.5

        midi_stream = stream.Stream(output_notes)

        midi_stream.write("midi", fp=filename)


In [None]:
seq_length = 2 **  n_note_encoding
print("Initialized Midi")
midi = Midi(seq_length, running_dev)

In [12]:
def encode_music(notes):
    AmplitudeEmbedding(features=notes, wires=encoding_range, normalize=True)

def music_generator(weights):
    # StronglyEntanglingLayers(weights, wires=encoding_range)
    # BasicEntanglerLayers(weights, wires=encoding_range)
    RandomLayers(weights, wires=encoding_range)

def discriminator(weights):
    # BasicEntanglerLayers(weights, wires=wires_range)
    StronglyEntanglingLayers(weights, wires=wires_range)

def measurement(wire_count):
    obs = qml.PauliZ(0)
    for i in range(1, wire_count):
        obs = obs @ qml.PauliZ(i)
    return qml.expval(obs)

In [13]:
@qml.qnode(dev, interface="torch")
def real_music_discriminator(inputs, weights):
    encode_music(inputs)
    discriminator(weights)
    return measurement(n_note_encoding)

def music_generator_circuit(inputs, note_weights):
  encode_music(inputs)
  music_generator(note_weights)

@qml.qnode(dev, interface="torch")
def generated_music_discriminator(inputs, note_weights, weights):
    music_generator_circuit(inputs, note_weights)
    discriminator(weights)
    return measurement(n_note_encoding)

In [15]:
n_disc_layers = 12
n_gen_layers = 20

real_shapes = {"weights": (n_disc_layers, n_wires, 3)}

real_layer = qml.qnn.TorchLayer(real_music_discriminator, real_shapes).to(running_dev)

generated_shapes = {
    "weights": (n_disc_layers, n_wires, 3),
    "note_weights": (n_gen_layers, n_note_encoding),
}

generated_layer = qml.qnn.TorchLayer(generated_music_discriminator, generated_shapes).to(running_dev)
generated_layer.weights.requires_grad=False

In [16]:
def sync_weights(source_layer, target_layer):
    """Synchronize the weights of two layers"""
    source_weights = source_layer.weights
    target_weights = target_layer.weights
    with torch.no_grad():
        for source_weight, target_weight in zip(source_weights, target_weights):
            target_weight.data = source_weight.data

In [17]:
def prob_fun_disc_true(layer):
    def prob_true(inputs):
        true_output = layer(inputs)
        # Convert to probability
        prob_true = (true_output + 1) / 2
        return prob_true

    return prob_true

In [18]:
prob_real_true = prob_fun_disc_true(real_layer)
prob_gen_true = prob_fun_disc_true(generated_layer)

empty_input = torch.tensor(np.zeros((1,))).to(running_dev)

def disc_cost(inputs):
    return prob_gen_true(inputs) - prob_real_true(inputs)

def gen_cost(inputs):
    return -prob_gen_true(inputs)

In [19]:
def gen_batch_inputs(batch_size=1):
    return midi.network_input[
        np.random.randint(0, len(midi.network_input), size=batch_size)
    ]

def shuffle_music(datapoint):
  return datapoint[torch.randperm(datapoint.size()[0])].detach()

In [20]:
def discriminator_iteration(n_iterations, learning_rate):

    opt = torch.optim.Adam(real_layer.parameters(), lr=learning_rate)
    best_cost = disc_cost(midi.network_input[0])
    
    for _ in range(n_iterations):
        opt.zero_grad()
        # Sample a batch of data
        batch_inputs = gen_batch_inputs()
        batch_inputs = batch_inputs.detach()
        # Compute the loss
        loss = disc_cost(batch_inputs)
        sync_weights(real_layer, generated_layer)
        # Backpropagate the loss
        loss.backward()
        # Update the weights
        opt.step()
        # Update the best cost
        if loss < best_cost:
            best_cost = loss
    print("New best Discriminator cost:", best_cost)

In [21]:
def generator_iteration(n_iterations, learning_rate):
    opt = torch.optim.SGD(filter(lambda p: p.requires_grad, generated_layer.parameters()), lr=learning_rate)
    best_cost = gen_cost(midi.network_input[0])
    
    for _ in range(n_iterations):
        opt.zero_grad()
        # Compute the loss

        batch_inputs = gen_batch_inputs()
        batch_inputs = shuffle_music(batch_inputs)

        # print(generated_layer.note_weights)

        loss = gen_cost(batch_inputs)
        # Backpropagate the loss
        loss.backward()
        # Update the weights
        opt.step()
        # Update the best cost
        if loss < best_cost:
            best_cost = loss
    print("New best Generator cost:", best_cost)

In [None]:
# The real iteration
steps = 100
n_iterations = 20
learning_rate = 0.1
# batch_size = 3

generation_counter = 0

model_name = f"quGan-qu{n_wires}-quen{n_note_encoding}-step{steps}-iter{n_iterations}"
model_str = f"{model_name}.pt"

if Path(model_str).is_file():
    print("Loading model")
    generated_layer.load_state_dict(torch.load(model_str))
    generated_layer.eval()
else:
    print("Training model")
    for _ in range(steps):
        discriminator_iteration(n_iterations, learning_rate)
        # sync_weights(real_layer, generated_layer)
        generator_iteration(n_iterations, learning_rate)
    torch.save(generated_layer.state_dict(), model_str)

In [56]:
import random

def generate_notes(model, network_input, int_to_note, n_notes):
        """Generate notes from the neural network based on a sequence of notes"""
        # pick a random sequence from the input as a starting point for the prediction
        scale_factor = len(midi.int_to_note) / torch.max(midi.input_norms)
        with torch.no_grad():
            start = random.randint(0, len(network_input) - n_notes)
            
            prediction_output = []

            # generate n_notes
            for i in range(start, start + n_notes):
                input_ = network_input[i]
                generated_note = model(shuffle_music(input_)) 
                generated_note = (generated_note + 1) * midi.input_norms[i]
                generated_note = int(generated_note)
                counter = 1
                while generated_note not in int_to_note:
                    generated_note *= counter / (counter + 1)
                    generated_note = int(generated_note)
                    counter += 1
                result = int_to_note[int(generated_note)]
                prediction_output.append(result)

            return prediction_output

In [24]:
@qml.qnode(dev, interface="torch")
def final_music_generator(inputs, note_weights):
  music_generator_circuit(inputs, note_weights)
  return measurement(n_note_encoding)

In [25]:
# generator_only = qml.QNode(final_music_generator, dev, interface="torch")
weight_gens = {
    "note_weights": (n_gen_layers, n_note_encoding),
}
generator_only_layer = qml.qnn.TorchLayer(final_music_generator, weight_gens).to(running_dev)

In [26]:
def sync_final_weights(source_layer, target_layer):
    """Synchronize the weights of two layers"""
    source_weights = source_layer.note_weights
    target_weights = target_layer.note_weights
    with torch.no_grad():
        for source_weight, target_weight in zip(source_weights, target_weights):
            target_weight.data = source_weight.data

In [None]:
n_notes = 200
generated_notes = []
print("Generating notes")
sync_final_weights(generated_layer, generator_only_layer)
notes = generate_notes(generator_only_layer, midi.network_input, midi.int_to_note, n_notes=n_notes)
# notes

In [60]:
generation_counter += 1
print("Saving as MIDI file.")
midi.create_midi_from_model(notes, f"{model_name}_generated_{generation_counter}.mid")

Saving as MIDI file.
