In [None]:
import os
import glob
import json
import random
import jupyter_black
import IPython.display

from GPT import *
from utils.data import *
from utils.utils import *
from utils.metrics import *

from scipy.io.wavfile import write
from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything

from loop_extraction.src.utils.utils import dataset_split
from loop_extraction.src.utils.utils import folder_to_file
from loop_extraction.src.utils.utils import folder_to_multiple_file

from loop_extraction.src.utils.remi import *
from loop_extraction.src.utils.constants import *
from loop_extraction.src.utils.bpe_encode import MusicTokenizer

jupyter_black.load(line_length=100)

%load_ext autoreload
%autoreload 2

In [None]:
#### load config
with open("./config/config.json", "r") as f:
    config = json.load(f)

#### set seed
seed_everything(config["random_seed"])
torch.set_float32_matmul_precision("medium")

#### initialize model with GPU
use_cuda = torch.cuda.is_available()
device = torch.device(5 if use_cuda else "cpu")

#### load tokenizer
bpe_path = "./loop_extraction/tokenizer/tokenizer_meta.json"
tokenizer = MusicTokenizer(bpe_path)

bar_idx = tokenizer.encode([BAR_TOKEN])[0]
pad_idx = tokenizer.encode([PAD_TOKEN])[0]
eob_idx = tokenizer.encode([EOB_TOKEN])[0]

vocab_size = tokenizer.bpe_vocab.get_vocab_size() + 1

In [None]:
#### load datasets
folder_path = "/workspace/loop_generation_old/data/"
datasets = [
    "lmd_full_loop_" + str(config["max_length"]),
    "meta_midi_loop_" + str(config["max_length"]),
]

folder_list = []
for dataset in datasets:
    folder_list += glob.glob(os.path.join(folder_path, dataset, "*"))

random.shuffle(folder_list)

#### split song into train, val, test
train_folder, val_folder, test_folder = dataset_split(folder_list, train_ratio=0.98, val_ratio=0.01)

#### get file_path of each dataset
test_files = folder_to_multiple_file(test_folder, k=3)
print(f"test_files : {len(test_files)}")

In [None]:
#### load model skeleton
model = GPT(
    vocab_size,
    pad_idx,
    dim_model=config["dim_model"],
    num_layers=config["num_layers"],
    num_heads=config["num_heads"],
    multiplier=config["multiplier"],
    dropout=config["dropout"],
    max_length=config["max_length"],
)

total_params = sum(p.numel() for p in model.parameters())
print(f"The number of parameters : {sizeof_number(total_params)}")

In [None]:
#### load pre-trained model
folder_path = "/workspace/loop_generation_old/model"
model_path = "version_GPT-medium-random_drop_0.5-max_length_1024-epoch=4-val_loss=0.4794.ckpt"

model = model.load_from_checkpoint(
    os.path.join(folder_path, model_path),
    vocab_size=vocab_size,
    pad_idx=pad_idx,
    dim_model=config["dim_model"],
    num_layers=config["num_layers"],
    num_heads=config["num_heads"],
    multiplier=config["multiplier"],
    dropout=config["dropout"],
    max_length=config["max_length"],
    map_location=device,
)

#### Save Generated Samples

In [None]:
def trim_tails(loop, end_time, fs):
    play_time = int(fs * end_time)

    loop_fs = loop.fluidsynth(fs=fs)
    diff = loop_fs.shape[0] - play_time

    rand = random.uniform(0, 0.2)
    margin = int(diff * rand) if diff > 0 else 0

    return loop_fs[: play_time + margin]

In [None]:
fs = 44100.0

gen_path = os.path.join("./evaluation", "gen_sample_cond_full")
true_path = os.path.join("./evaluation", "extension1")

tokenizer.add_tokens([START_TOKEN])
start_token = tokenizer.encode_meta(START_TOKEN)

for i, file_path in enumerate(test_files):
    file_name = file_path.split("/")[-1].split(".")[0]

    with open(file_path, "rb") as f:
        x = pickle.load(f)

    ### generate conditions
    cond = convert_cond(x, tokenizer, random_drop=0)
    cond = start_token + cond + [bar_idx]
    cond = torch.tensor(cond, dtype=torch.long)

    # ###convert inst
    # cond = sorted(tokenizer.encode(x["inst"]))
    # cond = start_token + cond + [bar_idx]
    # cond = torch.tensor(cond, dtype=torch.long)

    ## generate samples
    gen_loop = generate(
        cond,
        model,
        device,
        [bar_idx, eob_idx],
        temp=1,
        top_k=10,
        sample=True,
        max_length=config["max_length"],
    )

    ori_loop_remi, ori_end_time = remi2midi(x["loop"] + x["loop"])
    gen_loop_remi, gen_end_time = remi2midi(gen_loop)

    ori_loop_fs = trim_tails(ori_loop_remi, ori_end_time, fs)
    gen_loop_fs = trim_tails(gen_loop_remi, gen_end_time, fs)

    ori_file_path = os.path.join(true_path, "true_loop_" + file_name) + ".mid"
    gen_file_path = os.path.join(gen_path, "gen_cond_full_" + file_name) + ".mid"

    print("go")
    break
    # write(ori_file_path, int(fs), ori_loop_fs.astype(np.float32))
    # write(gen_file_path, int(fs), gen_loop_fs.astype(np.float32))

    if i % 100 == 0:
        print(f"We are processing the {i}th file!")

In [None]:
rate = 44100.0

ori_loop_remi, ori_end_time = remi2midi(x["loop"])
gen_loop_remi, gen_end_time = remi2midi(gen_loop)

ori_loop_fs = trim_tails(ori_loop_remi, ori_end_time, fs)
gen_loop_fs = trim_tails(gen_loop_remi, gen_end_time, fs)

IPython.display.display(IPython.display.Audio(ori_loop_fs, rate=rate))
IPython.display.display(IPython.display.Audio(gen_loop_fs, rate=rate))