In [2]:
# whole_song_gen_notebook.ipynb

# Import necessary libraries
from experiments.whole_song_gen import WholeSongGeneration
import torch

# Default model folders and demo directory
DEFAULT_FRM_MODEL_FOLDER = 'results_default/frm---/v-default'
DEFAULT_CTP_MODEL_FOLDER = 'results_default/ctp-a-b-/v-default'
DEFAULT_LSH_MODEL_FOLDER = 'results_default/lsh-a-b-/v-default'
DEFAULT_ACC_MODEL_FOLDER = 'results_default/acc-a-b-/v-default'
DEFAULT_DEMO_DIR = 'demo'

# Set the argument values directly
args = {
    'demo_dir': DEFAULT_DEMO_DIR,
    'mpath0': DEFAULT_FRM_MODEL_FOLDER,
    'mid0': 'default',
    'mpath1': DEFAULT_CTP_MODEL_FOLDER,
    'mid1': 'default',
    'mpath2': DEFAULT_LSH_MODEL_FOLDER,
    'mid2': 'default',
    'mpath3': DEFAULT_ACC_MODEL_FOLDER,
    'mid3': 'default',
    'nsample': 1,
    'pstring': None,
    'nbpm': 4,
    'key': 0,
    'minor': False,
    'debug': False
}

# Check available device
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

# Initialize the whole song generation pipeline
whole_song_expr = WholeSongGeneration.init_pipeline(
    frm_model_folder=args['mpath0'],
    ctp_model_folder=args['mpath1'],
    lsh_model_folder=args['mpath2'],
    acc_model_folder=args['mpath3'],
    frm_model_id=args['mid0'],
    ctp_model_id=args['mid1'],
    lsh_model_id=args['mid2'],
    acc_model_id=args['mid3'],
    debug_mode=args['debug'],
    device=device
)


default default default default
Description of the experiment is: m0-v-default-default
m1-v-default-default
m2-v-default-default
m3-v-default-default


In [3]:
whole_song_expr.frm_op.data_params = {'max_l': 8, 'h': 16, 'n_channel': 8, 'cur_channel': 8}

##### here we break down the main() method of WholeSongGeneration class


# Generate whole song
whole_song_expr.main(
    n_sample=args['nsample'],
    nbpm=args['nbpm'],
    nspb=4,  # assuming nspb is a constant value
    phrase_string=args['pstring'],
    key=args['key'],
    is_major=args['minor'],
    demo_dir=args['demo_dir']
)


In [4]:
import numpy as np
from inference.utils import quantize_generated_form_batch, specify_form

n_sample=args['nsample']
nbpm=args['nbpm']
nspb=4  # assuming nspb is a constant value
phrase_string=args['pstring']
key=args['key']
is_major=args['minor']
demo_dir=args['demo_dir']
bpm = 90


In [5]:
## form generation
print("Form generation...")
frm_canvas, slices, gen_max_l = whole_song_expr.frm_op.create_canvas(n_sample=1, prompt=None)
frm_1 = whole_song_expr.frm_op.generation(frm_canvas, slices, gen_max_l, quantize=False, n_sample=1)
frm_2, lengths, phrase_labels = quantize_generated_form_batch(frm_1)
print(f"Length of the song: {lengths[0]}, phrase_label:\n{phrase_labels[0]}")
frm = frm_2[:, :, 0: lengths[0]]
phrase_string = phrase_labels[0]

Form generation...


Length of the song: 8, phrase_label:
0: i8



In [6]:
## ctp generation

