This is a notebook that implements the paper https://arxiv.org/pdf/1903.07227.pdf in PyTorch. The goal is to generate samples of music, in the form of midi files, that sound like Bach chorales. Each Bach chorale is a piece of music for four voices. These chorales can be encoded in arrays of shape (4, N) where N is the number of 16th notes on the chorale and a value of 60 (say) at i, j indicates that voice is singing the pitch 60 at the jth 16th note. These encodings are in the file Jsb16thSeparated.npz.

I split these encodings into two measure chunks, so arrays of shape (4, 32). After one-hot encoding the entries, they become arrays of shape (4, 32, P) where P is the number of possible pitches.

To train a neural net to generate samples like the training samples, you generate samples which consist of random entries from a chorale plus the location of those entries. The neural net is then trained to predict the rest of the entries. For example, the net might be given the entries of one voice in the chorale and then its job is to predict the rest of the entries. In practice, this works by randomly generating arrays of shape (4, 32) whose entries are 0 and 1. A chorale is multiplied by this array to erase part of its data. Then the partially erased array and the masking array of 0s and 1s are fed through the neural net, which outputs a predicted array of shape (4, 32, P). This output array is compared with the full (4, 32, P) array of the inputted chorale via cross entropy loss, and gradient descent is applied with respect to this loss function. This encourages the network to learn the pitches in the Bach chorale that were erased in the input.

To generate good samples for listening, it helps to repeatedly resample. You generate a completely unmasked chorale, then slowly freeze notes (as if the composer has decided finally that this note is good) and resample with the frozen notes masked. As you resample, you freeze more and more notes, until you're masking all the notes. At this point the sample has been generated.

In [None]:
# installations needed for in-colab midi playback
# !apt install fluidsynth
# !cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2
# !pip install midi2audio

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
import matplotlib.pyplot as plt
import pandas as pd
import mido
import time
from midi2audio import FluidSynth
from IPython.display import Audio, display
import os

# device = 'cuda:0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
softmax = torch.nn.functional.softmax

base_dir = ''

In [None]:
def play_midi(path):
    """
    A script for playing the midi files in the notebook. path is the path to the midi file to be played, relative to base_dir.
    """
    if os.path.exists('test.wav'):
        os.remove('test.wav')
    FluidSynth('/usr/share/soundfonts/default.sf2').midi_to_audio(base_dir + path, 'test.wav')
    audio = Audio('test.wav')
    display(audio)
    
path = '30000midi.mid'
play_midi(path)

In [None]:
# load training data
data = np.load('Jsb16thSeparated.npz', encoding='bytes', allow_pickle=True)

In [None]:
# transpose chorales to different keys, so there's more variation in training data
all_tracks = []
for x in data.files:
    for y in data[x]:
        for i in range(-6, 6):
            all_tracks.append(y + i)

print(len(all_tracks))

In [None]:
# determine highest and lowest pitches

max_midi_pitch = -np.inf
min_midi_pitch = np.inf
for x in all_tracks:
    if x.max() > max_midi_pitch:
        max_midi_pitch = int(x.max())
    if x.min() < min_midi_pitch:
        min_midi_pitch = int(x.min())
        
print(max_midi_pitch, min_midi_pitch)

In [None]:
# set global variables

I = 4 # number of voices
T = 32 # length of samples (32 = two 4/4 measures)
P = max_midi_pitch - min_midi_pitch +1 # number of different pitches
batch_size=24

In [None]:
# prepare the training dataset by cutting chorales in 2 measure pieces

train_tracks = []

for track in all_tracks:
    track = track.transpose()
    cut = 0
    while cut < track.shape[1]-T:
        if (track[:, cut:cut+T] > 0).all():
            train_tracks.append(track[:, cut:cut+T] - min_midi_pitch)
        cut += T
        

train_tracks = np.array(train_tracks).astype(int)

In [None]:
print(train_tracks.shape)

In [None]:
# function for converting arrays of shape (T, 4) into midi files
# the input array has entries that are np.nan (representing a rest)
# of an integer between 0 and 127 inclusive

