In [None]:
import argparse
import sys

sys.path.append('/home/ubuntu/BigVGAN/')
sys.path.append('/home/ubuntu/music-spectrogram-diffusion-pytorch/')

import soundfile as sf
import yaml
from importlib import import_module
import note_seq
from preprocessor.event_codec import Codec
from preprocessor.preprocessor import preprocess
from tqdm import tqdm

import matplotlib.pyplot as plt

import sys
import os


os.chdir('/home/ubuntu/BigVGAN/')

import torch

from bigvgan import BigVGAN

os.chdir('/home/ubuntu/music-spectrogram-diffusion-pytorch/')
@torch.inference_mode()
def diff_main(model, tokens, segment_length, spec_frames, with_context, T=1000, verbose=True):
    """
    model: your transformer/diffusion spec-predictor
    tokens: iterable of token tensors
    segment_length: length of zero context in samples
    spec_frames: number of frames to predict per step
    with_context: bool, whether to carry mel_context
    """
    # --- Assume 'device' and 'vocoder' are defined in the outer scope ---
    output_specs = []
    zero_wav_context = torch.zeros(1, segment_length, device=device) if with_context else None
    mel_context = None

    model.scheduler.set_timesteps(T)

    # 2. Generate token-conditioned mel-specs
    for token in tqdm(tokens, disable=not verbose):
        x = token.unsqueeze(0).to(device)

        # pick which context arg to send
        if len(output_specs) == 0 and with_context:
            pred = model(
                x,
                seq_length=spec_frames,
                wav_context=zero_wav_context,
                rescale=False,
                T=T,
                verbose=verbose
            )
        else:
            pred = model(
                x,
                seq_length=spec_frames,
                mel_context=mel_context,
                rescale=False,
                T=T,
                verbose=verbose
            )

        output_specs.append(pred)
        mel_context = pred if with_context else None

    print("Pred shape:", pred.shape)
    print("Pred range:", pred.min().item(), pred.max().item())

    # stitch time-axis
    output_tensor = torch.cat(output_specs, dim=1)              # [1, total_frames, n_mels]

    # --- CRITICAL FIX HERE ---
    # Transpose to match MelToDB.reverse expected input shape [B, n_mels, T]
    output_tensor = output_tensor.transpose(1, 2)              # [1, n_mels, total_frames]

    # --- CRITICAL CHANGE HERE ---
    # Use the model's MelToDB `reverse` method to convert the diffusion model's [-1, 1] output
    # back to the log-magnitude spectrogram format that the BigVGAN vocoder expects.
    # The updated `reverse` function handles the denormalization from [-1,1] to the log scale.
    log_mels_for_vocoder = model.mel.reverse(output_tensor)     # [1, n_mels, total_frames]
    # This `log_mels_for_vocoder` is now in the format: log(clamp(magnitude, min=clip_val))
    # which matches BigVGAN's `spectral_normalize_torch` output/input requirement.

    # Ensure correct dtype for BigVGAN (shape is already [B, n_mels, T])
    log_mels_for_vocoder = log_mels_for_vocoder.to(torch.float32) # Ensure float32

    # Optional: Add explicit clamp for numerical stability based on BigVGAN training stats
    # Typical log(magnitude) values are between -20 and 5. Adjust if needed based on observations.
    # log_mels_for_vocoder = torch.clamp(log_mels_for_vocoder, min=-30.0, max=10.0)

    # --- Generate waveform with BigVGAN ---
    with torch.inference_mode():
        wav_gen = vocoder(log_mels_for_vocoder)                 # [1, 1, T_audio]

    # --- Process output ---
    wav_gen_float = wav_gen.squeeze(0).squeeze(0).cpu()         # [T_audio]

    # BigVGAN typically outputs in [-1, 1] range, but let's be safe and clamp
    # Clamp to [-1, 1] range to prevent any potential clipping artifacts
    wav_gen_float = torch.clamp(wav_gen_float, -1.0, 1.0)

    # Convert to 16-bit PCM
    wav_int16 = (wav_gen_float * 32767.0).clamp(-32768, 32767).cpu().numpy().astype('int16')

    # Return the final audio and optionally the log-mel spectrogram used for vocoding
    # wav_int16 is np.ndarray with shape [T_audio] and int16 dtype
    # log_mels_for_vocoder.cpu().numpy()[0] is the log-mel spectrogram fed to the vocoder
    return wav_int16, log_mels_for_vocoder.cpu().numpy()[0], output_specs

    # If you want to return the intermediate linear magnitude spectrogram for analysis:
    # linear_mels = torch.exp(log_mels_for_vocoder)
    # return wav_int16, linear_mels.cpu().numpy()[0]

#============================================================================

%cd /home/ubuntu/music-spectrogram-diffusion-pytorch

config = '/home/ubuntu/music-spectrogram-diffusion-pytorch/cfg/diff_base_44k.yaml'
midi = '/home/ubuntu/music-spectrogram-diffusion-pytorch/test.mid'
ckpt = '/home/ubuntu/last.ckpt'

