### 1. Make sure to have completed 'Prerequisites' from README.md

### 2. Copy midi files to `data/{type}/{type_short}/` folder.
> Examples: `data/midi_data_x/x/`, `data/midi_data_y/y/`, `data/midi_data_xy/xy/`, `data/midi_data_y_neg/y_neg/`

### 3. Prepare name list

```
find data/midi_data_x/x -type f -name *.mid -o -name *.xml | cut -c 20- > data/midi_data_x/original-names.txt
```

The `cut` command uses 20- here but it depends on the path (`data/midi_data_x/x`) length.

### 4. Convert the data to json

Make sure the path `data/midi_data_x/processed/json/` exists and run:

```
python mmt/convert_lmd_full.py -n data/midi_data_x/original-names.txt -i data/midi_data_x/x/ -o data/midi_data_x/processed/json/
```

### 5. Extract notes

Make sure the path `data/midi_data_x/processed/notes/` exists and run:

```
python mmt/extract.py -d midi_data_x
```

### 6. Split training/validation/test sets

```
python mmt/split.py -d midi_data_x -v 0 -t 1
```

0 and 1 are the validation and test set ratios, respectively. All files are being assigned to the test set.

### 7. Download pretrained model

Download from https://drive.google.com/drive/folders/1HoKfghXOmiqi028oc_Wv0m2IlLdcJglQ?usp=share_link using gdown

```
gdown --id 1HoKfghXOmiqi028oc_Wv0m2IlLdcJglQ --folder
```

Copy the `sod-ape` model (best_model.pt) to `exp/midi_data_x/ape/checkpoints/`

### 8. Download pre-processed sod dataset

Download pre-processed sod dataset (sod_json.zip and sod_notes.zip) from https://drive.google.com/drive/folders/1owWu-Ne8wDoBYCFiF9z11fruJo62m_uK?usp=share_link using gdown

```
gdown --id 1owWu-Ne8wDoBYCFiF9z11fruJo62m_uK --folder
```

Extract the files (sod_json.zip and sod_notes.zip) to data/sod/processed/json and data/sod/processed/notes


### 9. Start training the sod-ape model to generate the train-args.json file

Make sure the folder `exp/sod/ape/checkpoints` exists and then run:

```
python mmt/train.py -d sod -o exp/sod/ape -g 0
```

This generates the `train-args.json` file in `exp/sod/ape/`. Copy this file to `exp/midi_data_x/ape/`.

### 10. Generate samples using the pre-trained model

```
python mmt/generate.py -d midi_data_x -o exp/midi_data_x/ape -g 0 -ns 4
```

### 11. Compute H

Using the files in `exp/midi_data_x/ape/samples/logits/` and `exp/midi_data_x/ape/samples/npy/`

In [6]:
import numpy as np
import scipy.stats as stats
from scipy.special import softmax
import os

In [7]:
tuple_encoding = {"type" : 0,
                  "beat" : 1, 
                  "position": 2,
                  "pitch" : 3,
                  "duration" : 4,
                  "instrument": 5}

In [8]:
def read_logits(folder):
    """
    The logits folder (eg. /exp/midi_data_x/ape/samples/logits/) will have files like:
    0_16-beat-continuation-beat_logits.npy
    0_16-beat-continuation-duration_logits.npy
    0_16-beat-continuation-instrument_logits.npy
    0_16-beat-continuation-pitch_logits.npy
    0_16-beat-continuation-position_logits.npy
    0_16-beat-continuation-type_logits.npy
    1_16-beat-continuation-beat_logits.npy
    1_16-beat-continuation-duration_logits.npy
    1_16-beat-continuation-instrument_logits.npy
    1_16-beat-continuation-pitch_logits.npy
    1_16-beat-continuation-position_logits.npy
    1_16-beat-continuation-type_logits.npy
    .
    .
    .
    n_16-beat-continuation-beat_logits.npy
    n_16-beat-continuation-duration_logits.npy
    n_16-beat-continuation-instrument_logits.npy
    n_16-beat-continuation-pitch_logits.npy
    n_16-beat-continuation-position_logits.npy
    n_16-beat-continuation-type_logits.npy

    This function will read the logits and return a dictionary of the form:
    {0 : 
        {"type" : np.array(),
        "beat" : np.array(),
        "position" : np.array(),
        "pitch" : np.array(),
        "duration" : np.array(),
        "instrument" : np.array()},
    1 :
        {"type" : np.array(),
        "beat" : np.array(),
        "position" : np.array(),
        "pitch" : np.array(),
        "duration" : np.array(),
        "instrument" : np.array()},
    .
    .
    .
    n : 
        {"type" : np.array(),
        "beat" : np.array(),
        "position" : np.array(),
        "pitch" : np.array(),
        "duration" : np.array(),
        "instrument" : np.array()}
    }
    """
    filenames = os.listdir(folder)
    sample_ids = [int(fname.split('_')[0]) for fname in filenames]
    logits = {id: dict() for id in sample_ids}

    for fname in filenames:
        if fname.endswith(".npy"):
            id = int(fname.split('_')[0])
            l_type = fname.split('-')[-1].split('_')[0]
            logits[id][l_type] = np.load(os.path.join(folder, fname)).squeeze()
    
    return logits