def piano_roll_to_midi(piece):
    """
    piece is a an array of shape (T, 4) for some T.
    The (i,j)th entry of the array is the midi pitch of the jth voice at time i. It's an integer in range(128).
    outputs a mido object mid that you can convert to a midi file by called its .save() method
    """
    piece = np.concatenate([piece, [[np.nan, np.nan, np.nan, np.nan]]], axis=0)

    bpm = 50
    microseconds_per_beat = 60 * 1000000 / bpm

    mid = mido.MidiFile()
    tracks = {'soprano': mido.MidiTrack(), 'alto': mido.MidiTrack(),
              'tenor': mido.MidiTrack(), 'bass': mido.MidiTrack()}
    past_pitches = {'soprano': np.nan, 'alto': np.nan,
                    'tenor': np.nan, 'bass': np.nan}
    delta_time = {'soprano': 0, 'alto': 0, 'tenor': 0, 'bass': 0}

    # create a track containing tempo data
    metatrack = mido.MidiTrack()
    metatrack.append(mido.MetaMessage('set_tempo',
                                      tempo=int(microseconds_per_beat), time=0))
    mid.tracks.append(metatrack)

    # create the four voice tracks
    for voice in tracks:
        mid.tracks.append(tracks[voice])
        tracks[voice].append(mido.Message(
            'program_change', program=52, time=0))

    # add notes to the four voice tracks
    for i in range(len(piece)):
        pitches = {'soprano': piece[i, 0], 'alto': piece[i, 1],
                   'tenor': piece[i, 2], 'bass': piece[i, 3]}
        for voice in tracks:
            if np.isnan(past_pitches[voice]):
                past_pitches[voice] = None
            if np.isnan(pitches[voice]):
                pitches[voice] = None
            if pitches[voice] != past_pitches[voice]:
                if past_pitches[voice]:
                    tracks[voice].append(mido.Message('note_off', note=int(past_pitches[voice]),
                                                      velocity=64, time=delta_time[voice]))
                    delta_time[voice] = 0
                if pitches[voice]:
                    tracks[voice].append(mido.Message('note_on', note=int(pitches[voice]),
                                                      velocity=64, time=delta_time[voice]))
                    delta_time[voice] = 0
            past_pitches[voice] = pitches[voice]
            # 480 ticks per beat and each line of the array is a 16th note
            delta_time[voice] += 120

    return mid





class Chorale:
    """
    A class to store and manipulate an array self.arr that stores a chorale.
    """
    def __init__(self, arr, subtract_30=False):
        # arr is an array of shape (4, 32) with values in range(0, 57)
        self.arr = arr.copy()
        if subtract_30:
            self.arr -= 30
            
        # the one_hot representation of the array
        reshaped = self.arr.reshape(-1)
        self.one_hot = np.zeros((I*T, P))
        r = np.arange(I*T)
        self.one_hot[r, reshaped] = 1
        self.one_hot = self.one_hot.reshape(I, T, P)
        

    def to_image(self):
        # visualize the four tracks as a images
        soprano = self.one_hot[0].transpose()
        alto = self.one_hot[1].transpose()
        tenor = self.one_hot[2].transpose()
        bass = self.one_hot[3].transpose()
        
        fig, axs = plt.subplots(1, 4)
        axs[0].imshow(np.flip(soprano, axis=0), cmap='hot', interpolation='nearest')
        axs[0].set_title('soprano')
        axs[1].imshow(np.flip(alto, axis=0), cmap='hot', interpolation='nearest')
        axs[1].set_title('alto')
        axs[2].imshow(np.flip(tenor, axis=0), cmap='hot', interpolation='nearest')
        axs[2].set_title('tenor')
        axs[3].imshow(np.flip(bass, axis=0), cmap='hot', interpolation='nearest')
        axs[3].set_title('bass')
        fig.set_figheight(5)
        fig.set_figwidth(15)
        return fig, axs
    
    def play(self, filename='midi_track.mid'):
        # display an in-notebook widget for playing audio
        # saves the midi file as a file named name in base_dir/midi_files
        
        midi_arr = self.arr.transpose().copy()
        midi_arr += 30
        midi = piano_roll_to_midi(midi_arr)
        midi.save(base_dir + 'midi_files/' + filename)
        play_midi('midi_files/' + filename)
        
    def elaborate_on_voices(self, voices, model):
        # voice is a set consisting of 0, 1, 2, or 3
        # create a mask consisting of the given voices
        # generate a chorale with the same voices as in voices
        mask = np.zeros((I, T))
        y = np.random.randint(P, size=(I, T))
        for i in voices:
            mask[i] = 1
            y[i] = self.arr[i].copy()
        return harmonize(y, mask, model)

    def score(self):
        consonance_dict = {0: 1, 1: 0, 2: 0, 3: 1, 4: 1, 5: 1, 6: 0, 7: 1, 8: 1, 9: 1, 10: 0, 11: 0}
        consonance_score = 0
        for k in range(32):
            for i in range(4):
                for j in range(i):
                    consonance_score += consonance_dict[((self.arr[i, k] - self.arr[j, k]) % 12)]
        
        note_score = 0
        for i in range(4):
            for j in range(1, 32):
                if self.arr[i, j] != self.arr[i, j-1]:
                    note_score += 1
        return consonance_score, note_score
        
        
        
        
