<a href="https://colab.research.google.com/github/theerfan/Maqenta/blob/main/src/QuGan_pennyLane_improved.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip install pennylane pennylane-cirq
# !pip install protobuf==3.13.0

In [23]:
import pennylane as qml
from pennylane.templates.layers import BasicEntanglerLayers, StronglyEntanglingLayers
from pennylane.templates.embeddings import AmplitudeEmbedding
import pennylane.numpy as np
import torch
from music21 import converter, instrument, note, chord, stream

from pathlib import Path
import pickle, glob 

In [147]:
n_wires = 10
wires_range = range(n_wires)

n_note_encoding = 6 
encoding_range = range(n_note_encoding)

dev = qml.device('default.qubit', wires=n_wires)

running_dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
!wget https://github.com/theerfan/Maqenta/raw/main/data/notes.pk

In [148]:
def normalize_l2(notes):
    """Normalize a list of notes to have a L2 norm of 1"""
    l2_norm = np.linalg.norm(notes)
    return l2_norm, notes / l2_norm

In [149]:
# Midi.py

notes_dir = "notes.pk"


class Midi:
    def __init__(self, seq_length, device):
        self.seq_length = seq_length
        self.device = device

        if Path(notes_dir).is_file():
            self.notes = pickle.load(open(notes_dir, "rb"))
            # self.notes = pickle.loads(uploaded[notes_dir])
        else:
            self.notes = self.get_notes()
            pickle.dump(self.notes, open(notes_dir, "wb"))

        self.network_input, self.network_output = self.prepare_sequences(self.notes)
        print(f"Input shape: {self.network_input.shape}")
        print(f"Output shape: {self.network_output.shape}")

    def get_notes(self):
        """Get all the notes and chords from the midi files in the ./midi_songs directory"""
        # This is assuming that every interval between notes is the same (0.5)
        notes = []

        for file in glob.glob("midi_songs/*.mid"):
            midi = converter.parse(file)

            print("Parsing %s" % file)

            notes_to_parse = None

            try:  # file has instrument parts
                s2 = instrument.partitionByInstrument(midi)
                notes_to_parse = s2.parts[0].recurse()
            except:  # file has notes in a flat structure
                notes_to_parse = midi.flat.notes

            for element in notes_to_parse:
                if isinstance(element, note.Note):
                    notes.append(str(element.pitch))
                elif isinstance(element, chord.Chord):
                    notes.append(".".join(str(n) for n in element.normalOrder))

        with open(notes_dir, "wb") as filepath:
            pickle.dump(notes, filepath)

        return notes

    def prepare_sequences(self, notes):
        """Prepare the sequences used by the Neural Network"""
        self.n_vocab = len(set(notes))

        # get all pitch names
        pitchnames = sorted(set(item for item in notes))

        # create a dictionary to map pitches to integers
        self.note_to_int = {note: number for number, note in enumerate(pitchnames)}
        self.int_to_note = {number: note for number, note in enumerate(pitchnames)}

        network_input = []
        network_output = []

        # create input sequences and the corresponding outputs
        for i in range(len(self.notes) - self.seq_length):
            sequence_in = self.notes[i : i + self.seq_length]
            sequence_out = self.notes[i + self.seq_length]
            network_input.append([self.note_to_int[char] for char in sequence_in])
            network_output.append(self.note_to_int[sequence_out])

        n_patterns = len(network_input)

        # reshape the input into a format compatible with LSTM layers
        # So this is actuallyt (number of different inputs, sequence length, number of features)
        network_input = np.reshape(network_input, (n_patterns, self.seq_length))
        network_input = torch.tensor(network_input, device=self.device, dtype=torch.double)

        self.input_norms = torch.tensor(np.linalg.norm(network_input, axis=1))
        
        # print(network_input.shape)
        for i in range(network_input.shape[0]):
            network_input[i] /= self.input_norms[i]
        # network_input = torch.div(network_input, self.input_norms)

        return (
            network_input,
            torch.tensor(network_output, device=self.device),
        )

    def create_midi_from_model(self, prediction_output, filename):
        """convert the output from the prediction to notes and create a midi file
        from the notes"""
        offset = 0
        output_notes = []

        # create note and chord objects based on the values generated by the model
        for pattern in prediction_output:
            # pattern is a chord
            if ("." in pattern) or pattern.isdigit():
                notes_in_chord = pattern.split(".")
                notes = []
                for current_note in notes_in_chord:
                    new_note = note.Note(int(current_note))
                    new_note.storedInstrument = instrument.Piano()
                    notes.append(new_note)
                new_chord = chord.Chord(notes)
                new_chord.offset = offset
                output_notes.append(new_chord)
            # pattern is a note
            else:
                new_note = note.Note(pattern)
                new_note.offset = offset
                new_note.storedInstrument = instrument.Piano()
                output_notes.append(new_note)

            # increase offset each iteration so that notes do not stack
            offset += 0.5

        midi_stream = stream.Stream(output_notes)

        midi_stream.write("midi", fp=filename)


