### get data

In [1]:
import importlib
import numpy as np
from data_utils.pytorch_datasets.base_class import *
# import data_utils
# importlib.reload(data_utils)
from data_utils.utils.read_file import read_data_no_acc
from data_utils.read_pop909_data import analyze_pop909_dataset_without_acc
from data_utils.pytorch_datasets import create_form_datasets, create_counterpoint_datasets, create_leadsheet_datasets, \
    create_accompaniment_datasets
from data_utils.utils.song_analyzer import LanguageExtractor
from data_utils.pytorch_datasets.const import LANGUAGE_DATASET_PARAMS, AUTOREG_PARAMS, SHIFT_HIGH_T, SHIFT_LOW_T, SHIFT_HIGH_V, SHIFT_LOW_V
from data_utils.pytorch_datasets.form_dataset import FormDataset
from data_utils.pytorch_datasets.counterpoint_dataset import CounterpointDataset
from data_utils.pytorch_datasets.leadsheet_dataset import LeadSheetDataset


class ExtractFormCounterpointLeadsheet:
    def __init__(self, analyses):
        self.n_channels_frm = 8
        self.n_channels_ctp = 10
        self.n_channels_lsh = 12
        self.h_frm = 16
        self.h_ctp = 128
        self.h_lsh = 128

        self.frm_max_l = [analysis['languages']['form']['key_roll'].shape[1] for analysis in analyses]
        self.ctp_max_l = [analysis['languages']['counterpoint']['red_mel_roll'].shape[1] for analysis in analyses]
        self.lsh_max_l = [analysis['languages']['lead_sheet']['mel_roll'].shape[1] for analysis in analyses]

        form_langs = [analysis['languages']['form'] for analysis in analyses]
        ctpt_langs = [analysis['languages']['counterpoint'] for analysis in analyses]
        ldsht_langs = [analysis['languages']['lead_sheet'] for analysis in analyses]

        self.min_mel_pitches = [analysis['min_mel_pitch'] for analysis in analyses]
        self.max_mel_pitches = [analysis['max_mel_pitch'] for analysis in analyses]

        self.nbpms = [analysis['nbpm'] for analysis in analyses]
        self.nspbs = [analysis['nspb'] for analysis in analyses]
        self.song_names = [analysis['name'] for analysis in analyses]

        self.key_rolls = [form_lang['key_roll'] for form_lang in form_langs]
        self.key_rolls_ctp = [expand_roll(roll, nbpm) for roll, nbpm in zip(self.key_rolls, self.nbpms)]
        self.key_rolls_lsh = [expand_roll(roll, nbpm * nspb)
                          for roll, nbpm, nspb in zip(self.key_rolls, self.nbpms, self.nspbs)]

        self.phrase_rolls = [form_lang['phrase_roll'][:, :, np.newaxis] for form_lang in form_langs]
        self.phrase_rolls_ctp = [expand_roll(roll, nbpm) for roll, nbpm in zip(self.phrase_rolls, self.nbpms)]
        self.phrase_rolls_lsh = [expand_roll(roll, nbpm * nspb)
                             for roll, nbpm, nspb in zip(self.phrase_rolls, self.nbpms, self.nspbs)]

        self.red_mel_rolls = [ctpt_lang['red_mel_roll'] for ctpt_lang in ctpt_langs]
        self.red_mel_rolls_lsh = [expand_roll(roll, nspb, contain_onset=True)
                              for roll, nspb in zip(self.red_mel_rolls, self.nspbs)]
        
        self.red_chd_rolls = [ctpt_lang['red_chd_roll'] for ctpt_lang in ctpt_langs]
        self.red_chd_rolls_lsh = [expand_roll(roll, nspb, contain_onset=True)
                              for roll, nspb in zip(self.red_chd_rolls, self.nspbs)]

        self.mel_rolls = [ldsht_lang['mel_roll'] for ldsht_lang in ldsht_langs]
        self.chd_rolls = [ldsht_lang['chd_roll'] for ldsht_lang in ldsht_langs]
        self.chd_rolls_lsh = [expand_roll(roll, nspb, contain_onset=True)
                          for roll, nspb in zip(self.chd_rolls, self.nspbs)]


    def get_data_sample_frm(self, song_id=0, start_id=0, end_id=None):

        # store_key
        if self.key_rolls is not None:
            key_roll = self.key_rolls[song_id]
            self._key = np.roll(key_roll, shift=0, axis=-1)
        
        # store_phrase
        if self.phrase_rolls is not None:
            self._phrase = self.phrase_rolls[song_id]

        if end_id is None:
            end_id = self.frm_max_l[song_id]

        img = self.lang_to_img_frm(start_id, end_id=end_id, tgt_lgth=end_id-start_id)

        return img

    def get_data_sample_ctp(self, song_id=0, start_id=0, end_id=None):
        nbpm = self.nbpms[song_id]

        # store_key
        if self.key_rolls_ctp is not None:
            key_roll = self.key_rolls_ctp[song_id]
            self._key = np.roll(key_roll, shift=0, axis=-1)
        
        # store_phrase
        if self.phrase_rolls_ctp is not None:
            self._phrase = self.phrase_rolls_ctp[song_id]

        # store_red_mel
        if self.red_mel_rolls is not None:
            red_mel_roll = self.red_mel_rolls[song_id]
            self._red_mel = np.roll(red_mel_roll, shift=0, axis=-1)

        # store_red_chd
        if self.red_chd_rolls is not None:
            red_chd_roll = self.red_chd_rolls[song_id]
            self._red_chd = np.roll(red_chd_roll, shift=0, axis=-1)

        if end_id is None:
            end_id = self.ctp_max_l[song_id]
        img = self.lang_to_img_ctp(start_id, end_id, tgt_lgth=end_id-start_id)
        return img
        
    def get_data_sample_lsh(self, song_id=0, start_id=0, end_id=None):
        nbpm, nspb = self.nbpms[song_id], self.nspbs[song_id]

        # self.store_key(song_id, pitch_shift)
        if self.key_rolls_lsh is not None:
            key_roll = self.key_rolls_lsh[song_id]
            self._key = np.roll(key_roll, shift=0, axis=-1)

        # self.store_phrase(song_id)
        if self.phrase_rolls_lsh is not None:
            self._phrase = self.phrase_rolls_lsh[song_id]

        # self.store_red_mel(song_id, pitch_shift)
        if self.red_mel_rolls_lsh is not None:
            red_mel_roll = self.red_mel_rolls_lsh[song_id]
            self._red_mel = np.roll(red_mel_roll, shift=0, axis=-1)

        # self.store_red_chd(song_id, pitch_shift)
        if self.red_chd_rolls_lsh is not None:
            red_chd_roll = self.red_chd_rolls_lsh[song_id]
            self._red_chd = np.roll(red_chd_roll, shift=0, axis=-1)

        # self.store_mel(song_id, pitch_shift)
        if self.mel_rolls is not None:
            mel_roll = self.mel_rolls[song_id]
            self._mel = np.roll(mel_roll, shift=0, axis=-1)

        # self.store_chd(song_id, pitch_shift)
        if self.chd_rolls_lsh is not None:
            chd_roll = self.chd_rolls_lsh[song_id]
            self._chd = np.roll(chd_roll, shift=0, axis=-1)

        if end_id is None:
            end_id = self.lsh_max_l[song_id]

        img = self.lang_to_img_lsh(start_id, end_id=end_id, tgt_lgth=end_id-start_id)

        return img

    def lang_to_img_frm(self, start_id, end_id, tgt_lgth=None):
        key_roll = self._key[:, start_id: end_id]  # (2, L, 12)
        phrase_roll = self._phrase[:, start_id: end_id]  # (6, L, 1)

        # actual_l = self._key.shape[1]

        # to output image
        if tgt_lgth is None:
            tgt_lgth = end_id - start_id
        img = np.zeros((self.n_channels_frm, tgt_lgth, self.h_frm), dtype=np.float32)
        img[0: 2, 0: tgt_lgth, 0: 12] = key_roll
        img[2: 8, 0: tgt_lgth] = phrase_roll

        return img
    
    def lang_to_img_ctp(self, start_id, end_id, tgt_lgth=None):
        print(tgt_lgth)
        key_roll = self._key[:, start_id: end_id]  # (2, L, 12)
        phrase_roll = self._phrase[:, start_id: end_id]  # (6, L, 1)
        red_mel_roll = self._red_mel[:, start_id: end_id]  # (2, L, 128)
        red_chd_roll = self._red_chd[:, start_id: end_id]  # (6, L, 12)

        # actual_l = key_roll.shape[1]

        # to output image
        if tgt_lgth is None:
            tgt_lgth = self._key.shape[1] - start_id
        img = np.zeros((self.n_channels_ctp, tgt_lgth, 132), dtype=np.float32)
        print(tgt_lgth, red_mel_roll.shape)
        img[0: 2, 0: tgt_lgth, 0: 128] = red_mel_roll
        img[0: 2, 0: tgt_lgth, 36: 48] = red_chd_roll[2: 4]
        img[0: 2, 0: tgt_lgth, 24: 36] = red_chd_roll[4: 6]

        img[4: 10, 0: tgt_lgth] = phrase_roll

        img = img.reshape((self.n_channels_ctp, tgt_lgth, 11, 12))
        img[2: 4, 0: tgt_lgth] = key_roll[:, :, np.newaxis]
        img = img.reshape((self.n_channels_ctp, tgt_lgth, 132))
        return img[:, :, 0: self.h_ctp]
    
    def lang_to_img_lsh(self, start_id, end_id, tgt_lgth=None):
        key_roll = self._key[:, start_id: end_id]  # (2, L, 12)
        phrase_roll = self._phrase[:, start_id: end_id]  # (6, L, 1)
        red_mel_roll = self._red_mel[:, start_id: end_id]  # (2, L, 128)
        red_chd_roll = self._red_chd[:, start_id: end_id]  # (6, L, 12)
        mel_roll = self._mel[:, start_id: end_id]
        chd_roll = self._chd[:, start_id: end_id]

        # actual_l = key_roll.shape[1]

        # to output image
        if tgt_lgth is None:
            tgt_lgth = end_id - start_id
        img = np.zeros((self.n_channels_lsh, tgt_lgth, 132), dtype=np.float32)
        img[0: 2, 0: tgt_lgth, 0: 128] = mel_roll
        img[0: 2, 0: tgt_lgth, 36: 48] = chd_roll[2: 4]
        img[0: 2, 0: tgt_lgth, 24: 36] = chd_roll[4: 6]

        img[2: 4, 0: tgt_lgth, 0: 128] = red_mel_roll
        img[2: 4, 0: tgt_lgth, 36: 48] = red_chd_roll[2: 4]
        img[2: 4, 0: tgt_lgth, 24: 36] = red_chd_roll[4: 6]

        img[6: 12, 0: tgt_lgth] = phrase_roll

        img = img.reshape((self.n_channels_lsh, tgt_lgth, 11, 12))
        img[4: 6, 0: tgt_lgth] = key_roll[:, :, np.newaxis]
        img = img.reshape((self.n_channels_lsh, tgt_lgth, 132))
        return img[:, :, 0: self.h_lsh]