# harmonize a melody
def harmonize(y, C, model):
    """
    Generate an artificial Bach Chorale starting with y, and keeping the pitches where C==1.
    Here C is an array of shape (4, 32) whose entries are 0 and 1.
    The pitches outside of C are repeatedly resampled to generate new values.
    For example, to harmonize the soprano line, let y be random except y[0] contains the soprano line, let C[1:] be 0 and C[0] be 1.
    """
    model.eval()
    with torch.no_grad():
        x = y
        C2 = C.copy()
        num_steps = int(2*I*T)
        alpha_max = .999
        alpha_min = .001
        eta = 3/4
        for i in range(num_steps):
            p = np.maximum(alpha_min, alpha_max - i*(alpha_max-alpha_min)/(eta*num_steps))
            sampled_binaries = np.random.choice(2, size = C.shape, p=[p, 1-p])
            C2 += sampled_binaries
            C2[C==1] = 1
            x_cache = x
            x = model.pred(x, C2)
            x[C2==1] = x_cache[C2==1]
            C2 = C.copy()
        return x
    
def generate_random_chorale(model):
    """
    Calls harmonize with random initialization and C=0, and so generates a new sample that sounds like Bach.
    """
    y = np.random.randint(P, size=(I, T)).astype(int)
    C = np.zeros((I, T)).astype(int)
    return harmonize(y, C, model)

In [None]:
hidden_size = 32

class Unit(nn.Module):
    """
    Two convolution layers each followed by batchnorm and relu, plus a residual connection.
    """
    def __init__(self):
        super(Unit, self).__init__()
        self.conv1 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(hidden_size)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(hidden_size)
        self.relu2 = nn.ReLU()
        
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.batchnorm1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = self.batchnorm2(y)
        y = y + x
        y = self.relu2(y)
        return y
    
    