def read_truth(folder):
    """
    The npy folder (eg. /exp/midi_data_x/ape/samples/npy/) will have files like:
    0_truth.npy
    1_truth.npy
    .
    .
    .
    n_truth.npy 

    in addition to the generated files.

    This function will read the truth files and return a dictionary of the form:
    {0 : np.array(),
    1 : np.array(),
    .
    .
    .
    n : np.array()}
    """
    filenames = os.listdir(folder)
    
    truths = dict()
    for fname in filenames:
        if fname.endswith("_truth.npy"):
            id = int(fname.split('_')[0])
            truths[id] = np.load(os.path.join(folder, fname))

    return truths

def calc_entropy(logits, truth):
    """
    Given logits[i] of the form:

    {"type" : np.array(),
    "beat" : np.array(),
    "position" : np.array(),
    "pitch" : np.array(),
    "duration" : np.array(),
    "instrument" : np.array()}

    and truth[i] of the form: np.array(),

    this function will compute the entropy each of the 6 positions in the tuple representation:

    { "type" : entropy,
      "beat" : entropy,
      "position" : entropy,
      "pitch" : entropy,
      "duration" : entropy,
      "instrument" : entropy}
    """

    # select notes from truth after the first 16 beats
    # selected_truth = truth[truth[:, 1] >= 16]

    # select notes from truth after the first 32 elements
    selected_truth = truth[32:]

    # select rows from logits corresponding to the selected notes
    selected_logits = {l_type : logits[l_type][:len(selected_truth)] 
                       for l_type in logits.keys()}

    # convert logits to probabilities
    probs = {l_type : softmax(selected_logits[l_type], axis=-1) for l_type in selected_logits.keys()}

    # select the probability of the truth note from probs
    selected_probs = {l_type : probs[l_type][np.arange(len(selected_truth)), 
                                                       selected_truth[:, tuple_encoding[l_type]]]
                        for l_type in probs.keys()}

    # compute entropy
    entropies = {l_type : stats.entropy(selected_probs[l_type], base=2) for l_type in selected_probs.keys()}

    return entropies
    

In [9]:
def calc_9(logits, truth, add_noise='all'):
    """
    Given logits[i] of the form:

    {"type" : np.array(),
    "beat" : np.array(),
    "position" : np.array(),
    "pitch" : np.array(),
    "duration" : np.array(),
    "instrument" : np.array()}

    and truth[i] of the form: np.array(),

    this function will compute the entropy each of the 6 positions in the tuple representation:

    { "type" : entropy,
      "beat" : entropy,
      "position" : entropy,
      "pitch" : entropy,
      "duration" : entropy,
      "instrument" : entropy}
    """ 

    assert add_noise in ['all', 'pitch', 'duration', 'instrument', 'type', 'beat', 'position']

    # select notes from truth after the first 16 beats
    # selected_truth = truth[truth[:, 1] >= 16]

    # select notes from truth after the first 32 elements
    selected_truth = truth[32:]

    # select rows from logits corresponding to the selected notes
    selected_logits = {l_type : logits[l_type][:len(selected_truth)] 
                       for l_type in logits.keys()}

    probs = {l_type : selected_logits[l_type] for l_type in selected_logits.keys()}

    # select the probability of the truth note from probs
    selected_probs = {l_type : probs[l_type][np.arange(len(selected_truth)), 
                                                       selected_truth[:, tuple_encoding[l_type]]]
                        for l_type in probs.keys()}

    # compute first_term
    first_term = {l_type : np.mean(selected_probs[l_type]) for l_type in selected_probs.keys()}

    # add noise
    noise = {}
    if add_noise == 'all':
        noise_keys = ['pitch', 'duration', 'beat', 'position']
        r = np.random.RandomState(42)
        for l_type in noise_keys:
            choices = [r.choice(list(set(range(0, probs[l_type].shape[1])).difference(set([selected_truth[i, tuple_encoding[l_type]]])))) 
                       for i in range(len(selected_truth))]
            noise[l_type] = np.array(choices)
    else:
        noise_keys = [add_noise]
        r = np.random.RandomState(42)
        choices = [r.choice(list(set(range(0, probs[add_noise].shape[1])).difference(set([selected_truth[i, tuple_encoding[add_noise]]])))) 
                       for i in range(len(selected_truth))]
        noise[add_noise] = np.array(choices)

    noisy_probs = selected_probs
    for l_type in noise_keys:
        noisy_probs[l_type] = probs[l_type][np.arange(len(selected_truth)), noise[l_type]] 

    # compute second_term
    second_term = {l_type : np.log(np.mean(np.exp(noisy_probs[l_type]))) for l_type in noisy_probs.keys()}

    # compute 9  
    nine = {l_type : first_term[l_type] - second_term[l_type] for l_type in first_term.keys()}

    return nine