In [2]:
BEAT_PER_MEASURE = 4
song_data = read_data_no_acc("preprocessing/external_data/pig_mel", num_beat_per_measure=BEAT_PER_MEASURE, num_step_per_beat=4,
              clean_chord_unit=None, song_name=None, label=1)
lang_extractor = LanguageExtractor(song_data)
hie_lang = lang_extractor.analyze_without_acc()

In [3]:
print(hie_lang['languages']['form']['key_roll'].shape)
print(hie_lang['languages']['counterpoint']['red_mel_roll'].shape)
print(hie_lang['languages']['lead_sheet']['mel_roll'].shape)

(2, 23, 12)
(2, 92, 128)
(2, 368, 128)


In [4]:
mega_data = ExtractFormCounterpointLeadsheet([hie_lang])
start_measure = 0
end_measure = 8
frm = mega_data.get_data_sample_frm(start_id=start_measure, end_id=end_measure)
ctp = mega_data.get_data_sample_ctp(start_id=start_measure, end_id=end_measure*BEAT_PER_MEASURE)
lsh = mega_data.get_data_sample_lsh(start_id=start_measure, end_id=end_measure*BEAT_PER_MEASURE*4)

32
32 (2, 32, 128)


In [5]:
frm = np.expand_dims(frm, axis=0)
ctp = np.expand_dims(ctp, axis=0)
lsh = np.expand_dims(lsh, axis=0)

