In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from source.network.ravepqmf import PQMF
import librosa
from source.network.metrics import spectral_distance, multiscale_stft
from matplotlib.colors import LogNorm

def load_audio(file_path, sr=44100):
    audio, _ = librosa.load(file_path, sr=sr)
    return torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(0)

def plot_waveform(signal, title):
    plt.figure(figsize=(12, 4))
    plt.plot(signal.squeeze().numpy())
    plt.title(title)
    plt.xlabel('Sample')
    plt.ylabel('Amplitude')
    plt.show()

def plot_spectrogram(signal, sr, title):
    D = librosa.stft(signal.squeeze().numpy())
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    plt.figure(figsize=(12, 4))
    librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='hz')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.show()

def plot_frequency_domain(signal, sample_rate=44100, title=''):
    fft = np.fft.fft(signal.squeeze().numpy())
    freqs = np.fft.fftfreq(len(fft), 1/sample_rate)
    plt.figure(figsize=(12, 4))
    plt.xscale("log")
    plt.plot(freqs, np.abs(fft))
    plt.title(title)
    
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude')
    plt.xscale('log')
    plt.show()

def plot_spectral_distance(x_stfts, y_stfts, scales):
    for i, scale in enumerate(scales):
        plt.figure(figsize=(12, 4))
        plt.imshow(np.abs(y_stfts[i][0] - x_stfts[i][0]).numpy(), aspect='auto', cmap='viridis', norm=LogNorm(1e-9,1))
        plt.gca().invert_yaxis()
        plt.title(f'Spectral Distance (scale {scale})')
        plt.colorbar(label='Magnitude')
        plt.xlabel('Time Frame')
        plt.ylabel('Frequency Bin')
        plt.show()

def plot_spectrum(stft):
    plt.figure(figsize=(12, 4))
    plt.imshow(np.abs(stft[0]).numpy().T, aspect='auto', cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(label='Magnitude')
    plt.xlabel('Time Frame')
    plt.ylabel('Frequency Bin')
    plt.show()

def lin_distance(x, y):
    return torch.norm(x - y) / torch.norm(x)

def log_distance(x, y):
    return abs(torch.log(x + 1e-7) - torch.log(y + 1e-7)).mean()



## Spectral Distance

In [None]:
signal = load_audio("80s Beat 90 bpm_dry.wav").squeeze()
wet= load_audio("80s Beat 90 bpm_wet.wav").squeeze()


nfft = 1024
window = np.hanning(nfft)
sampling_rate = 44100
n_windows = 5
hop_length = nfft // 2

# Calculate and plot the spectrogram
plt.figure(figsize=(12, 8))
plt.specgram(signal, NFFT=nfft, Fs=sampling_rate, noverlap=hop_length, cmap='viridis')
plt.title('Spectrogram of 80s Beat')
plt.xlabel('Time')
plt.ylabel('Frequency')
plt.colorbar(label='Amplitude')
plt.show()

In [None]:
stft = np.abs(np.array([np.fft.fft(signal[i:i+nfft] * window) for i in range(0, len(signal) - nfft, hop_length)]))
print(stft.shape)
stft = np.abs(np.array([np.fft.rfft(signal[i:i+nfft] * window) for i in range(0, len(signal) - nfft, hop_length)]))

In [None]:
signal = load_audio("80s Beat 90 bpm_dry.wav")
wet= load_audio("80s Beat 90 bpm_wet.wav")

scales = [128, 64, 32, 16, 8]
distance = spectral_distance(signal, wet, scales)
print(f"Spectral distance between original and reconstructed: {distance.item()}")


x_stfts = multiscale_stft(signal, scales, .75)
y_stfts = multiscale_stft(wet, scales, .75)
print(len(x_stfts))


plot_spectral_distance(x_stfts, y_stfts, scales)



In [None]:
from matplotlib.colors import LogNorm

def plot_spectrogram(stft, scale, title):
    plt.figure(figsize=(10, 4))
    plt.imshow(stft[0], aspect='auto', origin='lower', cmap='gray', norm=LogNorm(1e-9, 1,clip=False), interpolation='none')
    plt.colorbar()
    plt.title(f'{title} (Scale: {scale})')
    plt.xlabel('Time')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.show()


In [None]:
import matplotlib.colors as colors
import matplotlib

In [None]:
def summarize_stft(stft_tensor):
    return {
        "shape": stft_tensor.shape,
        "mean": torch.mean(stft_tensor).item(),
        "std": torch.std(stft_tensor).item(),
        "min": torch.min(stft_tensor).item(),
        "max": torch.max(stft_tensor).item(),
    }

scales = [2048,1024,512,256,128,64,32,16,8]
x_stfts = multiscale_stft(signal, scales, .75)
y_stfts = multiscale_stft(wet, scales, .75)

normalized_x_stfts = []
normalized_y_stfts = []

for i,stft in enumerate(x_stfts):
    # print(torch.max(stft).item())
    normalized_x_stfts.append(stft / torch.max(stft))
    normalized_y_stfts.append(y_stfts[i] / torch.max(stft))

# print(torch.max(normalized_x_stfts[0]).item())

normalized_distance = []
for i in range(len(scales)):
    distance = torch.abs(normalized_x_stfts[i] - normalized_y_stfts[i])
    normalized_distance.append(distance)

norm=LogNorm(1e-9, 1)
cmap = matplotlib.cm.get_cmap('viridis')

print(normalized_distance[0].shape)
print(normalized_x_stfts[0].shape)
print(normalized_y_stfts[0].shape)
print(normalized_distance[0][0,0:5,-3])
print(normalized_x_stfts[0][0,0:5,-3])
print(normalized_y_stfts[0][0,0:5,-3])


print(norm(normalized_distance[0][0,0:5,-3]))
print(norm(normalized_x_stfts[0][0,0:5,-3]))
print(norm(normalized_y_stfts[0][0,0:5,-3]))

# distance = spectral_distance(signal, wet)
# print(f"Spectral distance between original and processed: {distance.item()}")
print(cmap(norm(normalized_distance[0][0,0:5,-7])))
print(cmap(norm(normalized_x_stfts[0][0,0:5,-7])))
print(cmap(norm(normalized_y_stfts[0][0,0:5,-7])))

    

In [None]:


print("STFT Summary for Original and Processed Signals:")
for scale, x_stft, y_stft in zip(scales, normalized_x_stfts, normalized_y_stfts):
    print(f"\nScale: {scale}")
    print("Original Signal STFT:")
    print(summarize_stft(x_stft))
    plot_spectrogram(x_stft, scale, 'Original Signal STFT')
    plt.show()
    print("Processed Signal STFT:")
    print(summarize_stft(y_stft))
    plot_spectrogram(y_stft, scale, 'Processed Signal STFT')
    plt.show()
    distance = torch.abs(x_stft - y_stft)

    print(distance[0,0:5,-3])
    print(x_stft[0,0:5,-3])
    print(y_stft[0,0:5,-3])

    plot_spectrogram(distance, scale, 'Spectral Distance')

    plot_spectrogram(y_stft, scale, 'Original Signal STFT')

    plt.show()

    print(f"Linear distance: {lin_distance(x_stft, y_stft)}")
    print(f"Log distance: {log_distance(x_stft, y_stft)}")