### 1. Make sure to have completed 'Prerequisites' from README.md

### 2. Copy midi files to `data/{type}/{type_short}/` folder.
> Examples: `data/midi_data_x/x/`, `data/midi_data_y/y/`, `data/midi_data_xy/xy/`, `data/midi_data_y_neg/y_neg/`

### 3. Prepare name list

```
find data/midi_data_x/x -type f -name *.mid -o -name *.xml | cut -c 20- > data/midi_data_x/original-names.txt
```

The `cut` command uses 20- here but it depends on the path (`data/midi_data_x/x`) length.

### 4. Convert the data to json

Make sure the path `data/midi_data_x/processed/json/` exists and run:

```
python mmt/convert_lmd_full.py -n data/midi_data_x/original-names.txt -i data/midi_data_x/x/ -o data/midi_data_x/processed/json/
```

### 5. Extract notes

Make sure the path `data/midi_data_x/processed/notes/` exists and run:

```
python mmt/extract.py -d midi_data_x
```

### 6. Split training/validation/test sets

```
python mmt/split.py -d midi_data_x -v 0 -t 1
```

0 and 1 are the validation and test set ratios, respectively. All files are being assigned to the test set.

### 7. Download pretrained model

Download from https://drive.google.com/drive/folders/1HoKfghXOmiqi028oc_Wv0m2IlLdcJglQ?usp=share_link using gdown

```
gdown --id 1HoKfghXOmiqi028oc_Wv0m2IlLdcJglQ --folder
```

Copy the `sod-ape` model (best_model.pt) to `exp/midi_data_x/ape/checkpoints/`

### 8. Download pre-processed sod dataset

Download pre-processed sod dataset (sod_json.zip and sod_notes.zip) from https://drive.google.com/drive/folders/1owWu-Ne8wDoBYCFiF9z11fruJo62m_uK?usp=share_link using gdown

```
gdown --id 1owWu-Ne8wDoBYCFiF9z11fruJo62m_uK --folder
```

Extract the files (sod_json.zip and sod_notes.zip) to data/sod/processed/json and data/sod/processed/notes


### 9. Start training the sod-ape model to generate the train-args.json file

Make sure the folder `exp/sod/ape/checkpoints` exists and then run:

```
python mmt/train.py -d sod -o exp/sod/ape -g 0
```

This generates the `train-args.json` file in `exp/sod/ape/`. Copy this file to `exp/midi_data_x/ape/`.

### 10. Generate samples using the pre-trained model

```
python mmt/generate.py -d midi_data_x -o exp/midi_data_x/ape -g 0 -ns 4
```

### 11. Compute H

Using the files in `exp/midi_data_x/ape/samples/logits/` and `exp/midi_data_x/ape/samples/npy/`

In [10]:
import numpy as np
import scipy.stats as stats
from scipy.special import softmax
import os

In [11]:
tuple_encoding = {"type" : 0,
                  "beat" : 1, 
                  "position": 2,
                  "pitch" : 3,
                  "duration" : 4,
                  "instrument": 5}

In [12]:
def read_logits(folder):
    """
    The logits folder (eg. /exp/midi_data_x/ape/samples/logits/) will have files like:
    0_16-beat-continuation-beat_logits.npy
    0_16-beat-continuation-duration_logits.npy
    0_16-beat-continuation-instrument_logits.npy
    0_16-beat-continuation-pitch_logits.npy
    0_16-beat-continuation-position_logits.npy
    0_16-beat-continuation-type_logits.npy
    1_16-beat-continuation-beat_logits.npy
    1_16-beat-continuation-duration_logits.npy
    1_16-beat-continuation-instrument_logits.npy
    1_16-beat-continuation-pitch_logits.npy
    1_16-beat-continuation-position_logits.npy
    1_16-beat-continuation-type_logits.npy
    .
    .
    .
    n_16-beat-continuation-beat_logits.npy
    n_16-beat-continuation-duration_logits.npy
    n_16-beat-continuation-instrument_logits.npy
    n_16-beat-continuation-pitch_logits.npy
    n_16-beat-continuation-position_logits.npy
    n_16-beat-continuation-type_logits.npy

    This function will read the logits and return a dictionary of the form:
    {0 : 
        {"type" : np.array(),
        "beat" : np.array(),
        "position" : np.array(),
        "pitch" : np.array(),
        "duration" : np.array(),
        "instrument" : np.array()},
    1 :
        {"type" : np.array(),
        "beat" : np.array(),
        "position" : np.array(),
        "pitch" : np.array(),
        "duration" : np.array(),
        "instrument" : np.array()},
    .
    .
    .
    n : 
        {"type" : np.array(),
        "beat" : np.array(),
        "position" : np.array(),
        "pitch" : np.array(),
        "duration" : np.array(),
        "instrument" : np.array()}
    }
    """
    filenames = os.listdir(folder)
    sample_ids = [int(fname.split('_')[0]) for fname in filenames]
    logits = {id: dict() for id in sample_ids}

    for fname in filenames:
        if fname.endswith(".npy"):
            id = int(fname.split('_')[0])
            l_type = fname.split('-')[-1].split('_')[0]
            logits[id][l_type] = np.load(os.path.join(folder, fname)).squeeze()
    
    return logits