class Net(nn.Module):
    """
    A CNN that where you input a starter chorale and a mask 
    and it outputs a prediction for the values
    in the starter chorale away from the mask 
    that are most like the training data.
    """
    def __init__(self):
        super(Net, self).__init__()
        self.initial_conv = nn.Conv2d(2*I, hidden_size, 3, padding=1)
        self.initial_batchnorm = nn.BatchNorm2d(hidden_size)
        self.initial_relu = nn.ReLU()
        self.unit1 = Unit()
        self.unit2 = Unit()
        self.unit3 = Unit()
        self.unit4 = Unit()
        self.unit5 = Unit()
        self.unit6 = Unit()
        self.unit7 = Unit()
        self.unit8 = Unit()
        self.unit9 = Unit()
        self.unit10 = Unit()
        self.unit11 = Unit()
        self.unit12 = Unit()
        self.unit13 = Unit()
        self.unit14 = Unit()
        self.unit15 = Unit()
        self.unit16 = Unit()
        self.affine = nn.Linear(hidden_size*T*P, I*T*P)
        
    def forward(self, x, C):
        # x is a tensor of shape (N, I, T, P)
        # C is a tensor of 0s and 1s of shape (N, I, T)
        # returns a tensor of shape (N, I, T, P)
        
        # get the number of batches
        N = x.shape[0]
        
        # tile the array C out of a tensor of shape (N, I, T, P)
        tiled_C = C.view(N, I, T, 1)
        tiled_C = tiled_C.repeat(1, 1, 1, P)
        
        # mask x and combine it with the mask to produce a tensor of shape (N, 2*I, T, P)
        y = torch.cat((tiled_C*x, tiled_C), dim=1)
        
        # apply the convolution and relu layers
        y = self.initial_conv(y)
        y = self.initial_batchnorm(y)
        y = self.initial_relu(y)
        y = self.unit1(y)
        y = self.unit2(y)
        y = self.unit3(y)
        y = self.unit4(y)
        y = self.unit5(y)
        y = self.unit6(y)
        y = self.unit7(y)
        y = self.unit8(y)
        y = self.unit9(y)
        y = self.unit10(y)
        y = self.unit11(y)
        y = self.unit12(y)
        y = self.unit13(y)
        y = self.unit14(y)
        y = self.unit15(y)
        y = self.unit16(y)
            
        # reshape before applying the fully connected layer
        y = y.view(N, hidden_size*T*P)
        y = self.affine(y)
        
        # reshape to (N, I, T, P)
        y = y.view(N, I, T, P)
                
        return y
    
    def pred(self, y, C):
        # y is an array of shape (I, T) with integer entries in [0, P)
        # C is an array of shape (I, T) consisting of 0s and 1s
        # the entries of y away from the support of C should be considered 'unknown'
        
        # x is shape (I, T, P) one-hot representation of y
        compressed = y.reshape(-1)
        x = np.zeros((I*T, P))
        r = np.arange(I*T)
        x[r, compressed] = 1
        x = x.reshape(I, T, P)
        
        # prep x and C for the plugging into the model
        x = torch.tensor(x).type(torch.FloatTensor).to(device)
        x = x.view(1, I, T, P)
        C2 = torch.tensor(C).type(torch.FloatTensor).view(1, I, T).to(device)
        
        # plug x and C2 into the model
        with torch.no_grad():
            out = self.forward(x, C2).view(I, T, P).cpu().numpy()
            out = out.transpose(2, 0, 1) # shape (P, I, T)
            probs = np.exp(out) / np.exp(out).sum(axis=0) # shape (P, I, T)
            cum_probs = np.cumsum(probs, axis=0) # shape (P, I, T)
            u = np.random.rand(I, T) # shape (I, T)
            return np.argmax(cum_probs > u, axis=0)         
            
        
            

In [None]:
model = Net().to(device)

In [None]:
# uncomment to load the previously trained model
model.load_state_dict(torch.load('model1.pt'))

In [None]:
 # try out the Chorale class functionality with training samples

track = train_tracks[18]
chorale = Chorale(track)
scores = chorale.score()
chorale.to_image()
chorale.play()



# # let's try out a chorale generated by the model, elaborating on the bass track of the last example
# print('-------------')
# new_chorale = Chorale(chorale.elaborate_on_voices([3], model))
# new_chorale.to_image()
# new_chorale.play()




In [None]:
# let's try harmonizing a simple melody. It looks like it's random unless you load a previously trained model above

melody = [66, 66, 66, 66, 71, 71, 71, 71, 73, 73, 73, 73, 75, 75, 75, 75,
         76, 76, 75, 73, 71, 71, 75, 75, 73, 73, 70, 70, 71, 71, 71, 71]

y = np.random.randint(P, size=(I, T))
y[0] = np.array(melody)-30
D0 = np.ones((1, T)).astype(int)
D1 = np.zeros((3, T)).astype(int)
D = np.concatenate([D0, D1], axis=0)

for _ in range(3):
    chorale = Chorale(harmonize(y, D, model))
    chorale.to_image()
    plt.show()
    chorale.play()

In [None]:
print(f'P: {P}, I: {I}, T: {T}')
print(f'y: {y}')

In [None]:
# let's do some more overfitting investigation
# this sample has a suspiciously compelling bass line

sample = [[74, 70, 65, 58], [74, 70, 65, 58], [74, 70, 65, 57], [74, 70, 65, 57], 
          [74, 70, 67, 55], [74, 70, 67, 55], [72, 69, 65, 53], [72, 69, 65, 53], 
          [70, 70, 67, 55], [70, 70, 67, 55], [70, 69, 67, 51], [70, 67, 67, 51], 
          [69, 69, 60, 53], [69, 69, 60, 53], [70, 65, 62, 50], [70, 65, 62, 50], 
          [72, 67, 63, 53], [72, 67, 63, 53], [72, 67, 57, 51], [72, 67, 57, 51], 
          [70, 65, 65, 46], [70, 65, 65, 46], [70, 65, 65, 46], [70, 65, 65, 46], 
          [70, 65, 62, 46], [70, 65, 62, 46], [70, 65, 62, 46], [70, 65, 62, 46], 
          [70, 65, 62, 46], [70, 65, 62, 46], [70, 65, 62, 46], [70, 65, 62, 46]]
