## 11 Pitch Consistency

Explores simple question: Does generated audio match the target pitch?

In [3]:
import numpy as np 
import torch 
import torch.nn.functional as F
import pandas as pd 
import sys 
import os 
import yaml 
import time
from IPython.display import Audio, display
import torchcrepe
from dotenv import load_dotenv

dotenv_path = '/home/robbizorg/classes/RT_MusicGen'
load_dotenv(dotenv_path=dotenv_path)

music_path = os.getenv("music_path")
sf_path = os.getenv('sf_path')
sample_path = os.getenv('sample_path')

sys.path.append('../../')
from src.spectral_ops import ISTFT, STFT
from src.models import Vocos
from src.encoder import TimbreEncoder
from src.dataset import Midi_Seg, train_collate_fn

In [4]:
# Load in Base Config'
yaml_name = 'midi_vocos_1st.yaml'
with open('../../yamls/' + yaml_name, "r") as stream:
    try:
        config = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

device = 'cpu' # Running all Tests on the CPU 
sample_rate = config['sample_rate']
buffer_size = config['buffer_size']
prev_ratio = config.get('prev_ratio', 2.0)
comment = config['comment']
ckpt_path = os.path.join('../../ckpt', comment)

vocos_config = config['vocos_config']   

# Load Models
model = Vocos(vocos_config).to(device)
model.load_state_dict(torch.load(os.path.join(ckpt_path, '30', 'VocosSynth.pth'), map_location=device))
model.eval()

timbre_config = config['timbre_config']  
tmbr_encoder = TimbreEncoder(timbre_config).to(device)
tmbr_encoder.load_state_dict(torch.load(os.path.join(ckpt_path, '30', 'VocosTimbre.pth'), map_location=device))
tmbr_encoder.eval()

# Init Transform
stft_transform = STFT(
    n_fft=vocos_config['head']['n_fft'],
    hop_length=vocos_config['head']['hop_length'],
    win_length=vocos_config['head']['n_fft']
).to(device)


# Load Dataset
train_dataset = Midi_Seg(sf_path = sample_path, 
    sr = sample_rate, 
    buffer_size = buffer_size, 
    prev_ratio = prev_ratio)

In [None]:
idxes = np.random.choice(list(range(len(train_dataset))), size = 100)
tgt_pitches = []
est_pitches = []
gen_times = []
for idx in idxes: 
    ex_path = train_dataset.files[idx]
    print(f'Loading {ex_path}')


    x_raw, pitch, _, _ = train_dataset.__getitem__(idx)
    x_raw = x_raw.unsqueeze(0)
    pitch = pitch.unsqueeze(0)

    # Where idx sampling can begin and end 
    start_idx = int(buffer_size * prev_ratio)
    end_idx = int(x_raw.shape[-1] - buffer_size)

    # Pad by Prev Ratio
    pad_len = int(buffer_size * prev_ratio)
    x = F.pad(x_raw, pad=(pad_len, 0))

    # Get Beginning (should be just 0s)
    prev_x = x[:, start_idx - pad_len : start_idx].float() # Previous Info

    timbre_x = x[:, start_idx : start_idx + sample_rate].float()

    # Process Timbre 
    timbre_spec = stft_transform(timbre_x)
    timbre_emb = tmbr_encoder(timbre_spec)

    # Process Inputs and Estimate next step
    all_out = torch.tensor([])

    # Generate at least 4s of Audio
    while all_out.shape[-1] < sample_rate * 4:
        start_time = time.time()
        prev_spec = stft_transform(prev_x)

        if len(pitch.shape) != 3:
            pitch_review = pitch[:, None, None].repeat(1, 1, prev_spec.shape[-1]).float().to(device)


        in_feats = torch.cat([pitch_review, prev_spec], dim = 1)

        with torch.no_grad():
            out = model(in_feats, timbre_emb=timbre_emb)

        # Match Time
        x_hat = out[:, :buffer_size]
        end_time = time.time()
        gen_times.append(end_time - start_time)

        # Move buffer
        prev_x = torch.cat([prev_x[:, buffer_size:], x_hat], dim = 1)
        # Append to all out
        all_out = torch.cat([all_out, x_hat], dim = 1)
     
    # Calculate Pitch and see if matches
    est_pitch_trace = torchcrepe.predict(all_out.to('gpu'),
                            sample_rate,
                            sample_rate // 200,
                            25, # Min Midi is 21
                            4500, # Max Midi is 109
                            'tiny', # Use Tiny Model for speed
                            batch_size=1,
                            device='gpu') # Run on GPU for speed
    
    est_pitch = est_pitch_trace.mean().cpu().numpy()

    tgt_pitches.append(pitch[0])
    est_pitches.append(est_pitch)

tgt_pitches = np.array(tgt_pitches)
est_pitches = np.array(est_pitches)

avg_err = np.abs(tgt_pitches - est_pitches).mean() 
print(f'Average Absolute Midi Error: {avg_err}')

Loading /data/robbizorg/music/samples/DSoundfont_Ultimate/T9-Snare-Drum/57_80.wav
Loading /data/robbizorg/music/samples/Touhou/Timpani/39_80.wav


KeyboardInterrupt: 

In [None]:
print(f'Average Generation Time: {round(np.mean(gen_times))*1000}ms')