In [1]:
import numpy as np
import torch
import pretty_midi as pm
import matplotlib.pyplot as plt
import os

import IPython.display     # IPython's display module (for in-line audio)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


In [88]:
from model_tr_full_ultra import VAETR

ptvae = VAETR(device=device).to(device)

# ptvae.load_state_dict(torch.load('runs_my_decoder/run_2024-04-24_16-33-11_256_0.0002_6/model_tr_encoder/pgrid-epoch-model.pt', map_location=device))
ptvae.load_state_dict(torch.load('runs/our_32_model.pt', map_location=device))
ptvae.eval();

# ptvae = VAE.init_model().to(device)

# ptvae.load_model("model_decoder.pt")

In [84]:
from model import VAE

ptvae = VAE(device=device).to(device)

ptvae.load_state_dict(torch.load('runs/model_decoder.pt', map_location=device))
ptvae.eval();

In [28]:
pitch_outs, dur_outs, dist = ptvae(torch.rand(2, 32, 16, 6).long().to(device))

dist

Normal(loc: torch.Size([2, 512]), scale: torch.Size([2, 512]))

In [9]:
x1 = np.load('sample2.npz')

pr_mat1= x1['pr_mat']
pianotree1 = x1['ptree']
c1 = x1['c']

pt1_t = torch.from_numpy(pianotree1).unsqueeze(0).long().to(device)

print(pt1_t.shape)

torch.Size([1, 32, 16, 6])


In [83]:
from dataset import PolyphonicDataset

dataset = PolyphonicDataset("pop909+mlpv_t32_fix1/pop909+mlpv_t32_val_fix1.npy", 0, 0)
# dataset = PolyphonicDataset("new_data/pop909+mlpv_t128_val.npy", 0, 0)
print(len(dataset))

# dataset[0].shape

pianotree1 = dataset[717]

pt1_t = torch.from_numpy(pianotree1).unsqueeze(0).long().to(device)

print(pt1_t.shape)

22843
torch.Size([1, 32, 16, 6])


In [90]:
# reconstruction of pt1
import time 

start = time.time()
pitch_outs, dur_outs, dist = ptvae(pt1_t, inference=True)
print((time.time() - start) * 1000)

est_pitch = pitch_outs.max(-1)[1].unsqueeze(-1)
est_dur = dur_outs.max(-1)[1]
recon = torch.cat([est_pitch, est_dur], dim=-1)

sos_line = torch.zeros(recon.size(0), ptvae.num_step, 1, 6, dtype=recon.dtype, device=recon.device)
sos_line[:, :, :, 0] = ptvae.pitch_sos
sos_line[:, :, :, 1:] = ptvae.dur_pad
recon = torch.cat([sos_line, recon], dim=2)

recon = recon.squeeze(0).cpu().numpy()
recon.shape

52.82783508300781


(32, 16, 6)

In [91]:
i = 0

print('input:')
print(pianotree1[i])
# print('recon:')
print(recon[i])

input:
[[128   2   2   2   2   2]
 [ 46   0   0   1   1   1]
 [ 70   0   0   1   1   1]
 [129   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]
 [130   2   2   2   2   2]]
[[128   2   2   2   2   2]
 [ 46   0   0   1   1   1]
 [ 70   0   0   1   1   1]
 [129   0   0   1   1   1]
 [129   0   0   1   1   1]
 [129   0   0   1   1   1]
 [129   0   0   1   1   1]
 [129   0   0   1   1   1]
 [129   0   0   1   1   1]
 [129   0   0   1   1   1]
 [ 39   0   0   1   1   1]
 [ 68   0   0   1   1   1]
 [ 68   0   0   1   0   1]
 [ 68   0   0   1   1   1]
 [ 39   0   0   1   0   1]
 [ 39   0   0   1   0   1]]


In [92]:
def pianotree_to_notes(pno_tree, bpm=60., start=0.):
    alpha = 0.25 * 60 / bpm
    notes = []
    for t in range(32):
        for n in range(1, pno_tree.shape[1]):
            note = pno_tree[t, n]
            if note[0] == 129:
                break
            pitch = note[0]
            dur = int(''.join([str(_) for _ in note[1:]]), 2) + 1
            notes.append(
                pm.Note(100, int(pitch), 
                                 start + t * alpha, start + (t + dur) * alpha))
    return notes


bpm = 90

# notes1 = pianotree_to_notes(pianotree1, bpm=bpm)
notes1 = pianotree_to_notes(recon, bpm=bpm)
# notes1_recon = pianotree_to_notes(recon, bpm=bpm, start=60 / bpm * 12)