chorale = Chorale(np.array(sample).transpose(), subtract_30=True)
chorale.play()

sample = (np.array(sample)-30).transpose()

bass_first_measure = sample[3, :16]

training_bass_first_measures = train_tracks[:, 3, :16]

sq_diff = np.power(bass_first_measure - training_bass_first_measures, 2)
distances = np.sum(sq_diff, axis=1)

distances_as_series = pd.Series(distances).sort_values()
candidates = list(distances_as_series.index[:5])
print(candidates)

for c in candidates:
    candidate_chorale = Chorale(train_tracks[c])
    candidate_chorale.play()
#     track = train_tracks[c]
#     print((track + 30).transpose().tolist())
    
# verdict: the sample simply noticed something which recurs in the chorales, without copying it directly

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)
losses = []

In [None]:
soprano_probs = []
alto_probs = []
tenor_probs = []
bass_probs = []

In [None]:
# some helper functions for getting feedback data while training

def pad_number(n):
    """
    prepare numbers for better file storage
    """
    if n == 0:
        return '00000'
    else:
        digits = int(np.ceil(np.log10(n)))
        pad_zeros = 5 - digits
        return '0'*pad_zeros + str(n)


def return_probs(y, C):
    """
    Plugs (y, C) into model and converts the (logprob) output to probabilities.
    In other words, in the output, the (i,j,k)th entry is the probability of getting the kth pitch when you sample for the ith voice at time j.
    """
    compressed = y.reshape(-1)
    x = np.zeros((I*T, P))
    r = np.arange(I*T)
    x[r, compressed] = 1
    x = x.reshape(I, T, P)
    x = torch.tensor(x).type(torch.FloatTensor).to(device)
    x = x.view(1, I, T, P)
    C2 = torch.tensor(C).type(torch.FloatTensor).to(device)
    model.eval()
    with torch.no_grad():
        out = model.forward(x, C2).view(I, T, P).cpu().numpy().transpose(2, 0, 1)
        probs = np.exp(out)/np.sum(np.exp(out), axis=0)
        return probs.transpose(1, 2, 0)
    
def store_heatmaps(x, C):
    """
    The output of a inputting a single sample into the net is an array of shape (I, T, P) that is interpreted as log probabilities.
    After normalizing to probabilities, it can be interpreted as four arrays (once for each voice soprano, alto, tenor, bass) of shape (T, P)
    That consist of the probabilities of selecting given pitches for each voice at each time. These probabilities can be visualized in heatmaps,
    and this function stores those four heatmaps in the arrays soprano_probs, alto_probs, tenor_probs, bass_probs.
    """
    model.eval()
    with torch.no_grad():
        probs = return_probs(x, C)
        soprano_probs.append(probs[0].transpose())
        alto_probs.append(probs[1].transpose())
        tenor_probs.append(probs[2].transpose())
        bass_probs.append(probs[3].transpose())
    
def display_heatmaps():
    """
    Displays the latest heatmaps produced by store_heatmaps.
    """
    fig, axs = plt.subplots(1, 4)
    axs[0].imshow(np.flip(soprano_probs[-1], axis=0), cmap='hot', interpolation='nearest')
    axs[0].set_title('soprano')
    axs[1].imshow(np.flip(alto_probs[-1], axis=0), cmap='hot', interpolation='nearest')
    axs[1].set_title('alto')
    axs[2].imshow(np.flip(tenor_probs[-1], axis=0), cmap='hot', interpolation='nearest')
    axs[2].set_title('tenor')
    axs[3].imshow(np.flip(bass_probs[-1], axis=0), cmap='hot', interpolation='nearest')
    axs[3].set_title('bass')
    fig.set_figheight(5)
    fig.set_figwidth(15)
    plt.show()
    



# melody = [66, 66, 66, 66, 71, 71, 71, 71, 73, 73, 73, 73, 75, 75, 75, 75,
#         76, 76, 75, 73, 71, 71, 75, 75, 73, 73, 70, 70, 71, 71, 71, 71]