def read_truth(folder):
    """
    The npy folder (eg. /exp/midi_data_x/ape/samples/npy/) will have files like:
    0_truth.npy
    1_truth.npy
    .
    .
    .
    n_truth.npy 

    in addition to the generated files.

    This function will read the truth files and return a dictionary of the form:
    {0 : np.array(),
    1 : np.array(),
    .
    .
    .
    n : np.array()}
    """
    filenames = os.listdir(folder)
    
    truths = dict()
    for fname in filenames:
        if fname.endswith("_truth.npy"):
            id = int(fname.split('_')[0])
            truths[id] = np.load(os.path.join(folder, fname))

    return truths

def calc_entropy(logits, truth):
    """
    Given logits[i] of the form:

    {"type" : np.array(),
    "beat" : np.array(),
    "position" : np.array(),
    "pitch" : np.array(),
    "duration" : np.array(),
    "instrument" : np.array()}

    and truth[i] of the form: np.array(),

    this function will compute the entropy each of the 6 positions in the tuple representation:

    { "type" : entropy,
      "beat" : entropy,
      "position" : entropy,
      "pitch" : entropy,
      "duration" : entropy,
      "instrument" : entropy}
    """

    # select notes from truth after the first 16 beats
    # selected_truth = truth[truth[:, 1] >= 16]

    # select notes from truth after the first 32 elements
    selected_truth = truth[32:]

    # select rows from logits corresponding to the selected notes
    selected_logits = {l_type : logits[l_type][:len(selected_truth)] 
                       for l_type in logits.keys()}

    # convert logits to probabilities
    probs = {l_type : softmax(selected_logits[l_type], axis=-1) for l_type in selected_logits.keys()}

    # select the probability of the truth note from probs
    selected_probs = {l_type : probs[l_type][np.arange(len(selected_truth)), 
                                                       selected_truth[:, tuple_encoding[l_type]]]
                        for l_type in probs.keys()}

    # compute entropy
    entropies = {l_type : stats.entropy(selected_probs[l_type], base=2) for l_type in selected_probs.keys()}

    return entropies
    

In [13]:
def calc_9(logits, truth, add_noise='all'):
    """
    Given logits[i] of the form:

    {"type" : np.array(),
    "beat" : np.array(),
    "position" : np.array(),
    "pitch" : np.array(),
    "duration" : np.array(),
    "instrument" : np.array()}

    and truth[i] of the form: np.array(),

    this function will compute the entropy each of the 6 positions in the tuple representation:

    { "type" : entropy,
      "beat" : entropy,
      "position" : entropy,
      "pitch" : entropy,
      "duration" : entropy,
      "instrument" : entropy}
    """ 

    assert add_noise in ['all', 'pitch', 'duration', 'instrument', 'type', 'beat', 'position']

    # select notes from truth after the first 16 beats
    # selected_truth = truth[truth[:, 1] >= 16]

    # select notes from truth after the first 32 elements
    selected_truth = truth[32:]

    # select rows from logits corresponding to the selected notes
    selected_logits = {l_type : logits[l_type][:len(selected_truth)] 
                       for l_type in logits.keys()}

    # convert logits to probabilities
    probs = {l_type : softmax(selected_logits[l_type], axis=-1) for l_type in selected_logits.keys()}

    # select the probability of the truth note from probs
    selected_probs = {l_type : probs[l_type][np.arange(len(selected_truth)), 
                                                       selected_truth[:, tuple_encoding[l_type]]]
                        for l_type in probs.keys()}

    # compute first_term
    first_term = {l_type : np.mean(selected_probs[l_type]) for l_type in selected_probs.keys()}

    # add noise
    noise = {}
    if add_noise == 'all':
        noise_keys = ['pitch', 'duration', 'beat', 'position']
        for l_type in noise_keys:
            noise[l_type] = np.random.randint(0, probs[l_type].shape[1], size=probs[l_type].shape[0])
    else:
        noise_keys = [add_noise]
        noise[add_noise] = np.random.randint(0, probs[add_noise].shape[1], size=probs[add_noise].shape[0])

    noisy_probs = selected_probs
    for l_type in noise_keys:
        noisy_probs[l_type] = probs[l_type][np.arange(len(selected_truth)), noise[l_type]] 

    # compute second_term
    second_term = {l_type : np.log(np.mean(np.exp(noisy_probs[l_type]))) for l_type in noisy_probs.keys()}

    # compute 9  
    nine = {l_type : first_term[l_type] - second_term[l_type] for l_type in first_term.keys()}

    return nine

In [14]:
logits = read_logits('../exp/midi_data_x/ape/samples/logits/')

for key in sorted(logits.keys()):
    print(key)
    for l_type in sorted(logits[key].keys()):
        print(l_type, logits[key][l_type].shape)

