In [None]:
%matplotlib widget
from pathlib import Path
from typing import Union
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader

from GANsynth_pytorch import NSynth

import crepe
from crepe.torch_backend import CREPE, build_and_load_model

def expand_path(p: Union[str, Path]) -> Path:
    return Path(p).expanduser().absolute()

In [None]:
NSynth_valid_json_data_path = expand_path(
    '../../data/nsynth/valid/json_wav/examples.json')
NSynth_valid_audio_path = expand_path(
    '../../data/nsynth/valid/json_wav/audio/')

dataset_valid = NSynth(NSynth_valid_audio_path,
                       NSynth_valid_json_data_path,
                       categorical_field_list=['pitch'],
                       valid_pitch_range=[60, 84],
                       return_full_metadata=False,
                       label_encode_categorical_data=False)
valid_dataloader = DataLoader(dataset_valid, batch_size=5, shuffle=True)

# model = CREPE(model_capacity='tiny')
model = build_and_load_model('full')
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model.zero_grad()

In [None]:
samples, pitches_midi = next(iter(valid_dataloader))
samples = samples.to(device)
pitches_midi = pitches_midi.to(device)
print(pitches_midi)
pitches_hz = model.data_helper.midi_to_hz(pitches_midi)
print(pitches_hz)
pitches_cents = model.data_helper.hertz_to_cents(pitches_hz)
print(pitches_cents)
pitches_cents_bins = model.data_helper.cents_to_bins(pitches_cents)
print(pitches_cents_bins.shape)

plt.close("all")
plt.figure()
frequencies_hz = model.data_helper.cents_to_hz(
    model.data_helper._to_local_average_cents_matrix.cpu())
plt.plot(frequencies_hz,
         pitches_cents_bins.cpu().numpy()[0])
plt.show()

In [None]:
frames = model.data_helper.get_frames(samples)
print(samples.shape)
print(frames.shape)

In [None]:
# obtain a new, random batch of samples 
samples, pitches_midi = next(iter(valid_dataloader))
# convert the pitches of each sample to hertz to obtain the ground truth target
target_hz = model.data_helper.midi_to_hz(pitches_midi).cpu().numpy()

# disable gradient computation for evaluation
with torch.no_grad():
    # batched prediction on the samples
    model.eval()
    logits = model.forward_audio(samples.to('cuda:0'))
    activation = torch.sigmoid(logits)

# clear previous figure
plt.close()
# declare new figure, 2 rows, one for frequencies and one for confidences
batch_size, num_frames = activation.shape[:2]
fig, axes = plt.subplots(4, batch_size, sharex='col', sharey='row', figsize=(10,6))
fig.suptitle(f"CREPE ('{model.model_capacity}' capacity) detected frequencies -- random NSynth samples")

# convert activations to frequencies and confidence levels
time, frequencies, confidence = [
    t.cpu().numpy() for t in model.data_helper.interpret_activation(activation)]

# iterate over the batch of samples
for i, (sample_time, sample_frequencies, sample_confidence) in enumerate(zip(time, frequencies, confidence)):
    # TF implementation predictions
    sample_time_tf, sample_frequencies_tf, sample_confidence_tf, sample_activation_tf = crepe.predict(
        samples[i].cpu().numpy(), sr=model.fs_hz, model_capacity=model.model_capacity,
        step_size=model.hop_length_s*1000, backend='tf'
    )
    # plot the detected frequencies
    axes[0, i].plot(sample_time, sample_frequencies, color='cornflowerblue')
    axes[1, i].plot(sample_time, sample_frequencies_tf, color='yellowgreen')
    # add the target as a horizontal line
    axes[0, i].plot(sample_time, [target_hz[i]] * num_frames, color='red', linestyle='dotted',)    
    axes[1, i].plot(sample_time, [target_hz[i]] * num_frames, color='red', linestyle='dotted',)    
    # plot the confidence levels on the second row
    axes[2, i].plot(sample_time, sample_confidence, color='cornflowerblue')
    axes[3, i].plot(sample_time, sample_confidence_tf, color='yellowgreen')

    
for i, ax in enumerate(axes.flat[:batch_size]):
    ax.set_title(f'Sample {i}')
    
for i, ax in enumerate(axes.flat[:2*batch_size]):
    ax.set(ylabel=f'Frequency (Hz)\n' + ('(PyTorch)' if i < batch_size else '(TensorFlow)'))

for i, ax in enumerate(axes.flat[2*batch_size:]):
    ax.set(ylabel=f'Confidence\n'  + ('(PyTorch)' if i < batch_size else '(TensorFlow)'))

# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axes.flat:
    ax.set(xlabel='Time (s)')
    ax.label_outer()

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
with torch.no_grad():
    # obtain a new, random batch of samples 
    samples, pitches_midi = next(iter(valid_dataloader))

    # batched prediction on the samples
    model.eval()
    logits = model.forward_audio(samples.to('cuda:0'))
    activation = torch.sigmoid(logits)

    targets = model.data_helper.make_targets(samples, pitches_midi.to('cuda:0'))
    print(activation.shape)
    print(targets.argmax(-1).float().std())
    loss = criterion(logits, targets)
    print(loss)

plt.close()
plt.figure()
bin_frequencies_hz = model.data_helper.cents_to_hz(model.data_helper._to_local_average_cents_matrix).cpu()
for i, (predicted, target) in enumerate(zip(activation[:, 200], targets[:, 200])):
    plt.plot(bin_frequencies_hz, predicted.cpu(), label=f'predicted, sample {i}')
    plt.plot(bin_frequencies_hz, target.cpu(), label=f'target, sample {i}', linestyle='dotted')

plt.legend()
plt.show()