# Query-and-reArrange

* Tutorial notebook for Zhao et al., [Q&A: Query-Based Representation Learning for Multi-Track Symbolic Music re-Arrangement](https://arxiv.org/abs/2306.01635), accepted by IJCAI 2023 Special Track for AI the Arts and Creativity.

* Based on composition style transfer, Q&A is a generic model for a range of symbolic rearrangement problems, including 1) **orchestration**, 2) **piano cover generation**, 3) **re-instrumentation**, and 4) **voice separation**. We will demonstrate each case in this notebook.

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']= '0'
import numpy as np
import torch
from torch.utils.data import DataLoader
from model import Query_and_reArrange
from dataset import Slakh2100_Pop909_Dataset, collate_fn_inference, EMBED_PROGRAM_MAPPING
SLAKH_CLASS_MAPPING = {v: k for k, v in EMBED_PROGRAM_MAPPING.items()}
from utils.format_convert import matrix2midi_with_dynamics, dataitem2midi
from utils.inferring import mixture_function_prior, search_reference, velocity_adaption
import datetime
import json
import warnings
warnings.filterwarnings("ignore")

## 1. Symbolic multi-track music rearrangement

* In the following, we apply Q&A to perform **rearrangement** on $8$-bar music samples (i.e., `SAMPLE_BAR_LEN`=$8$). Demo will be saved to `./demo`.

* We first sample a *source* piece $x$ as the donor of content, and then sample a *reference* piece $y$ as the donor of track functions (style). Later, we apply Q&A to generate *target* piece $\hat{x}$, which is the rearrangement version of $x$ using the style of $y$.

* We set `DEBUG_MODE`=`True` and load a small portion of the sample datasets. You may toggle this setting if you have a sufficient RAM and more diverse results will be generated.

In [2]:
POP909_DIR = "./data/POP909"
SLAKH2100_DIR = "./data/Slakh2100"
with open("./data/slakh_melody_check.json", 'r') as f:
    MEL_CHECK = json.load(f)
SAVE_DIR = './demo'

SAMPLE_BAR_LEN = 8

MODEL_DIR = "./checkpoints/Q&A_epoch_029.pt"
DEVICE = 'cuda:0'
model = Query_and_reArrange(name='inference_model', device=DEVICE, trf_layers=2)
model.load_state_dict(torch.load(MODEL_DIR, map_location='cpu'))
model.to(DEVICE)
model.eval();

### 1.1 Orchestration

* For orchestration, we sample a piano clip $x$ from POP909 and a multi-track clip $y$ from Slakh2100, and then orchestrate $x$ using $y$'s style.

In [3]:
# load piano dataset. A piano piece x is the donor of content.
x_set = Slakh2100_Pop909_Dataset(None, POP909_DIR, 16*SAMPLE_BAR_LEN, debug_mode=True, split='validation', mode='inference', with_dynamics=True)
# load multi-track dataset. A multi-track piece y is the donor of style.
y_set = Slakh2100_Pop909_Dataset(SLAKH2100_DIR, None, 16*SAMPLE_BAR_LEN, debug_mode=True, split='validation', mode='inference', with_dynamics=True)
# Prepare for the heuristic sampling of y
y_set_loader = DataLoader(y_set, batch_size=1, shuffle=False, collate_fn=lambda b: collate_fn_inference(b, DEVICE))
y_prior_set = mixture_function_prior(y_set_loader)

loading Pop909 Dataset ...


100%|██████████| 10/10 [00:00<00:00, 63.68it/s]


loading Slakh2100 Dataset ...


100%|██████████| 10/10 [00:00<00:00, 22.61it/s]


Rendering sample space for style references ...


100%|██████████| 857/857 [00:03<00:00, 240.04it/s]


* Sampling source piece $x$ from POP909

In [4]:
# get a random x sample
IDX = np.random.randint(len(x_set))
x = x_set.__getitem__(IDX)
(x_mix, x_instr, x_fp, x_ft), x_dyn, x_dir = collate_fn_inference(batch = [(x)], device = DEVICE)
# save x
save_path = os.path.join(SAVE_DIR, f"orchestration-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]}")
if not os.path.exists(save_path):
    os.makedirs(save_path)
x_recon = dataitem2midi(*x, SLAKH_CLASS_MAPPING)
x_recon.write(os.path.join(save_path, '01_source.mid'))
print(f'saved to {save_path}.')

saved to ./demo\orchestration-230604164637.


