In [9]:
import yaml
import torch
import numpy as np
from music21 import chord, note, stream, tempo, duration
from lightning.data import MusicDataWrapper, MusicDataWrapperCNN
from lightning.rnn import LitAttentionRNN
from lightning.seq2seq import LitSeq2Seq
from lightning.vae import LitVAE
from lightning.seg import LitMusicSeg
from lightning.gan import MuseGAN

def show_and_write(notes, durs, fn):
    streamm = stream.Stream()
    for notee, dur in zip(notes, durs):
        if notee == "START" or not dur:
            continue

        if '.' in notee:
            element = chord.Chord(notee.split("."))
        else:
            element = note.Note(notee)

        streamm.append(element)

    streamm.write("midi", fn)

def get_music_from_tokens(tokenizer_ds, token_notes, token_durs):
    notes = durs = []
    for tok_note, tok_dur in zip(token_notes, token_durs):
        notes.append(tokenizer_ds.tokens_to_notes[tok_note.item()])
        durs.append(tokenizer_ds.tokens_to_durations[tok_dur.item()])
    return notes, durs


def notes_to_midi(midi_note_score, filename = None):
    parts = stream.Score()
    parts.append(tempo.MetronomeMark(number= 66))

    for i in range(4):
        s = stream.Part()
        dur = 1/4
        for idx, x in enumerate(midi_note_score[:, i]):
            if np.isnan(x):
                n = note.Rest(dur)
            else:
                x = int(x)
                n = note.Note(x)
                n.duration = duration.Duration(dur)
            s.append(n)

        parts.append(s)

    parts.write('midi', fp=filename)


# Attention RNN

In [11]:
with open("configs/config_rnn.yaml", "r") as f:
    config_rnn = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = "checkpoints\model-epoch=09-val_loss=2.29.ckpt"
dm_rnn = MusicDataWrapper(config_rnn)
lit_rnn = LitAttentionRNN(config_rnn, dm_rnn.num_notes_classes, dm_rnn.num_duration_classes)
lit_rnn_trained = LitAttentionRNN.load_from_checkpoint(checkpoint_path, config=config_rnn, input_note_size=dm_rnn.num_notes_classes, input_dur_size=dm_rnn.num_duration_classes)

In [12]:
notes, durs = lit_rnn.generate(dm_rnn.dataset)
show_and_write(notes, durs, "examples/untrained_rnn.midi")

In [13]:
notes, durs = lit_rnn_trained.generate(dm_rnn.dataset)
show_and_write(notes, durs, "examples/trained_rnn.midi")

# Encoder-Decoder model

In [18]:
with open("configs/config_seq2seq.yaml", "r") as f:
    config_seq2seq = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = "checkpoints_seq2seq\model-epoch=06-val_loss=100.53.ckpt"
dm_seq2seq = MusicDataWrapper(config_seq2seq)
lit_seq2seq = LitSeq2Seq(config_seq2seq, dm_seq2seq.num_notes_classes, dm_seq2seq.num_duration_classes)
lit_seq2seq_trained = LitSeq2Seq.load_from_checkpoint(checkpoint_path, config=config_seq2seq, input_note_size=dm_seq2seq.num_notes_classes, input_dur_size=dm_seq2seq.num_duration_classes)
lit_seq2seq_trained.eval()