# generate whole song

In [6]:
# whole_song_gen_notebook.ipynb

# Import necessary libraries
from experiments.whole_song_gen import WholeSongGeneration
import torch

# Default model folders and demo directory
DEFAULT_FRM_MODEL_FOLDER = 'results_default/frm---/v-default'
DEFAULT_CTP_MODEL_FOLDER = 'results_default/ctp-a-b-/v-default'
DEFAULT_LSH_MODEL_FOLDER = 'results_default/lsh-a-b-/v-default'
DEFAULT_ACC_MODEL_FOLDER = 'results_default/acc-a-b-/v-default'
DEFAULT_DEMO_DIR = 'demo'

# Set the argument values directly
args = {
    'demo_dir': DEFAULT_DEMO_DIR,
    'mpath0': DEFAULT_FRM_MODEL_FOLDER,
    'mid0': 'default',
    'mpath1': DEFAULT_CTP_MODEL_FOLDER,
    'mid1': 'default',
    'mpath2': DEFAULT_LSH_MODEL_FOLDER,
    'mid2': 'default',
    'mpath3': DEFAULT_ACC_MODEL_FOLDER,
    'mid3': 'default',
    'nsample': 1,
    'pstring': None,
    'nbpm': 4,
    'key': 0,
    'minor': False,
    'debug': False
}