* Calling Q&A for **orchestration** after sampling reference piece $y$

In [5]:
# heuristic sampling for y (i.e., Equation (8) in the paper)
y_anchor = search_reference(x_fp, x_ft, y_prior_set)
y = y_set.__getitem__(y_anchor)
(y_mix, y_instr, y_fp, y_ft), y_dyn, y_dir = collate_fn_inference(batch=[(y)], device=DEVICE)
# exchange x's and y's melody track function in order to preserve the theme melody after rearrangement.
x_mel, y_mel = 0, MEL_CHECK[y_dir.replace('\\', '/').split('/')[-1].replace('.npz', '')]
y_fp[:, y_mel] = x_fp[:, x_mel]
y_ft[:, y_mel] = x_ft[:, x_mel]
#save y
y_recon = dataitem2midi(*y, SLAKH_CLASS_MAPPING)
y_recon.write(os.path.join(save_path, '02_reference.mid'))

# Q&A model inference
output = model.inference(x_mix, y_instr, y_fp, y_ft, mel_id=y_mel)
# apply y's dynamics to the rearrangement result
velocity = velocity_adaption(y_dyn[..., 0], output, y_mel)
cc = y_dyn[..., 1]
output = np.stack([output, velocity, cc], axis=-1)
# reconstruct MIDI
midi_recon = matrix2midi_with_dynamics(
    matrices=output, 
    programs=[SLAKH_CLASS_MAPPING[item.item()] for item in y_instr[0]], 
    init_tempo=100)
midi_recon.write(os.path.join(save_path, '03_target.mid'))
print(f'saved to {save_path}.')

saved to ./demo\orchestration-230604164637.


### 1.2 Piano Cover Generation

* For piano cover generation, we sample a multi-track clip $x$ from Slakh2100 and a piano clip $y$ from POP909, and then rearrange $x$ using $y$'s textures.

In [6]:
# load piano dataset. A piano piece x is the donor of content.
x_set = Slakh2100_Pop909_Dataset(SLAKH2100_DIR, None, 16*SAMPLE_BAR_LEN, debug_mode=True, split='validation', mode='inference', with_dynamics=True)
# load multi-track dataset. A multi-track piece y is the donor of style.
y_set = Slakh2100_Pop909_Dataset(None, POP909_DIR, 16*SAMPLE_BAR_LEN, debug_mode=True, split='validation', mode='inference', with_dynamics=True)
# Prepare for the heuristic sampling of y
y_set_loader = DataLoader(y_set, batch_size=1, shuffle=False, collate_fn=lambda b: collate_fn_inference(b, DEVICE))
y_prior_set = mixture_function_prior(y_set_loader)

loading Slakh2100 Dataset ...


100%|██████████| 10/10 [00:00<00:00, 20.86it/s]


loading Pop909 Dataset ...


100%|██████████| 10/10 [00:00<00:00, 59.51it/s]


Rendering sample space for style references ...


100%|██████████| 727/727 [00:02<00:00, 267.19it/s]


* Sampling source piece $x$ from Slakh2100

In [7]:
# get a random x sample
IDX = np.random.randint(len(x_set))
x = x_set.__getitem__(IDX)
(x_mix, x_instr, x_fp, x_ft), x_dyn, x_dir = collate_fn_inference(batch = [(x)], device = DEVICE)
# save x
save_path = os.path.join(SAVE_DIR, f"pianocover-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]}")
if not os.path.exists(save_path):
    os.makedirs(save_path)
x_recon = dataitem2midi(*x, SLAKH_CLASS_MAPPING)
x_recon.write(os.path.join(save_path, '01_source.mid'))
print(f'saved to {save_path}.')

saved to ./demo\pianocover-230604164651.


* Calling Q&A for **piano cover generation** after sampling $y$

In [8]:
# heuristic sampling for y (i.e., Equation (8) in the paper)
y_anchor = search_reference(x_fp, x_ft, y_prior_set)
y = y_set.__getitem__(y_anchor)
(y_mix, y_instr, y_fp, y_ft), y_dyn, y_dir = collate_fn_inference(batch=[(y)], device=DEVICE)
# exchange x's and y's melody track function in order to preserve the theme melody after rearrangement.
x_mel, y_mel = MEL_CHECK[x_dir.replace('\\', '/').split('/')[-1].replace('.npz', '')], 0
y_fp[:, y_mel] = x_fp[:, x_mel]
y_ft[:, y_mel] = x_ft[:, x_mel]
# save y
y_recon = dataitem2midi(*y, SLAKH_CLASS_MAPPING)
y_recon.write(os.path.join(save_path, '02_reference.mid'))