LitSeq2Seq(
  (encoder): Encoder(
    (note_emb): Embedding(471, 100)
    (dur_emb): Embedding(18, 100)
    (rnn): GRU(200, 256, num_layers=2, batch_first=True)
    (fc_notes): Linear(in_features=256, out_features=471, bias=True)
    (fc_durs): Linear(in_features=256, out_features=18, bias=True)
  )
  (decoder): Decoder(
    (note_emb): Embedding(471, 100)
    (dur_emb): Embedding(18, 100)
    (rnn): GRU(200, 256, num_layers=2, batch_first=True)
    (fc_notes): Linear(in_features=256, out_features=471, bias=True)
    (fc_durs): Linear(in_features=256, out_features=18, bias=True)
  )
  (model): Seq2Seq(
    (encoder): Encoder(
      (note_emb): Embedding(471, 100)
      (dur_emb): Embedding(18, 100)
      (rnn): GRU(200, 256, num_layers=2, batch_first=True)
      (fc_notes): Linear(in_features=256, out_features=471, bias=True)
      (fc_durs): Linear(in_features=256, out_features=18, bias=True)
    )
    (decoder): Decoder(
      (note_emb): Embedding(471, 100)
      (dur_emb): Embeddin

In [16]:
notes, durs = lit_seq2seq.generate(dm_seq2seq.dataset, start_seq=(["START"], [0]))
show_and_write(notes, durs, "examples/untrained_seq2seq.midi")

In [21]:
notes, durs = lit_seq2seq_trained.generate(dm_seq2seq.dataset, start_seq=(["START"]*15, [0]*15))
show_and_write(notes, durs, "examples/trained_seq2seq.midi")

# CNN VAE

In [22]:
with open("configs/config_vae.yaml", "r") as f:
    config_vae = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = "checkpoints_vae\model-epoch=29-valid_loss=714.52.ckpt"
dm_vae = MusicDataWrapperCNN(config_vae)
lit_vae = LitVAE(config_vae)
lit_vae_trained = LitVAE.load_from_checkpoint(checkpoint_path, config=config_vae)
lit_vae.eval()
lit_vae_trained.eval()

LitVAE(
  (encoder): Encoder(
    (enc_conv1): Conv2d(4, 64, kernel_size=(4, 4), stride=(4, 4))
    (enc_conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(4, 4))
    (enc_conv3): Conv2d(128, 256, kernel_size=(5, 2), stride=(1, 1))
    (enc_batch1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_batch3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_lin): Linear(in_features=256, out_features=256, bias=True)
    (enc_mu): Linear(in_features=256, out_features=64, bias=True)
    (enc_sigma): Linear(in_features=256, out_features=64, bias=True)
    (dropout): Dropout(p=0.4, inplace=False)
  )
  (decoder): Decoder(
    (dec_lin): Linear(in_features=64, out_features=256, bias=True)
    (dec_conv1): ConvTranspose2d(256, 128, kernel_size=(5, 2), stride=(1, 1))
    (dec_conv2): ConvTranspose2d(128, 64, kernel_siz

In [23]:
z = lit_vae.model.encoder.N.sample(torch.Size([32, config_vae["model"]["latent_size"]]))
rec = lit_vae.generate(z)
notes_to_midi(rec[0], f"examples/untrained_vae-mse.midi")

In [24]:
z = lit_vae_trained.model.encoder.N.sample(torch.Size([32, config_vae["model"]["latent_size"]]))
rec = lit_vae_trained.generate(z)
notes_to_midi(rec[0], f"examples/trained_vae-mse.midi")

# Convolutional VAE cross-entropy

In [25]:
with open("configs/config_vae.yaml", "r") as f:
    config_vae = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = "checkpoints_vae_cross_entropy\model-epoch=29-valid_loss=718.03.ckpt"
dm_vae = MusicDataWrapperCNN(config_vae)
lit_vae = LitVAE(config_vae)
lit_vae_trained = LitVAE.load_from_checkpoint(checkpoint_path, config=config_vae)
lit_vae.eval()
lit_vae_trained.eval()

LitVAE(
  (encoder): Encoder(
    (enc_conv1): Conv2d(4, 64, kernel_size=(4, 4), stride=(4, 4))
    (enc_conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(4, 4))
    (enc_conv3): Conv2d(128, 256, kernel_size=(5, 2), stride=(1, 1))
    (enc_batch1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_batch3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_lin): Linear(in_features=256, out_features=256, bias=True)
    (enc_mu): Linear(in_features=256, out_features=64, bias=True)
    (enc_sigma): Linear(in_features=256, out_features=64, bias=True)
    (dropout): Dropout(p=0.4, inplace=False)
  )
  (decoder): Decoder(
    (dec_lin): Linear(in_features=64, out_features=256, bias=True)
    (dec_conv1): ConvTranspose2d(256, 128, kernel_size=(5, 2), stride=(1, 1))
    (dec_conv2): ConvTranspose2d(128, 64, kernel_siz

In [26]:
z = lit_vae.model.encoder.N.sample(torch.Size([32, config_vae["model"]["latent_size"]]))
rec = lit_vae.generate(z)
notes_to_midi(rec[0], f"examples/untrained_vae-cross-entropy.midi")

In [27]:
z = lit_vae_trained.model.encoder.N.sample(torch.Size([32, config_vae["model"]["latent_size"]]))
rec = lit_vae_trained.generate(z)
notes_to_midi(rec[0], f"examples/trained_vae-cross-entropy.midi")

# Convolutional U-Net like

In [29]:
with open("configs/config_seg.yaml", "r") as f:
    config_seg = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = "checkpoints_seg_mse\model-epoch=14-valid_loss=0.01-v1.ckpt"
dm_seg = MusicDataWrapperCNN(config_seg)
lit_seg = LitMusicSeg(config_seg)
lit_seg_trained = LitMusicSeg.load_from_checkpoint(checkpoint_path, config=config_seg)
lit_seg.eval()
lit_seg_trained.eval()

LitMusicSeg(
  (model): MusicSeg(
    (enc_conv1): Conv2d(4, 64, kernel_size=(4, 4), stride=(4, 4))
    (enc_conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(4, 4))
    (enc_conv3): Conv2d(128, 256, kernel_size=(5, 2), stride=(1, 1))
    (enc_batch1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_batch3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout): Dropout(p=0.4, inplace=False)
    (dec_conv1): ConvTranspose2d(256, 128, kernel_size=(5, 2), stride=(1, 1))
    (dec_conv2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(4, 4), output_padding=(1, 0))
    (dec_conv3): ConvTranspose2d(64, 4, kernel_size=(3, 4), stride=(4, 4))
    (dec_batch3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec_batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True

In [31]:
data = next(iter(dm_seg.val_dataloader()))
gen = lit_seg.generate(data[0])
notes_to_midi(gen[0], "examples/untrained_seg_mse.midi")

In [33]:
data = next(iter(dm_seg.val_dataloader()))
x, y = data
gen = lit_seg_trained.generate(x)
y = torch.argmax(y.permute(0, 3, 1, 2), dim=3)
x = torch.argmax(x.permute(0, 3, 1, 2), dim=3)

notes_to_midi(gen[0], "examples/seg_mse/pred.midi")
notes_to_midi(y[0].numpy(), "examples/seg_mse/orig.midi")
notes_to_midi(x[0].numpy(), "examples/seg_mse/start.midi")

# Convolutional Unet-like cross-entropy

In [34]:
with open("configs/config_seg.yaml", "r") as f:
    config_seg = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = "checkpoints_seg_cross-entropy\model-epoch=05-valid_loss=3.43.ckpt"
dm_seg = MusicDataWrapperCNN(config_seg)
lit_seg = LitMusicSeg(config_seg)
lit_seg_trained = LitMusicSeg.load_from_checkpoint(checkpoint_path, config=config_seg)
lit_seg.eval()
lit_seg_trained.eval()

LitMusicSeg(
  (model): MusicSeg(
    (enc_conv1): Conv2d(4, 64, kernel_size=(4, 4), stride=(4, 4))
    (enc_conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(4, 4))
    (enc_conv3): Conv2d(128, 256, kernel_size=(5, 2), stride=(1, 1))
    (enc_batch1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc_batch3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout): Dropout(p=0.4, inplace=False)
    (dec_conv1): ConvTranspose2d(256, 128, kernel_size=(5, 2), stride=(1, 1))
    (dec_conv2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(4, 4), output_padding=(1, 0))
    (dec_conv3): ConvTranspose2d(64, 4, kernel_size=(3, 4), stride=(4, 4))
    (dec_batch3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dec_batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True

In [35]:
data = next(iter(dm_seg.val_dataloader()))
gen = lit_seg.generate(data[0])
notes_to_midi(gen[0], "examples/untrained_seg_cross_entropy.midi")

In [36]:
data = next(iter(dm_seg.val_dataloader()))
x, y = data
gen = lit_seg_trained.generate(x)
y = torch.argmax(y.permute(0, 3, 1, 2), dim=3)
x = torch.argmax(x.permute(0, 3, 1, 2), dim=3)

notes_to_midi(gen[0], f"examples/seg_cross_entropy/pred.midi")
notes_to_midi(y[0].numpy(), f"examples/seg_cross_entropy/orig.midi")
notes_to_midi(x[0].numpy(), f"examples/seg_cross_entropy/start.midi")

# GAN

In [59]:
with open("configs/config_gan.yaml", "r") as f:
    config_gan = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = "checkpoints_gans\model-epoch=299-train_disc_loss=-168.32.ckpt"
dm_seg = MusicDataWrapperCNN(config_gan)
lit_gan = MuseGAN(config_gan)
lit_gan_trained = MuseGAN.load_from_checkpoint(checkpoint_path, config=config_gan)
lit_gan.eval()
lit_gan_trained.eval()

MuseGAN(
  (discriminator): Discriminator(
    (conv1): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1))
    (conv2): Conv2d(64, 128, kernel_size=(12, 1), stride=(12, 1), padding=(6, 0))
    (conv3): Conv2d(128, 128, kernel_size=(7, 1), stride=(7, 1), padding=(3, 0))
    (conv4): Conv2d(128, 128, kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
    (conv5): Conv2d(128, 128, kernel_size=(1, 2), stride=(1, 2), padding=(0, 1))
    (conv6): Conv2d(128, 256, kernel_size=(1, 4), stride=(1, 2), padding=(0, 1))
    (conv7): Conv2d(256, 512, kernel_size=(1, 3), stride=(1, 2))
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (dense): Linear(in_features=512, out_features=1024, bias=True)
    (output): Linear(in_features=1024, out_features=1, bias=True)
  )
  (generator): Generator(
    (unflatten): Unflatten(dim=1, unflattened_size=(512, 1, 1))
    (deconv1): ConvTranspose2d(512, 512, kernel_size=(1, 3), stride=(1, 2))
    (deconv2): ConvTranspose2d(512, 256, kernel_size=(1, 4), stride=(1, 2

In [62]:
z = torch.randn(12, config_gan["model"]["latent_size"])
generated = lit_gan.generate(z)
notes_to_midi(generated[0], f"examples/untrained_gan.midi")



In [61]:
z = torch.randn(12, config_gan["model"]["latent_size"])
generated = lit_gan_trained.generate(z)
for i in range(12):
    notes_to_midi(generated[i], f"examples/gan299/trained_gan{i}.midi")

