In [None]:
import os
import glob
import json
import random
import pickle
import jupyter_black

from utils.utils import *
from utils.metrics import *

from prdc import compute_prdc
from pytorch_lightning import seed_everything
from torch.nn.utils.rnn import pad_sequence

from loop_extraction.src.BERT import BERT_Lightning
from loop_extraction.src.utils.utils import dataset_split
from loop_extraction.src.utils.utils import folder_to_multiple_file

from loop_extraction.src.utils.remi import *
from loop_extraction.src.utils.vocab import *
from loop_extraction.src.utils.constants import *
from loop_extraction.src.utils.bpe_encode import MusicTokenizer

jupyter_black.load(line_length=100)

%load_ext autoreload
%autoreload 2

In [None]:
seed_everything(0)

In [None]:
def get_bins(KEY):
    bins = None

    if KEY == TEMPO_KEY:
        bins = DEFAULT_TEMPO_BINS
    elif KEY == VELOCITY_KEY:
        bins = DEFAULT_VELOCITY_BINS
    elif KEY == DURATION_KEY:
        bins = DEFAULT_DURATION_BINS

    return bins

In [None]:
def get_mean_value(tokens, KEY):
    indices = [int(token.split("_")[-1]) for token in tokens if KEY in token]

    if KEY == PITCH_KEY:
        return np.mean(indices)

    bins = get_bins(KEY)
    return np.mean([bins[index] for index in indices])

#### Controllability

In [None]:
path = "/workspace/loop_generation/samples/GPT-medium_random_drop_0-cond_full"
test_files = glob.glob(os.path.join(path, "*"))

results = []
for i, file_path in enumerate(test_files):
    with open(file_path, "rb") as f:
        x = pickle.load(f)

    # 1. instrument - Jaccard iimilarity
    pred_inst = list(set([token for token in x["loop"] if INSTRUMENT_KEY in token]))
    true_inst = x["inst"]

    nom = len(set(pred_inst).intersection(set(true_inst)))
    denom = len(set(pred_inst).union(set(true_inst)))
    inst = nom / denom

    # 2. mean_family
    pred_results = []
    for key in [PITCH_KEY, TEMPO_KEY, VELOCITY_KEY, DURATION_KEY]:
        value = get_mean_value(x["loop"], key)
        pred_results.append(value)

    true_results = []
    for token in [x["mean_pitch"], x["mean_tempo"], x["mean_velocity"], x["mean_duration"]]:
        key, index = token[0].split("_")
        bins = get_bins(key)

        if bins is not None:
            true_results.append(bins[int(index)])
        else:
            true_results.append(int(index))

    # distance
    pred_results = np.array(pred_results)
    true_results = np.array(true_results)
    mean_family = np.abs(pred_results - true_results).tolist()

    # 3. bar_length
    pred_length = len(list([token for token in x["loop"] if BAR_KEY in token]))
    true_length = int(x["bar_length"][0].split("_")[-1])
    bar_length = 1 if pred_length == true_length else 0

    results.append([inst] + mean_family + [bar_length])

print(f"Results : {np.mean(np.array(results), axis=0)}")

#### Precision & Recall

In [None]:
# load BERT-stranger model
with open("./loop_extraction/config/config.json", "r") as f:
    config = json.load(f)

# initialize model with GPU
use_cuda = torch.cuda.is_available()
device = torch.device(0 if use_cuda else "cpu")

# tokenizer path
bpe_path = "./loop_extraction/tokenizer/tokenizer.json"
bpe_meta_path = "./loop_extraction/tokenizer/tokenizer_meta.json"

# tokenizer
tokenizer = MusicTokenizer(bpe_path)

In [None]:
#### load datasets
folder_path = "./data/"
datasets = ["lmd_full_loop_1024", "meta_midi_loop_1024"]

folder_list = []
for dataset in datasets:
    folder_list += glob.glob(os.path.join(folder_path, dataset, "*"))

random.shuffle(folder_list)

#### split song into train, val, test
train_folder, val_folder, test_folder = dataset_split(folder_list, train_ratio=0.98, val_ratio=0.01)

#### get file_path of each dataset
train_files = folder_to_multiple_file(train_folder, k=1)
train_files = train_files[: len(test_files)]

print(f"train_files : {len(train_files)}")

In [None]:
# define model
model = BERT_Lightning(
    dim=config["dim"],
    depth=config["depth"],
    heads=config["heads"],
    dim_head=int(config["dim"] / config["heads"]),
    mlp_dim=int(4 * config["dim"]),
    max_len=config["max_len"],
    rate=config["rate"],
    bpe_path=bpe_path,
)

In [None]:
model = model.load_from_checkpoint(
    "./loop_extraction//model/BERT-stranger.ckpt",
    dim=config["dim"],
    depth=config["depth"],
    heads=config["heads"],
    dim_head=int(config["dim"] / config["heads"]),
    mlp_dim=int(4 * config["dim"]),
    max_len=config["max_len"],
    rate=config["rate"],
    bpe_path=bpe_path,
).to(device)

In [None]:
def get_feat(files):
    results = []

    for i, file_path in enumerate(files):
        with open(file_path, "rb") as f:
            events = pickle.load(f)["loop"]

        bars = [i for i, event in enumerate(events) if f"{BAR_KEY}_" in event]

        contexts = list(zip(bars[:-1], bars[1:])) + [(bars[-1], len(events))]
        contexts = [
            (start, end)
            if (end - start) <= (MAX_TOKEN_LEN - 1)
            else (start, start + (MAX_TOKEN_LEN - 1))
            for (start, end) in contexts
        ]

        music = []
        for j, (start, end) in enumerate(contexts):
            bar = events[start:end]

            if EOB_TOKEN not in bar:
                bar = bar + [EOB_TOKEN]

            # REMI to BPE tokens
            bar = tokenizer.encode(bar)
            bar = torch.tensor(bar, dtype=torch.long).to(device)
            music.append(bar)

        pad_idx = RemiVocab().to_i(PAD_TOKEN)
        music = pad_sequence(music, batch_first=True, padding_value=pad_idx)

        model.eval()
        with torch.no_grad():
            _, h = model(music)

        results.append(h.detach().cpu().numpy())

    return np.vstack(results)

In [None]:
# get features
train_feat = get_feat(train_files)
gen_feat = get_feat(test_files)

print(f"train_feat : {len(train_feat)}, gen_feat : {len(gen_feat)}")

In [None]:
# get precision & recall & density & coverage
metrics = compute_prdc(real_features=train_feat, fake_features=gen_feat, nearest_k=5)
print(f"prdc : {metrics}")

In [None]:
# get FID
fid = compute_fid(train_feat, gen_feat)
print(f"fid : {fid}")