0
beat (992, 257)
duration (992, 33)
instrument (992, 65)
pitch (992, 129)
position (992, 13)
type (992, 5)
1
beat (528, 257)
duration (528, 33)
instrument (528, 65)
pitch (528, 129)
position (528, 13)
type (528, 5)
2
beat (992, 257)
duration (992, 33)
instrument (992, 65)
pitch (992, 129)
position (992, 13)
type (992, 5)
3
beat (603, 257)
duration (603, 33)
instrument (603, 65)
pitch (603, 129)
position (603, 13)
type (603, 5)


In [15]:
truths = read_truth('../exp/midi_data_x/ape/samples/npy/')

for key in sorted(truths.keys()):
    print(key, truths[key].shape)

0 (166, 6)
1 (282, 6)
2 (128, 6)
3 (149, 6)


In [16]:
print('9a:')
for id in sorted(logits.keys()):
    print(id)
    nine_a = calc_9(logits[id], truths[id])
    for key in sorted(nine_a.keys()):
        print(key, nine_a[key])

9a:
0
beat 0.019994609
duration 0.24111746
instrument -0.0027303696
pitch 0.038895123
position 0.23206127
type -0.0027343035
1
beat 0.015090554
duration 0.36677882
instrument -0.0014672875
pitch 0.066336006
position 0.3866105
type -0.00146842
2
beat 0.034893207
duration -0.0122243
instrument -0.0038087368
pitch 0.051265247
position -0.047401913
type -0.0038104653
3
beat 0.0019428665
duration 0.08776606
instrument -0.003128171
pitch 0.1077864
position -0.0076160207
type -0.0031298995


In [17]:
logits_xy = read_logits('../exp/midi_data_xy/ape/samples/logits/')

for key in sorted(logits_xy.keys()):
    print(key)
    for l_type in sorted(logits_xy[key].keys()):
        print(l_type, logits_xy[key][l_type].shape)

0
beat (955, 257)
duration (955, 33)
instrument (955, 65)
pitch (955, 129)
position (955, 13)
type (955, 5)
1
beat (970, 257)
duration (970, 33)
instrument (970, 65)
pitch (970, 129)
position (970, 13)
type (970, 5)
2
beat (478, 257)
duration (478, 33)
instrument (478, 65)
pitch (478, 129)
position (478, 13)
type (478, 5)
3
beat (607, 257)
duration (607, 33)
instrument (607, 65)
pitch (607, 129)
position (607, 13)
type (607, 5)


In [18]:
print('9b:')
for id in sorted(logits_xy.keys()):
    print(id)
    nine_b = calc_9(logits_xy[id], truths[id])
    for key in sorted(nine_b.keys()):
        print(key, nine_b[key])

9b:
0
beat 4.1296067e-07
duration 0.2551444
instrument -0.01188606
pitch 0.057947773
position 0.22950651
type -0.0027342439
1
beat -0.00041193492
duration 0.2381254
instrument -0.005526066
pitch 0.04531589
position 0.25653473
type -0.00146842
2
beat -2.0796447e-06
duration 0.057223387
instrument -0.0047407746
pitch 0.013494772
position -0.12175432
type -0.0038104057
3
beat 0.0594617
duration 0.2480256
instrument -0.0031609535
pitch 0.035239223
position 0.01201351
type -0.0031297207


In [24]:
for id in sorted(logits.keys()):
    print(id, ':')
    nine_a = calc_9(logits[id], truths[id])
    nine_b = calc_9(logits_xy[id], truths[id])
    for key in sorted(nine_a.keys()):
        print()
        print(key, ':')
        print('9a - 9b:', nine_a[key] - nine_b[key])
        print('9a:', nine_a[key])
        print('9b:', nine_b[key])
    print('------------------')

0 :

beat :
9a - 9b: 0.028855793
9a: 0.01612781
9b: -0.012727983

duration :
9a - 9b: -0.05012089
9a: 0.19804078
9b: 0.24816167

instrument :
9a - 9b: 0.009155691
9a: -0.0027303696
9b: -0.01188606

pitch :
9a - 9b: -0.031020664
9a: 0.030762639
9b: 0.061783303

position :
9a - 9b: 0.020113796
9a: 0.1850583
9b: 0.1649445

type :
9a - 9b: -5.9604645e-08
9a: -0.0027343035
9b: -0.0027342439
------------------
1 :

beat :
9a - 9b: 0.015371662
9a: 0.018886268
9b: 0.0035146063

duration :
9a - 9b: 0.11748311
9a: 0.3740636
9b: 0.2565805

instrument :
9a - 9b: 0.0040587783
9a: -0.0014672875
9b: -0.005526066

pitch :
9a - 9b: 0.009807136
9a: 0.063003376
9b: 0.05319624

position :
9a - 9b: 0.13990793
9a: 0.40103772
9b: 0.2611298

type :
9a - 9b: 0.0
9a: -0.00146842
9b: -0.00146842
------------------
2 :

beat :
9a - 9b: 0.059939414
9a: 0.031412527
9b: -0.02852689

duration :
9a - 9b: -0.07888188
9a: -0.012873998
9b: 0.06600788

instrument :
9a - 9b: 0.00093203783
9a: -0.0038087368
9b: -0.004740774