In [150]:
seq_length = 2 **  n_note_encoding
print("Initialized Midi")
midi = Midi(seq_length, running_dev)

Initialized Midi
torch.Size([44792, 64])
Input shape: torch.Size([44792, 64])
Output shape: torch.Size([44792])


In [182]:
midi.input_norms[120]
midi.network_input[120]
midi.network_input[120]

tensor([0.1702, 0.0974, 0.1533, 0.0974, 0.0890, 0.1832, 0.0890, 0.1702, 0.0890,
        0.1533, 0.0890, 0.1910, 0.0974, 0.1702, 0.0974, 0.1832, 0.0974, 0.1747,
        0.0974, 0.1832, 0.0890, 0.1793, 0.0890, 0.1618, 0.0890, 0.0890, 0.1533,
        0.0974, 0.0974, 0.0974, 0.0974, 0.0890, 0.0890, 0.0890, 0.0890, 0.0974,
        0.0974, 0.0974, 0.0974, 0.1611, 0.0890, 0.1527, 0.0890, 0.1832, 0.0890,
        0.0890, 0.1793, 0.0974, 0.0974, 0.0974, 0.0974, 0.0890, 0.0890, 0.0890,
        0.0890, 0.0974, 0.0974, 0.0974, 0.0974, 0.1832, 0.0767, 0.1793, 0.1708,
        0.1832], dtype=torch.float64)

In [152]:
def real_music(notes):
    AmplitudeEmbedding(features=notes, wires=encoding_range, normalize=True)

def music_generator(weights):
    BasicEntanglerLayers(weights, wires=encoding_range)

def discriminator(weights):
    StronglyEntanglingLayers(weights, wires=wires_range)

def measurement():
    obs = qml.PauliZ(0)
    for i in range(1, n_wires):
        obs = obs @ qml.PauliZ(i)
    return qml.expval(obs)

In [153]:
@qml.qnode(dev, interface="torch")
def real_music_discriminator(inputs, weights):
    real_music(inputs)
    discriminator(weights)
    return measurement()

@qml.qnode(dev, interface="torch")
def generated_music_discriminator(inputs, note_weights, weights):
    music_generator(note_weights)
    discriminator(weights)
    return measurement()

In [154]:
n_variational_layers = 3

real_shapes = {"weights": (n_variational_layers, n_wires, 3)}

real_layer = qml.qnn.TorchLayer(real_music_discriminator, real_shapes).to(running_dev)

generated_shapes = {
    "weights": (n_variational_layers, n_wires, 3),
    "note_weights": (n_variational_layers, n_note_encoding),
}

generated_layer = qml.qnn.TorchLayer(generated_music_discriminator, generated_shapes).to(running_dev)

In [155]:
def sync_weights(source_layer, target_layer):
    """Synchronize the weights of two layers"""
    target_layer.weights = source_layer.weights

In [175]:
generated_layer.weights = real_layer.weights

In [157]:
def prob_fun_disc_true(layer):
    def prob_true(inputs):
        true_output = layer(inputs)
        # Convert to probability
        prob_true = (true_output + 1) / 2
        return prob_true

    return prob_true

In [168]:
prob_real_true = prob_fun_disc_true(real_layer)
prob_gen_true = prob_fun_disc_true(generated_layer)

empty_input = torch.tensor(np.zeros((1,))).to(running_dev)

def disc_cost(inputs):
    return prob_gen_true(empty_input) - prob_real_true(inputs)

def gen_cost():
    return -prob_gen_true(empty_input)

In [203]:
def gen_batch_inputs(batch_size=1):
    return midi.network_input[
        np.random.randint(0, len(midi.network_input), size=batch_size)
    ]