# 加一个prompt
# 先改成C调
# frm[0][:2,:, :] = frm[0][:2,:, [i for i in range(3,12)] + [0, 1, 2, 12, 13, 14, 15]]
# def index_to_one_hot(index, length=128):
#     array = np.zeros(length)
#     if index>=0:
#         array[index] = 1
#     return array
# prompt = np.array([[[index_to_one_hot(60),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(62),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(64),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(60),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(60),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(62),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(64),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(60),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(64),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(65),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(67),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(64),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(65),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(67),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(-1)
#                      ],
#                     [index_to_one_hot(-1),
#                      index_to_one_hot(60),
#                      index_to_one_hot(60),
#                      index_to_one_hot(60),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(62),
#                      index_to_one_hot(62),
#                      index_to_one_hot(62),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(64),
#                      index_to_one_hot(64),
#                      index_to_one_hot(64),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(60),
#                      index_to_one_hot(60),
#                      index_to_one_hot(60),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(60),
#                      index_to_one_hot(60),
#                      index_to_one_hot(60),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(62),
#                      index_to_one_hot(62),
#                      index_to_one_hot(62),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(64),
#                      index_to_one_hot(64),
#                      index_to_one_hot(64),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(60),
#                      index_to_one_hot(60),
#                      index_to_one_hot(60),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(64),
#                      index_to_one_hot(64),
#                      index_to_one_hot(64),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(65),
#                      index_to_one_hot(65),
#                      index_to_one_hot(65),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(64),
#                      index_to_one_hot(64),
#                      index_to_one_hot(64),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(65),
#                      index_to_one_hot(65),
#                      index_to_one_hot(65),
#                      index_to_one_hot(-1),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67),
#                      index_to_one_hot(67)
#                      ]]])

print("Counterpoint generation...")
background_cond = whole_song_expr.ctp_op.expand_background(frm, nbpm)

ctp_canvas, slices, gen_max_l = \
    whole_song_expr.ctp_op.create_canvas(background_cond, n_sample, nbpm, None, whole_song_expr.random_n_autoreg)
print(f"Number of iterations: {len(slices)}")
ctp = whole_song_expr.ctp_op.generation(ctp_canvas, slices, gen_max_l)
ctp = np.stack(ctp, 0)


Counterpoint generation...
Number of iterations: 1


In [7]:
for time in ctp[0][0,:12,:]:
    for note in range(len(time)):
        if time[note]>0:
            print(note, end=" ")
print("\n")
for time in ctp[0][1,:12,:]:
    for note in range(len(time)):
        if time[note]>0:
            print(note, end=" ")

35 39 42 47 25 37 40 44 28 37 40 45 30 37 42 46 

35 39 42 47 35 39 42 47 35 39 42 47 25 37 40 44 25 37 40 44 25 37 40 44 28 37 40 45 30 37 42 46 

In [8]:
ctp.shape

(1, 10, 32, 128)

In [9]:
## Lead Sheet generation
print("Lead Sheet generation...")
background_cond = whole_song_expr.lsh_op.expand_background(ctp, nspb)
lsh_canvas, slices, gen_max_l = \
    whole_song_expr.lsh_op.create_canvas(background_cond, n_sample, nbpm, nspb, None, whole_song_expr.random_n_autoreg)
print(f"Number of iterations: {len(slices)}")
lsh = whole_song_expr.lsh_op.generation(lsh_canvas, slices, gen_max_l)
lsh = np.stack(lsh, 0)

Lead Sheet generation...
Number of iterations: 1


In [10]:
## Accompaniment generation
print("Accompaniment generation...")
acc_canvas, slices, gen_max_l = \
    whole_song_expr.acc_op.create_canvas(lsh, n_sample, nbpm, nspb, None, whole_song_expr.random_n_autoreg)
print(f"Number of iterations: {len(slices)}")
acc = whole_song_expr.acc_op.generation(acc_canvas, slices, gen_max_l)

Accompaniment generation...
Number of iterations: 1


In [11]:
whole_song_expr.output(acc, phrase_string, key, is_major, demo_dir, bpm)

In [18]:
69%12

9

In [14]:
from data_utils.midi_output import piano_roll_to_note_mat, note_mat_to_notes, create_pm_object

piano_roll_to_note_mat(acc[0][0:2], True)