def save_midi(melody, id_number):
    """
    Generate an artificial chorale which has melody in the soprano line and a Bach-like harmonization in the other lines.
    Save the result in a midi file named {id_number}midi.mid
    """
    y = np.random.randint(P, size=(I, T))
    y[0] = np.array(melody)-30 # subtract 30 because 30 is the minimum midi_value
    D0 = np.ones((1, T)).astype(int)
    D1 = np.zeros((3, T)).astype(int)
    D = np.concatenate([D0, D1], axis=0)
    prediction = harmonize(y, D, model) + 30 # 30 back on before passing to piano_roll_to_midi
    prediction = prediction.transpose().tolist()
    prediction = np.array(prediction)
    midi_output = piano_roll_to_midi(prediction)
    midi_output.save(str(pad_number(id_number)) + 'midi.mid')
    
goldberg_like_line = [67, 67, 67, 67, 67, 67, 67, 67, 71, 71, 71, 71, 71, 71, 71, 71,
                      69, 69, 69, 69, 67, 67, 66, 66, 64, 64, 64, 64, 62, 62, 62, 62]    
    
    
goldberg_like_line_down = [37, 37, 37, 37, 37, 37, 37, 37, 41, 41, 41, 41, 41, 41, 41, 41,
                      39, 39, 39, 39, 37, 37, 36, 36, 34, 34, 34, 34, 32, 32, 32, 32]

In [None]:
# this cell will take about 30 hours.
model.train()
N = batch_size

# for i in range(30000):
for i in range(3):
    
    # tensor of shape (N, I, T)
    C = np.random.randint(2, size=(N, I, T))
      
    # batch is an np array of shape (N, I, T), entries are integers in [0, P)
    indices = np.random.choice(train_tracks.shape[0], size=N)
    batch = train_tracks[indices]    
    
    # targets is of shape (N*I*T)
    targets = batch.reshape(-1)
    targets = torch.tensor(targets).to(device)
    
    # x is of shape (N, I, T, P)
    
    batch = batch.reshape(N*I*T)
    x = np.zeros((N*I*T, P))
    r = np.arange(N*I*T)
    x[r, batch] = 1
    x = x.reshape(N, I, T, P)
    x = torch.tensor(x).type(torch.FloatTensor).to(device)

    C2 = torch.tensor(C).type(torch.FloatTensor).to(device)
    out = model(x, C2)
    out = out.view(N*I*T, P)
    loss = loss_fn(out, targets)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    if i % 50 == 0:
        print(i)
        print('loss: ', loss.item())
        D0 = np.ones((1, T))
        D1 = np.zeros((3, T))
        D = np.concatenate([D0, D1], axis=0).astype(int)
        y = np.random.randint(P, size=(I, T))
        y[0, :] = np.array(goldberg_like_line_down)
        store_heatmaps(y, D)
        display_heatmaps()
        if i % 500 == 0:
            save_midi(goldberg_like_line, i)
        model.train()
        
    # adjust learning rate    
    if i % 5000 == 0:
        # for g in optimizer.uparam_groups:
        for g in optimizer.param_groups:
            g['lr'] *= .75

In [None]:
# torch.save(model.state_dict(), '????????.pt')

In [None]:
plt.plot(losses)

In [None]:
# script to produce a gif of the probability heatmaps during training

from matplotlib.animation import FuncAnimation

fig, axs = plt.subplots(1, 4)
axs[0].imshow(np.flip(soprano_probs[0], axis=0), cmap='hot', interpolation='nearest')
axs[0].set_title('soprano')
axs[1].imshow(np.dflip(alto_probs[0], axis=0), cmap='hot', interpolation='nearest')
axs[1].set_title('alto')
axs[2].imshow(np.flip(tenor_probs[0], axis=0), cmap='hot', interpolation='nearest')
axs[2].set_title('tenor')
axs[3].imshow(np.flip(bass_probs[0], axis=0), cmap='hot', interpolation='nearest')
axs[3].set_title('bass')
fig.set_figheight(5)
fig.set_figwidth(15)


def update(i):
    axs[0].imshow(np.flip(soprano_probs[i], axis=0), cmap='hot', interpolation='nearest')
    axs[1].imshow(np.flip(alto_probs[i], axis=0), cmap='hot', interpolation='nearest')
    axs[2].imshow(np.flip(tenor_probs[i], axis=0), cmap='hot', interpolation='nearest')
    axs[3].imshow(np.flip(bass_probs[i], axis=0), cmap='hot', interpolation='nearest')
    
anim = FuncAnimation(fig, update, interval=300, repeat=True, frames=4)
anim.save('anim.gif', writer='pillow')