In [204]:
def discriminator_iteration(n_iterations, batch_size, learning_rate):

    opt = torch.optim.Adam(real_layer.parameters(), lr=learning_rate)
    best_cost = disc_cost(midi.network_input[0])
    
    for _ in range(n_iterations):
        opt.zero_grad()
        # Sample a batch of data
        batch_inputs = gen_batch_inputs()
        # batch_inputs = gen_batch_inputs(batch_size)
        # batch_inputs = batch_inputs / midi.input_norms[:batch_size]
        batch_inputs = batch_inputs.detach()
        # Compute the loss
        loss = disc_cost(batch_inputs)
        # Backpropagate the loss
        loss.backward()
        # Update the weights
        opt.step()
        # Update the best cost
        if loss < best_cost:
            best_cost = loss
            print("New best Discriminator cost:", best_cost)

In [205]:
def generator_iteration(n_iterations, learning_rate):
    opt = torch.optim.Adam(generated_layer.parameters(), lr=learning_rate)
    best_cost = gen_cost()
    
    for _ in range(n_iterations):
        opt.zero_grad()
        # Compute the loss
        loss = gen_cost()
        # Backpropagate the loss
        loss.backward()
        # Update the weights
        opt.step()
        # Update the best cost
        if loss < best_cost:
            best_cost = loss
            print("New best Generator cost:", best_cost)

In [207]:
# The real iteration
steps = 100
n_iterations = 100
learning_rate = 0.01
batch_size = 3

for _ in range(steps):
    discriminator_iteration(n_iterations, batch_size, learning_rate)
    sync_weights(real_layer, generated_layer)
    generator_iteration(n_iterations, learning_rate)

In [171]:
steps = 100
learning_rate = 0.01

# Training the discriminator
opt = torch.optim.Adam(real_layer.parameters(), lr=learning_rate)
best_cost = disc_cost(midi.network_input[0])

for i in range(steps):
    opt.zero_grad()
    cost = disc_cost(midi.network_input[0])
    cost.backward()
    opt.step()
    if cost < best_cost:
        best_cost = cost
    
    if i % 10 == 9 or i == steps - 1:
        print("Step {}: Cost = {}".format(i + 1, cost))

Step 10: Cost = -0.21044705564375155
Step 20: Cost = -0.2956607206036578
Step 30: Cost = -0.3632238684045632
Step 40: Cost = -0.4055262398664411
Step 50: Cost = -0.4272241644579412
Step 60: Cost = -0.4369314413286276
Step 70: Cost = -0.4408592439583635


KeyboardInterrupt: 

In [17]:
circuit(midi.network_input[0].reshape(64))

tensor([-0.24559963, -0.07313006, -0.12988893, -0.05583117,  0.03701939,
         0.05884297], requires_grad=True)

In [None]:
opt = tf.keras.optimizers.SGD(0.4)

In [None]:
def disc_iteration():
  cost = lambda: disc_cost(disc_weights)

  print("####### Minimizing discriminator cost #######")

  for step in range(50):
    opt.minimize(cost, disc_weights)
    
    if step % 5 == 0:
      cost_val = cost().numpy()
      print("Step {}: cost = {}".format(step, cost_val))

  print("####### Finished minimizing discriminator cost #######")

  print("Prob(real classified as real): ", prob_real_true(disc_weights).numpy())
  print("Prob(fake classified as real): ", prob_fake_true(gen_weights, disc_weights).numpy())

In [None]:
def gen_iteration():
  cost = lambda: gen_cost(gen_weights)

  print("####### Minimizing generator cost #######")

  for step in range(50):
    opt.minimize(cost, gen_weights)
    if step % 5 == 0:
      cost_val = cost().numpy()
      print("Step {}: cost = {}".format(step, cost_val))

  print("####### Finished minimizing generator cost #######")

  print("Prob(fake classified as real): ", prob_fake_true(gen_weights, disc_weights).numpy())

In [None]:
def compare_data():

  obs = [qml.PauliX(0), qml.PauliY(0), qml.PauliZ(0)]

  bloch_vector_real = qml.map(real_data, obs, dev, interface="tf")
  bloch_vector_generator = qml.map(generator, obs, dev, interface="tf")

  print("Real Bloch vector: {}".format(bloch_vector_real([phi, theta, omega])))
  print("Generator Bloch vector: {}".format(bloch_vector_generator(gen_weights)))

In [None]:
# The training loop

for i in range(5):
  disc_iteration()
  gen_iteration()
  compare_data()