In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from rhythmic_complements.data import PairDataset

part_1 = 'Bass'
part_2 = 'Piano'

dataset_config = {
    "dataset_name": "babyslakh_20_1bar_24res",
    "part_1": part_1,
    "part_2": part_2,
    "repr_1": "hits",
    "repr_2": "hits",
}

data = PairDataset(**dataset_config)
loader = DataLoader(data, batch_size=1)

n = 2
N = torch.zeros((n, n), dtype=torch.int32)

for x, y in tqdm(loader):
    x = x.to(torch.int32)
    y = y.to(torch.int32)
    for h1, h2 in zip(x.flatten(), y.flatten()):
        ix1 = h1.item()
        ix2 = h2.item()
        N[ix1, ix2] += 1

P = (N+1).float()
P /= P.sum(1, keepdims=True)

plt.figure(figsize=(4,4))
plt.imshow(P, cmap='Blues')
for i in range(n):
    for j in range(n):
        chstr = str(i) + str(j)
        plt.text(j, i, chstr, ha="center", va="bottom", color='gray')
        plt.text(j, i, round(P[i, j].item(), 3), ha="center", va="top", color='gray')
plt.axis('off');

In [None]:
import wave
from base64 import b64encode
from io import BytesIO
from IPython.display import Audio, HTML, display

import numpy as np
from chord_progressions.io.audio import mk_arpeggiated_chord_buffer, combine_buffers

SAMPLE_RATE = 44100


def mk_wav(arr):
    """Transform a numpy array to a PCM bytestring
    Adapted from https://github.com/ipython/ipython/blob/main/IPython/lib/display.py#L146
    """
    scaled = arr * 32767
    scaled = scaled.astype("<h").tobytes()

    fp = BytesIO()
    waveobj = wave.open(fp,mode='wb')
    waveobj.setnchannels(1)
    waveobj.setframerate(SAMPLE_RATE)
    waveobj.setsampwidth(2)
    waveobj.setcomptype('NONE','NONE')
    waveobj.writeframes(scaled)
    val = fp.getvalue()
    waveobj.close()

    return val

def get_audio_el(audio):
    wav = mk_wav(audio)
    b64 = b64encode(wav).decode('ascii')
    return f'<audio controls="controls"><source src="data:audio/wav;base64,{b64}" type="audio/wav"/></audio>'


g = torch.Generator().manual_seed(15987348311)

n_samples = 5
pattern_len = len(next(iter(loader))[0][0])

audio_duration = 3 # seconds
n_overtones = 0

input_hits = [1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0]
inbuff = mk_arpeggiated_chord_buffer(['C4'], audio_duration, [input_hits], n_overtones)
inhitstr = ''.join(map(str, input_hits))

buffs = {}
for i in range(n_samples):
    out = []
    for el in input_hits:
        p = P[el]
#         p = torch.ones(2) / 2. # uniform
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(ix)

    hbuff = mk_arpeggiated_chord_buffer(['G4'], audio_duration, [out], n_overtones)
    buffs[''.join(map(str, out))] = combine_buffers([inbuff, hbuff])
    
html = ''.join([f"<p>{get_audio_el(v)} <p>{inhitstr} {part_1} input</br>{k} {part_2} prediction</p></p>" for k,v in buffs.items()])

HTML(html)