output = '/home/ubuntu/output.wav'

with open(config) as f:
    config = yaml.safe_load(f)

model_configs = config['model']

module_path, class_name = model_configs['class_path'].rsplit('.', 1)
module = import_module(module_path)
model = getattr(module, class_name).load_from_checkpoint(
    ckpt, **model_configs['init_args'])
model = model.cuda()
model.eval()

hop_length = model_configs['init_args']['hop_length']
n_mels = model_configs['init_args']['n_mels']
data_configs = config['data']
sr = data_configs['init_args']['sample_rate']
segment_length = data_configs['init_args']['segment_length']
spec_frames = segment_length // hop_length
resolution = 100
segment_length_in_time = segment_length / sr
codec = Codec(int(segment_length_in_time * resolution + 1))

with_context = data_configs['init_args']['with_context'] and model_configs['init_args']['with_context']

ns = note_seq.midi_file_to_note_sequence(midi)
ns = note_seq.apply_sustain_control_changes(ns)
tokens, _ = preprocess(ns, codec=codec)

repo_id = "nvidia/bigvgan_v2_44khz_128band_512x"
device = 'cuda'

vocoder = BigVGAN.from_pretrained(
        repo_id,
        use_cuda_kernel=False,
)

vocoder.h['fmax'] = 22050
vocoder.h["num_freq"]   = 1025

vocoder.cuda()
vocoder.eval()
vocoder.remove_weight_norm()

pred, db_mels, output_specs = diff_main(model, tokens, segment_length,
                 spec_frames, with_context)

sf.write(output, pred, sr)

plt.figure(figsize=(12, 4))
plt.imshow(db_mels,
           aspect='auto',
           origin='lower',
           interpolation='nearest',
           cmap='magma')
plt.colorbar(label='Amplitude')
plt.xlabel('Frame Index')
plt.ylabel('Mel Bin Index')
plt.title('Predicted Mel-Spectrogram')
plt.tight_layout()

plt.savefig('/home/ubuntu/mel_plot.png', dpi=150)
plt.close()  # free memory / avoid overplotting

In [None]:
with_context

In [None]:
tokens

In [None]:
# stitch time-axis
output_tensor = torch.cat(output_specs, dim=1)              # [1, total_frames, n_mels]

# --- CRITICAL FIX HERE ---
# Transpose to match MelToDB.reverse expected input shape [B, n_mels, T]
output_tensor = output_tensor.transpose(1, 2)              # [1, n_mels, total_frames]

#log_mels_for_vocoder = model.mel.reverse(output_tensor)

In [None]:
output_tensor.shape

In [None]:
min_val = -11.512925 #output_tensor.min()
max_val = 2.5241776 #output_tensor.max()
normalized_0_to_1 = (output_tensor - min_val) / (max_val - min_val)
output_tensor_normalized = normalized_0_to_1 * 2 - 1

In [None]:
output_tensor_normalized = (output_tensor + 1.0) / 2.0 * (max_val - min_val) + min_val

In [None]:
plt.figure(figsize=(12, 4))
plt.imshow(output_tensor_normalized[0].cpu(),
           aspect='auto',
           origin='lower',
           interpolation='nearest',
           cmap='magma')
plt.colorbar(label='Amplitude')
plt.xlabel('Frame Index')
plt.ylabel('Mel Bin Index')
plt.title('Predicted Mel-Spectrogram')
plt.tight_layout()

plt.savefig('/home/ubuntu/mel_plot.png', dpi=150)
plt.close()  # free memory / avoid overplotting

In [None]:
output_tensor[0][64]

In [None]:
# After concatenation
print("Final output_tensor range (before reverse):", output_tensor.min().item(), "to", output_tensor.max().item())

In [None]:
print("Pred shape:", pred.shape)
print("Pred range:", pred.min().item(), pred.max().item())

In [None]:
vocoder.h

In [None]:
# Test with known inputs to see expected range
import torch

# Create test mels in different ranges
test_mel_normalized = torch.randn(1, 128, 100).clamp(-1, 1)  # [-1, 1] range
test_mel_db = torch.randn(1, 128, 100) * 20 - 40  # [-60, 20] dB range  
test_mel_linear = torch.rand(1, 128, 100) * 10  # Linear magnitude

try:
    with torch.inference_mode():
        out1 = vocoder(test_mel_normalized.cuda())
        out2 = vocoder(test_mel_db.cuda())  
        out3 = vocoder(test_mel_linear.cuda())
    print("All inputs work - check output quality")
except Exception as e:
    print(f"Error with certain input types: {e}")

In [None]:
import matplotlib.pyplot as plt

t = torch.linspace(0,1,1000)
log_snr = model.get_log_snr(t)
plt.plot(t.cpu(), log_snr.cpu())
plt.title("Cosine log-SNR schedule")
plt.xlabel("t"); plt.ylabel("log SNR")
plt.show()