# Check available device
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

# Initialize the whole song generation pipeline
whole_song_expr = WholeSongGeneration.init_pipeline(
    frm_model_folder=args['mpath0'],
    ctp_model_folder=args['mpath1'],
    lsh_model_folder=args['mpath2'],
    acc_model_folder=args['mpath3'],
    frm_model_id=args['mid0'],
    ctp_model_id=args['mid1'],
    lsh_model_id=args['mid2'],
    acc_model_id=args['mid3'],
    debug_mode=args['debug'],
    device=device
)


default default default default
Description of the experiment is: m0-v-default-default
m1-v-default-default
m2-v-default-default
m3-v-default-default


In [7]:
# whole_song_expr.frm_op.data_params = {'max_l': 8, 'h': 16, 'n_channel': 8, 'cur_channel': 8}

In [8]:
import numpy as np
from inference.utils import quantize_generated_form_batch, specify_form

n_sample=args['nsample']
nbpm=BEAT_PER_MEASURE
nspb=4  # assuming nspb is a constant value
phrase_string="B34"
key=args['key']
is_major=args['minor']
demo_dir=args['demo_dir']
bpm = 110


In [10]:
## form generation
print("Form generation...")
frm_canvas, slices, gen_max_l = whole_song_expr.frm_op.create_canvas(n_sample=1, prompt=None)
frm_1 = whole_song_expr.frm_op.generation(frm_canvas, slices, gen_max_l, quantize=False, n_sample=1)
frm_2, lengths, phrase_labels = quantize_generated_form_batch(frm_1)
print(f"Length of the song: {lengths[0]}, phrase_label:\n{phrase_labels[0]}")
frm = frm_2[:, :, 0: lengths[0]]
phrase_string = phrase_labels[0]

Form generation...


Length of the song: 8, phrase_label:
0: i1
1: i7



In [11]:
# ctp generation
print("Counterpoint generation...")
background_cond = whole_song_expr.ctp_op.expand_background(frm, nbpm)

ctp_canvas, slices, gen_max_l = \
    whole_song_expr.ctp_op.create_canvas(background_cond, n_sample, nbpm, None, whole_song_expr.random_n_autoreg)
print(f"Number of iterations: {len(slices)}")
ctp = whole_song_expr.ctp_op.generation(ctp_canvas, slices, gen_max_l)
ctp = np.stack(ctp, 0)


Counterpoint generation...
Number of iterations: 1


In [27]:
## Lead Sheet generation
print("Lead Sheet generation...")
background_cond = whole_song_expr.lsh_op.expand_background(ctp, nspb)
lsh_canvas, slices, gen_max_l = \
    whole_song_expr.lsh_op.create_canvas(background_cond, n_sample, nbpm, nspb, None, whole_song_expr.random_n_autoreg)
print(f"Number of iterations: {len(slices)}")
lsh = whole_song_expr.lsh_op.generation(lsh_canvas, slices, gen_max_l)
lsh = np.stack(lsh, 0)

Lead Sheet generation...
Number of iterations: 4


In [9]:
## Accompaniment generation
print("Accompaniment generation...")
acc_canvas, slices, gen_max_l = \
    whole_song_expr.acc_op.create_canvas(lsh, n_sample, nbpm, nspb, None, whole_song_expr.random_n_autoreg)
print(f"Number of iterations: {len(slices)}")
acc = whole_song_expr.acc_op.generation(acc_canvas, slices, gen_max_l)

Accompaniment generation...
Number of iterations: 1


In [10]:
midi_file = whole_song_expr.output(acc, phrase_string, key, is_major, demo_dir, bpm)