# Q&A model inference
output = model.inference(x_mix, y_instr, y_fp, y_ft, mel_id=y_mel)
# apply y's dynamics to the rearrangement result
velocity = velocity_adaption(y_dyn[..., 0], output, y_mel)
cc = y_dyn[..., 1]
output = np.stack([output, velocity, cc], axis=-1)
# reconstruct MIDI
midi_recon = matrix2midi_with_dynamics(
    matrices=output, 
    programs=[SLAKH_CLASS_MAPPING[item.item()] for item in y_instr[0]], 
    init_tempo=100)
midi_recon.write(os.path.join(save_path, '03_target.mid'))
print(f'saved to {save_path}.')

saved to ./demo\pianocover-230604164651.


### 1.3 Re-Instrumentation

* For re-instrumentation, we sample multi-track clips $x$ and $y$ both from Slakh2100, and then rearrange $x$ using $y$'s style.

In [9]:
# load piano dataset. A piano piece x is the donor of content.
x_set = Slakh2100_Pop909_Dataset(SLAKH2100_DIR, None, 16*SAMPLE_BAR_LEN, debug_mode=True, split='test', mode='inference', with_dynamics=True)
# load multi-track dataset. A multi-track piece y is the donor of style.
y_set = Slakh2100_Pop909_Dataset(SLAKH2100_DIR, None, 16*SAMPLE_BAR_LEN, debug_mode=True, split='validation', mode='inference', with_dynamics=True)
# Prepare for the heuristic sampling of y
y_set_loader = DataLoader(y_set, batch_size=1, shuffle=False, collate_fn=lambda b: collate_fn_inference(b, DEVICE))
y_prior_set = mixture_function_prior(y_set_loader)

loading Slakh2100 Dataset ...


100%|██████████| 10/10 [00:00<00:00, 12.90it/s]


loading Slakh2100 Dataset ...


100%|██████████| 10/10 [00:00<00:00, 18.16it/s]


Rendering sample space for style references ...


100%|██████████| 857/857 [00:04<00:00, 186.59it/s]


* Sampling source piece $x$ from Slakh2100

In [10]:
# get a random x sample
IDX = np.random.randint(len(x_set))
x = x_set.__getitem__(IDX)
(x_mix, x_instr, x_fp, x_ft), x_dyn, x_dir = collate_fn_inference(batch = [(x)], device = DEVICE)
# save x
save_path = os.path.join(SAVE_DIR, f"reinstrumentation-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]}")
if not os.path.exists(save_path):
    os.makedirs(save_path)
x_recon = dataitem2midi(*x, SLAKH_CLASS_MAPPING)
x_recon.write(os.path.join(save_path, '01_source.mid'))
print(f'saved to {save_path}.')

saved to ./demo\reinstrumentation-230604164702.


* Calling Q&A for **re-instrumentation** after sampling $y$

In [11]:
# heuristic sampling for y (i.e., Equation (8) in the paper)
y_anchor = search_reference(x_fp, x_ft, y_prior_set)
y = y_set.__getitem__(y_anchor)
(y_mix, y_instr, y_fp, y_ft), y_dyn, y_dir = collate_fn_inference(batch=[(y)], device=DEVICE)
# exchange x's and y's melody track function in order to preserve the theme melody after rearrangement.
x_mel, y_mel = MEL_CHECK[x_dir.replace('\\', '/').split('/')[-1].replace('.npz', '')], MEL_CHECK[y_dir.replace('\\', '/').split('/')[-1].replace('.npz', '')]
y_fp[:, y_mel] = x_fp[:, x_mel]
y_ft[:, y_mel] = x_ft[:, x_mel]
# save y
y_recon = dataitem2midi(*y, SLAKH_CLASS_MAPPING)
y_recon.write(os.path.join(save_path, '02_reference.mid'))

# Q&A model inference
output = model.inference(x_mix, y_instr, y_fp, y_ft, mel_id=y_mel)
# apply y's dynamics to the rearrangement result
velocity = velocity_adaption(y_dyn[..., 0], output, y_mel)
cc = y_dyn[..., 1]
output = np.stack([output, velocity, cc], axis=-1)
# reconstruct MIDI
midi_recon = matrix2midi_with_dynamics(
    matrices=output, 
    programs=[SLAKH_CLASS_MAPPING[item.item()] for item in y_instr[0]], 
    init_tempo=100)