In [10]:
logits = read_logits('../exp/midi_data_x/ape/samples/logits/')

for key in sorted(logits.keys()):
    print(key)
    for l_type in sorted(logits[key].keys()):
        print(l_type, logits[key][l_type].shape)

0
beat (134, 257)
duration (134, 33)
instrument (134, 65)
pitch (134, 129)
position (134, 13)
type (134, 5)
1
beat (251, 257)
duration (251, 33)
instrument (251, 65)
pitch (251, 129)
position (251, 13)
type (251, 5)
2
beat (97, 257)
duration (97, 33)
instrument (97, 65)
pitch (97, 129)
position (97, 13)
type (97, 5)
3
beat (118, 257)
duration (118, 33)
instrument (118, 65)
pitch (118, 129)
position (118, 13)
type (118, 5)


In [11]:
truths = read_truth('../exp/midi_data_x/ape/samples/npy/')

for key in sorted(truths.keys()):
    print(key, truths[key].shape)

0 (166, 6)
1 (282, 6)
2 (128, 6)
3 (149, 6)


In [12]:
logits_xy = read_logits('../exp/midi_data_xy/ape/samples/logits/')

for key in sorted(logits_xy.keys()):
    print(key)
    for l_type in sorted(logits_xy[key].keys()):
        print(l_type, logits_xy[key][l_type].shape)

0
beat (317, 257)
duration (317, 33)
instrument (317, 65)
pitch (317, 129)
position (317, 13)
type (317, 5)
1
beat (293, 257)
duration (293, 33)
instrument (293, 65)
pitch (293, 129)
position (293, 13)
type (293, 5)
2
beat (176, 257)
duration (176, 33)
instrument (176, 65)
pitch (176, 129)
position (176, 13)
type (176, 5)
3
beat (207, 257)
duration (207, 33)
instrument (207, 65)
pitch (207, 129)
position (207, 13)
type (207, 5)


In [13]:
logits_xy_neg = read_logits('../exp/midi_data_xy_neg/ape/samples/logits/')

for key in sorted(logits_xy_neg.keys()):
    print(key)
    for l_type in sorted(logits_xy_neg[key].keys()):
        print(l_type, logits_xy_neg[key][l_type].shape)

0
beat (142, 257)
duration (142, 33)
instrument (142, 65)
pitch (142, 129)
position (142, 13)
type (142, 5)
1
beat (382, 257)
duration (382, 33)
instrument (382, 65)
pitch (382, 129)
position (382, 13)
type (382, 5)
2
beat (184, 257)
duration (184, 33)
instrument (184, 65)
pitch (184, 129)
position (184, 13)
type (184, 5)
3
beat (188, 257)
duration (188, 33)
instrument (188, 65)
pitch (188, 129)
position (188, 13)
type (188, 5)


In [14]:
print('positive:')
for id in sorted(logits.keys()):
    print(id, ':')
    nine_a = calc_9(logits[id], truths[id])
    nine_b = calc_9(logits_xy[id], truths[id])
    for key in sorted(['beat', 'position', 'pitch', 'duration']):
        print()
        print(key, ':')
        print('9b - 9a:', nine_b[key] - nine_a[key])
        print('9a:', nine_a[key])
        print('9b:', nine_b[key])
    print('------------------')

positive:
0 :

beat :
9b - 9a: -11.768776
9a: 6.547868
9b: -5.2209077

duration :
9b - 9a: -1.320541
9a: 1.2540367
9b: -0.06650436

pitch :
9b - 9a: -6.7951803
9a: 4.99835
9b: -1.7968299