[[34, 52, 1],
 [38, 52, 1],
 [40, 52, 1],
 [114, 52, 1],
 [115, 52, 1],
 [116, 52, 1],
 [118, 52, 1],
 [119, 52, 1],
 [120, 52, 1],
 [41, 54, 1],
 [42, 54, 1],
 [44, 54, 1],
 [46, 54, 1],
 [98, 54, 1],
 [122, 54, 1],
 [64, 56, 2],
 [96, 56, 2],
 [99, 56, 1],
 [124, 56, 2],
 [99, 57, 1],
 [100, 57, 1],
 [102, 57, 1],
 [105, 57, 1],
 [0, 58, 1],
 [0, 59, 2],
 [2, 59, 1],
 [3, 59, 1],
 [4, 59, 1],
 [6, 59, 1],
 [8, 59, 1],
 [9, 59, 1],
 [10, 59, 1],
 [13, 59, 1],
 [14, 59, 1],
 [50, 59, 1],
 [51, 59, 1],
 [52, 59, 1],
 [54, 59, 1],
 [60, 59, 3],
 [70, 59, 1],
 [72, 59, 1],
 [73, 59, 1],
 [76, 59, 3],
 [106, 59, 1],
 [108, 59, 1],
 [110, 59, 1],
 [121, 59, 1],
 [126, 59, 1],
 [18, 49, 1],
 [19, 49, 1],
 [20, 49, 1],
 [25, 49, 1],
 [28, 49, 3],
 [68, 49, 1],
 [80, 49, 1],
 [81, 49, 1],
 [82, 49, 1],
 [83, 49, 1],
 [84, 49, 1],
 [86, 49, 1],
 [92, 49, 1],
 [95, 49, 1],
 [126, 50, 1],
 [32, 52, 1],
 [33, 52, 1],
 [34, 52, 1],
 [35, 52, 1],
 [36, 52, 2],
 [38, 52, 1],
 [112, 52, 1],
 [114, 52,

In [20]:
acc[0][6:8, 0:5, :12]

array([[[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]],

       [[1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1.],
        [1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1.],
        [1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1.],
        [1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1.],
        [1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1.]]])

In [12]:
done

NameError: name 'done' is not defined

In [None]:
ctp[0].shape

(10, 40, 128)

In [None]:
ctp[0].repeat(4,axis=1)[:,32,:]

array([[0. , 0. , 0. , ..., 0. , 0. , 0. ],
       [0. , 0. , 0. , ..., 0. , 0. , 0. ],
       [0. , 0. , 0. , ..., 0. , 0. , 0. ],
       ...,
       [0.8, 0.8, 0.8, ..., 0.8, 0.8, 0.8],
       [0. , 0. , 0. , ..., 0. , 0. , 0. ],
       [0. , 0. , 0. , ..., 0. , 0. , 0. ]])

In [None]:
lsh[0][:,32,:]

array([[0. , 0. , 0. , ..., 0. , 0. , 0. ],
       [0. , 0. , 0. , ..., 0. , 0. , 0. ],
       [0. , 0. , 0. , ..., 0. , 0. , 0. ],
       ...,
       [0.8, 0.8, 0.8, ..., 0.8, 0.8, 0.8],
       [0. , 0. , 0. , ..., 0. , 0. , 0. ],
       [0. , 0. , 0. , ..., 0. , 0. , 0. ]])

In [None]:
acc[0].shape

(14, 160, 128)

In [None]:
lsh[0].shape

(12, 160, 128)

In [None]:
difference = ctp[0].repeat(4,axis=1) != lsh[0][2:]

# Find the indices where the arrays differ
different_indices = np.argwhere(difference)
different_indices

array([[  0,   1,  35],
       [  0,   1,  39],
       [  0,   1,  42],
       [  0,   1,  46],
       [  0,   1,  47],
       [  0,   2,  35],
       [  0,   2,  39],
       [  0,   2,  42],
       [  0,   2,  46],
       [  0,   2,  47],
       [  0,   3,  35],
       [  0,   3,  39],
       [  0,   3,  42],
       [  0,   3,  46],
       [  0,   3,  47],
       [  0,  17,  32],
       [  0,  17,  39],
       [  0,  17,  42],
       [  0,  17,  44],
       [  0,  17,  47],
       [  0,  18,  32],
       [  0,  18,  39],
       [  0,  18,  42],
       [  0,  18,  44],
       [  0,  18,  47],
       [  0,  19,  32],
       [  0,  19,  39],
       [  0,  19,  42],
       [  0,  19,  44],
       [  0,  19,  47],
       [  0,  33,  35],
       [  0,  33,  39],
       [  0,  33,  42],
       [  0,  33,  47],
       [  0,  34,  35],
       [  0,  34,  39],
       [  0,  34,  42],
       [  0,  34,  47],
       [  0,  35,  35],
       [  0,  35,  39],
       [  0,  35,  42],
       [  0,  35

In [None]:
(ctp[0].repeat(4,axis=1) == lsh[0][2:]).all()

False

In [None]:
lsh[0][2:].shape

(10, 160, 128)