midi_recon.write(os.path.join(save_path, '03_target.mid'))
print(f'saved to {save_path}.')

saved to ./demo\reinstrumentation-230604164702.


## 2. Voice separation

* By inferring track functions as voice hints, Q&A can additionally handle voice separation.
* We assume a preset total number of voices, which equals to 4 in our case.
* In the following, we apply Q&A for voice separation on Bach chorales and string quartets. Demo will be saved to `./demo/voice_separation`.

In [12]:
from model import Query_and_reArrange_vocie_separation
from dataset import Voice_Separation_Dataset
from utils.format_convert import matrix2midi, mixture2midi

### 2.1 Bach chorales

* Loading Bach Chorales dataset

In [13]:
BACH_DIR = "./data/Bach_Chorales"
QUARTETS_DIR = None
SAVE_DIR = './demo'

DEVICE = 'cuda:0'
MODEL_DIR = "./checkpoints/Q&A_chorales_epoch_041.pt"
model = Query_and_reArrange_vocie_separation(name='inference_model', device=DEVICE, trf_layers=2)
model.load_state_dict(torch.load(MODEL_DIR))
model.to(DEVICE)
model.eval();

x_set = Voice_Separation_Dataset(BACH_DIR, QUARTETS_DIR, 'full', split='validation', mode='inference')

loading Bach Chorale Dataset ...


100%|██████████| 41/41 [00:00<00:00, 99.02it/s]


* Sample a mixture

In [14]:
# get a random x sample
IDX = np.random.randint(len(x_set))
x = x_set.__getitem__(IDX)
(x_mix, x_instr, _, _), _, x_dir = collate_fn_inference(batch = [(x)], device = DEVICE)
# save mixture
save_path = os.path.join(SAVE_DIR, f"voiceseparation-{x_dir.replace('.npz', '')}-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]}")
if not os.path.exists(save_path):
    os.makedirs(save_path)
x_recon = mixture2midi(x_mix)
x_recon.write(os.path.join(save_path, f"01_source.mid"))
print(f'saved to {save_path}.')

saved to ./demo\voiceseparation-bwv430-230604164718.


* Calling Q&A for **voice separation**

In [15]:
output = model.inference(x_mix, x_instr)
midi_recon = matrix2midi(output, programs=[52]*4, init_tempo=100)
midi_recon.write(os.path.join(save_path, '02_target.mid'))
print(f'saved to {save_path}.')

saved to ./demo\voiceseparation-bwv430-230604164718.


### 2.2 String Quartets

* Loading String Quartets dataset

In [16]:
BACH_DIR = None
QUARTETS_DIR = './data/String_Quartets'
MODEL_DIR = "./checkpoints/Q&A_quartets_epoch_029.pt"
SAVE_DIR = './demo'

DEVICE = 'cuda:0'
model = Query_and_reArrange_vocie_separation(name='inference_model', device=DEVICE, trf_layers=2)
model.load_state_dict(torch.load(MODEL_DIR))
model.to(DEVICE)
model.eval();

x_set = Voice_Separation_Dataset(BACH_DIR, QUARTETS_DIR, 'full', split='validation', mode='inference')

loading String Quartets Dataset ...


100%|██████████| 6/6 [00:00<00:00, 33.86it/s]


* Sample a mixture

In [17]:
# get a random x sample
IDX = np.random.randint(len(x_set))
x = x_set.__getitem__(IDX)
(x_mix, x_instr, _, _), _, x_dir = collate_fn_inference(batch = [(x)], device = DEVICE)
# save mixture
save_path = os.path.join(SAVE_DIR, f"voiceseparation-{x_dir.replace('.npz', '')}-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')[2:]}")
if not os.path.exists(save_path):
    os.makedirs(save_path)
x_recon = mixture2midi(x_mix)
x_recon.write(os.path.join(save_path, f"01_source.mid"))
print(f'saved to {save_path}.')

saved to ./demo\voiceseparation-Ravel-2180_gr_rqtf4-230604164725.


* Calling Q&A for **voice separation**

In [18]:
output = model.inference(x_mix, x_instr)
midi_recon = matrix2midi(output, programs=[40, 40, 41, 42], init_tempo=100)
midi_recon.write(os.path.join(save_path, '02_target.mid'))
print(f'saved to {save_path}.')

saved to ./demo\voiceseparation-Ravel-2180_gr_rqtf4-230604164725.