position :
9b - 9a: -4.0066066
9a: 2.3895078
9b: -1.6170988
------------------
1 :

beat :
9b - 9a: -15.347685
9a: 12.409351
9b: -2.938333

duration :
9b - 9a: -1.4329405
9a: 5.624644
9b: 4.1917033

pitch :
9b - 9a: -6.6002417
9a: 4.441007
9b: -2.1592348

position :
9b - 9a: -4.9100003
9a: 3.7843647
9b: -1.1256356
------------------
2 :

beat :
9b - 9a: -18.360144
9a: 10.532637
9b: -7.8275065

duration :
9b - 9a: -3.9163632
9a: 1.7990894
9b: -2.1172738

pitch :
9b - 9a: -6.6549954
9a: 2.3911252
9b: -4.2638702

position :
9b - 9a: -7.447062
9a: 2.8707662
9b: -4.576296
------------------
3 :

beat :
9b - 9a: -12.031067
9a: 13.100759
9b: 1.0696914

duration :
9b - 9a: -0.6862011
9a: 2.3727117
9b: 1.6865106

pitch :
9b - 9a: -4.194921
9a: 3.6773572
9b: -0.5175638

position :
9b - 9a: -2.7487712
9a: 1.2811

In [15]:
print('negative:')
for id in sorted(logits.keys()):
    print(id, ':')
    nine_a = calc_9(logits[id], truths[id])
    nine_b = calc_9(logits_xy_neg[id], truths[id])
    for key in sorted(['beat', 'position', 'pitch', 'duration']):
        print()
        print(key, ':')
        print('9b - 9a:', nine_b[key] - nine_a[key])
        print('9a:', nine_a[key])
        print('9b:', nine_b[key])
    print('------------------')

negative:
0 :

beat :
9b - 9a: -7.104929
9a: 6.547868
9b: -0.5570612

duration :
9b - 9a: -1.2930149
9a: 1.2540367
9b: -0.03897822

pitch :
9b - 9a: -5.0748615
9a: 4.99835
9b: -0.07651138

position :
9b - 9a: -6.254874
9a: 2.3895078
9b: -3.8653665
------------------
1 :

beat :
9b - 9a: -13.497978
9a: 12.409351
9b: -1.0886269

duration :
9b - 9a: -0.94711494
9a: 5.624644
9b: 4.677529

pitch :
9b - 9a: -3.9298115
9a: 4.441007
9b: 0.51119566

position :
9b - 9a: -4.254596
9a: 3.7843647
9b: -0.47023153
------------------
2 :

beat :
9b - 9a: -15.548332
9a: 10.532637
9b: -5.0156956

duration :
9b - 9a: -2.1232486
9a: 1.7990894
9b: -0.3241592

pitch :
9b - 9a: -1.0726395
9a: 2.3911252
9b: 1.3184857

position :
9b - 9a: -6.0927744
9a: 2.8707662
9b: -3.2220082
------------------
3 :

beat :
9b - 9a: -14.186874
9a: 13.100759
9b: -1.0861156

duration :
9b - 9a: -0.92983747
9a: 2.3727117
9b: 1.4428742

pitch :
9b - 9a: -3.8526323
9a: 3.6773572
9b: -0.17527509

position :
9b - 9a: -3.891138
9a: 1

In [17]:
import plotly.graph_objects as go

positive = {'beat': [], 'position': [], 'pitch': [], 'duration': []}
negative = {'beat': [], 'position': [], 'pitch': [], 'duration': []}

for id in sorted(logits.keys()):
    nine_a = calc_9(logits[id], truths[id])
    nine_b_pos = calc_9(logits_xy[id], truths[id])
    nine_b_neg = calc_9(logits_xy_neg[id], truths[id])
    for key in sorted(['beat', 'position', 'pitch', 'duration']):
        positive[key].append(nine_b_pos[key] - nine_a[key])
        negative[key].append(nine_b_neg[key] - nine_a[key])

for key in sorted(['beat', 'position', 'pitch', 'duration']):
    fig = go.Figure()
    fig.add_trace(go.Bar(y=positive[key], text=[round(i, 4) for i in positive[key]], name='positive'))
    fig.add_trace(go.Bar(y=negative[key], text=[round(i, 4) for i in negative[key]], name='negative'))
    fig.update_layout(title=str.capitalize(key),
                    xaxis_title='Sample',
                    yaxis_title='D_(Y|X) - D_(Y)',
                    width=1000,
                    height=500,
                    font_size=18,)
    fig.show()