In [None]:
import torch
from dataset_tool import compute_loudness, compute_centroid
from IPython.display import Audio
import pickle
import librosa as li
from noisebandnet.model import NoiseBandNet
import torch.nn.functional as F

In [None]:
def load_audio(path, fs, max_len, norm=True):
    x = li.load(path, sr=fs, mono=True)[0]
    if max_len > 0:
        if len(x)>max_len:
            x = x[:max_len]
    if norm:
        x = li.util.normalize(x)
    return x

In [None]:
device = 'cuda'

TRAIN_PATH = 'trained_models/metal'
MODEL_PATH = f'{TRAIN_PATH}/model_10000.ckpt'
CONFIG_PATH = f'{TRAIN_PATH}/config.pickle'

#path to the training data used to train de model
AUDIO_PATH = 'training_data/metal.wav'

with (open(CONFIG_PATH, "rb")) as f:
    config = pickle.load(f)
FS = config.sampling_rate

x_audio = load_audio(path=AUDIO_PATH, fs=FS, max_len=2**19)
x_audio = torch.from_numpy(x_audio).unsqueeze(0)
Audio(x_audio[0], rate=FS)

In [None]:
#This example works for models trained with loudness or loudness and centroid.
# For user-defined control parameters you need to load them manually.
if len(config.auto_control_params) != 2:
    if config.auto_control_params == "loudness":
        loudness, _, _ = compute_loudness(audio_data=x_audio, sampling_rate=FS)
        loudness = loudness.unsqueeze(0).float()
        loudness = F.interpolate(input=loudness, scale_factor=1/config.synth_window, mode='linear').permute(0,2,1).float()
        control_params = [loudness.to(device)]
    if config.auto_control_params == "centroid":
        centroid, _, _ = compute_centroid(audio_data=x_audio, sampling_rate=FS)
        centroid = centroid.unsqueeze(0).float()
        centroid = F.interpolate(input=centroid, scale_factor=1/config.synth_window, mode='linear').permute(0,2,1).float()
        control_params = [centroid.to(device)]
else:
    control_params = []
    loudness, _, _ = compute_loudness(audio_data=x_audio, sampling_rate=FS)
    loudness = loudness.unsqueeze(0).float()
    loudness = F.interpolate(input=loudness, scale_factor=1/config.synth_window, mode='linear').permute(0,2,1).float()
    control_params.append(loudness)
    
    centroid, _, _ = compute_centroid(audio_data=x_audio, sampling_rate=FS)
    centroid = centroid.unsqueeze(0).float()
    centroid = F.interpolate(input=centroid, scale_factor=1/config.synth_window, mode='linear').permute(0,2,1).float()
    control_params.append(centroid)

In [None]:
synth = NoiseBandNet(hidden_size=config.hidden_size, n_band=config.n_band, synth_window=config.synth_window, n_control_params=config.n_control_params).to(device).float()

In [None]:
synth.load_state_dict(torch.load(MODEL_PATH))

## Stereo generation

In [None]:
y_audio = []
for i in range(2):
    with torch.no_grad():
        y_audio.append(synth.forward_random(control_params=control_params, frame_len=control_params[0].shape[1], frequency_shifts=0, k_amplitudes=10, k_low_mult=0.95, k_high_mult=1.15, init_f_shifts=0))
y_audio = torch.cat(y_audio).permute(1,0,2)
Audio(y_audio[0].detach().cpu().numpy(), rate=FS)

## Amplitude randomisation

In [None]:
audio_len = control_params[0].shape[1]
audio_chunks = 2
frame_len = audio_len//audio_chunks

with torch.no_grad():
    y_audio = synth.forward_random(control_params=control_params, frame_len=frame_len, frequency_shifts=1, k_amplitudes=100, k_low_mult=0., k_high_mult=1., init_f_shifts=10)
Audio(y_audio[0].detach().cpu().numpy(), rate=FS)