midi = pm.PrettyMIDI(initial_tempo=bpm)
ins1 = pm.Instrument(0, name='ground truth')
ins2 = pm.Instrument(0, name='reconstruction')
ins1.notes = notes1
# ins2.notes = notes1_recon
# midi.instruments = [ins1, ins2]
midi.instruments = [ins1]

wave = midi.fluidsynth(fs = 44100.0, sf2_path = './FluidR3_GM.sf2')
# wave = midi.synthesize(fs = 44100.0)
IPython.display.display(IPython.display.Audio(data=wave, rate=44100))

import soundfile as sf
wave_int32 = (wave * (2**31)).astype(np.int32)

sf.write('samples/sample_32_4_ours.mp3', wave_int32, 44100)

In [80]:
def interp_path(z1, z2, interpolation_count=10):
    result_shape = z1.shape
    z1 = z1.reshape(-1)
    z2 = z2.reshape(-1)

    def slerp2(p0, p1, t):
        omega = np.arccos(
            np.dot(p0 / np.linalg.norm(p0), p1 / np.linalg.norm(p1)))
        so = np.sin(omega)
        return np.sin((1.0 - t) * omega)[:, None] / so * p0[None] + np.sin(
            t * omega)[:, None] / so * p1[None]

    percentages = np.linspace(0.0, 1.0, interpolation_count)

    normalized_z1 = z1 / np.linalg.norm(z1)
    normalized_z2 = z2 / np.linalg.norm(z2)

    dirs = slerp2(normalized_z1, normalized_z2, percentages)
    length = np.linspace(np.log(np.linalg.norm(z1)),
                         np.log(np.linalg.norm(z2)), interpolation_count)
    return (dirs * np.exp(length[:, None])).reshape(
        [interpolation_count] + list(result_shape))


# x1 = np.load('sample1.npz')
# pianotree1 = x1['ptree']
# pt1_t = torch.from_numpy(pianotree1).unsqueeze(0).long().to(device)

# x2 = np.load('sample2.npz')
# pianotree2 = x2['ptree']
# pt2_t = torch.from_numpy(pianotree2).unsqueeze(0).long().to(device)

pianotree1 = dataset[777]
pt1_t = torch.from_numpy(pianotree1).unsqueeze(0).long().to(device)

pianotree2 = dataset[717]
pt2_t = torch.from_numpy(pianotree2).unsqueeze(0).long().to(device)

with torch.no_grad():
    _, _, dist_1 = ptvae(pt1_t)
    _, _, dist_2 = ptvae(pt2_t)

z1 = dist_1.mean
z2 = dist_2.mean

z_path = interp_path(z1.cpu().numpy().squeeze(0), z2.cpu().numpy().squeeze(0), 10)
z_path = torch.from_numpy(z_path).to(z1.device).float()
z_path.size()

torch.Size([10, 512])

In [81]:
# with torch.no_grad():
#     pitch_outs, dur_outs = ptvae.decoder(z_path)
#     # pitch_outs = pitch_outs.cpu().numpy()
#     # dur_outs = dur_outs.cpu().numpy()

# est_pitch = pitch_outs.max(-1)[1].unsqueeze(-1)
# est_dur = dur_outs.max(-1)[1]
# recon = torch.cat([est_pitch, est_dur], dim=-1)

# sos_line = torch.zeros(recon.size(0), ptvae.num_step, 1, 6, dtype=recon.dtype, device=recon.device)
# sos_line[:, :, :, 0] = ptvae.pitch_sos
# sos_line[:, :, :, 1:] = ptvae.dur_pad
# recon = torch.cat([sos_line, recon], dim=2)

with torch.no_grad():
    recon = ptvae.decoder(z_path)

recon = recon.cpu().numpy()
recon.shape

(10, 32, 16, 6)

In [82]:
recon_path_notes = [pianotree_to_notes(pt, bpm, start=i * 8 * 60 / bpm) for i, pt in enumerate(recon)]
recon_path_notes = [note for notes in recon_path_notes for note in notes]

midi = pm.PrettyMIDI(initial_tempo=bpm)
ins = pm.Instrument(0, name='interpolation')
ins.notes = recon_path_notes
midi.instruments = [ins]

wave = midi.fluidsynth(fs = 44100.0, sf2_path = './FluidR3_GM.sf2')
# wave = midi.synthesize(fs = 44100.0)
IPython.display.display(IPython.display.Audio(data=wave, rate=44100))

import soundfile as sf
wave_int32 = (wave * (2**31)).astype(np.int32)

sf.write('samples/inter_32_2_gru.mp3', wave_int32, 44100)