In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import librosa
from source.network.metrics import spectral_distance, multiscale_stft
from matplotlib.colors import LogNorm
import matplotlib.colors as colors

def load_audio(file_path, sr=44100):
    audio, _ = librosa.load(file_path, sr=sr)
    return torch.from_numpy(audio).float().unsqueeze(0)

def plot_waveform(signal, title):
    plt.figure(figsize=(12, 4))
    plt.plot(signal.squeeze().numpy())
    plt.title(title)
    plt.xlabel('Sample')
    plt.ylabel('Amplitude')
    plt.show()

def plot_spectrogram(signal, sr, title):
    D = librosa.stft(signal.squeeze().numpy())
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    plt.figure(figsize=(12, 4))
    librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='hz')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.show()

def plot_frequency_domain(signal, sample_rate=44100, title=''):
    fft = np.fft.fft(signal.squeeze().numpy())
    freqs = np.fft.fftfreq(len(fft), 1/sample_rate)
    plt.figure(figsize=(12, 4))
    plt.plot(freqs, np.abs(fft))
    plt.title(title)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude')
    plt.xscale('log')
    plt.show()

def plot_spectral_distance(x_stfts, y_stfts, scales):
    for i, scale in enumerate(scales):
        plt.figure(figsize=(12, 4))
        plt.imshow(np.abs(y_stfts[i][0] - x_stfts[i][0]).numpy(), aspect='auto', cmap='viridis', norm=LogNorm(1e-9, 1))
        plt.gca().invert_yaxis()
        plt.title(f'Spectral Distance (scale {scale})')
        plt.colorbar(label='Magnitude')
        plt.xlabel('Time Frame')
        plt.ylabel('Frequency Bin')
        plt.show()

def plot_stft(stft, scale, title):
    plt.figure(figsize=(10, 4))
    plt.imshow(stft[0], aspect='auto', origin='lower', cmap='viridis', norm=LogNorm(1e-9, 1, clip=False), interpolation='none')
    plt.colorbar()
    plt.title(f'{title} (Scale: {scale})')
    plt.xlabel('Time')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.show()

def plot_stft_constant(original_stft, processed_stft, spectral_distance, scale):
    # # Find the global min and max values across all three tensors
    # vmin = min(original_stft.min().item(), processed_stft.min().item(), spectral_distance.min().item())
    # vmax = max(original_stft.max().item(), processed_stft.max().item(), spectral_distance.max().item())

    # # Create a common normalization
    # norm = LogNorm(vmin=max(vmin, 1e-9), vmax=vmax)

    # # Create figure and axes
    # fig, axs = plt.subplots(3, 1, figsize=(12, 18))
    # fig.subplots_adjust(right=0.85)  # Make room for colorbar

    # images = []
    # titles = ['Original Signal STFT', 'Processed Signal STFT', 'Spectral Distance']
    # data = [original_stft, processed_stft, spectral_distance]
    
    # for ax, img_data, title in zip(axs, data, titles):
    #     im = ax.imshow(img_data.squeeze().cpu().numpy(), aspect='auto', origin='lower', cmap='viridis', norm=norm)
    #     ax.set_title(f'{title} (Scale: {scale})')
    #     ax.set_xlabel('Time')
    #     ax.set_ylabel('Frequency')
    #     images.append(im)
    
    # # Add a single colorbar for all subplots
    # cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    # fig.colorbar(images[0], cax=cbar_ax, label='Magnitude')
    
    # plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout to not overlap with colorbar
    # plt.show()



    def find_non_zero_min(tensor):
        non_zero = tensor[tensor > 0]
        return non_zero.min().item() if len(non_zero) > 0 else tensor.min().item()

    # Find the global non-zero min and max values across all three tensors
    vmin = min(find_non_zero_min(original_stft), 
               find_non_zero_min(processed_stft), 
               find_non_zero_min(spectral_distance))
    vmax = max(original_stft.max().item(), processed_stft.max().item(), spectral_distance.max().item())

    print(vmin)

    # Create a common normalization with the non-zero minimum
    norm = LogNorm(vmin=max(vmin, 1e-9), vmax=vmax)  # Ensure vmin is not zero

    # Create figure and axes
    fig, axs = plt.subplots(3, 1, figsize=(12, 18))
    fig.subplots_adjust(right=0.85)  # Make room for colorbar

    images = []
    titles = ['Original Signal STFT', 'Processed Signal STFT', 'Spectral Distance']
    data = [original_stft, processed_stft, spectral_distance]
    
    for ax, img_data, title in zip(axs, data, titles):
        im = ax.imshow(img_data.squeeze().cpu().numpy(), aspect='auto', origin='lower', cmap='viridis', norm=norm)
        ax.set_title(f'{title} (Scale: {scale})')
        ax.set_xlabel('Time')
        ax.set_ylabel('Frequency')
        images.append(im)
    
    # Add a single colorbar for all subplots
    cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    fig.colorbar(images[0], cax=cbar_ax, label='Magnitude')
    
    plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust layout to not overlap with colorbar
    plt.show()

    
def lin_distance(x, y):
    return torch.norm(x - y) / torch.norm(x)

def log_distance(x, y):
    return torch.abs(torch.log(x + 1e-7) - torch.log(y + 1e-7)).mean()

def summarize_stft(stft_tensor):
    return {
        "shape": stft_tensor.shape,
        "mean": torch.mean(stft_tensor).item(),
        "std": torch.std(stft_tensor).item(),
        "min": torch.min(stft_tensor).item(),
        "max": torch.max(stft_tensor).item(),
    }

def normalize_stfts(stfts):
    return [stft / torch.max(stft) for stft in stfts]

def calculate_spectral_distance(signal, wet, scales):
    x_stfts = multiscale_stft(signal, scales, 0.75)
    y_stfts = multiscale_stft(wet, scales, 0.75)
    
    normalized_x_stfts = normalize_stfts(x_stfts)
    normalized_y_stfts = normalize_stfts(y_stfts)
    
    normalized_distance = [torch.abs(x - y) for x, y in zip(normalized_x_stfts, normalized_y_stfts)]
    
    return normalized_x_stfts, normalized_y_stfts, normalized_distance

def analyze_and_plot_spectral_distance(signal, wet, scales):
    normalized_x_stfts, normalized_y_stfts, normalized_distance = calculate_spectral_distance(signal, wet, scales)
    
    print("STFT Summary for Original and Processed Signals:")
    for scale, x_stft, y_stft, distance in zip(scales, normalized_x_stfts, normalized_y_stfts, normalized_distance):
        print(f"\nScale: {scale}")
        print("Original Signal STFT:")
        print(summarize_stft(x_stft))
        # plot_stft(x_stft, scale, 'Original Signal STFT')
        
        print("Processed Signal STFT:")
        print(summarize_stft(y_stft))
        # plot_stft(y_stft, scale, 'Processed Signal STFT')
        
        # plot_stft(distance, scale, 'Spectral Distance')

        plot_stft_constant(x_stft, y_stft, distance, scale)
        
        print(f"Linear distance: {lin_distance(x_stft, y_stft)}")
        print(f"Log distance: {log_distance(x_stft, y_stft)}")

In [None]:
signal = load_audio("80s Beat 90 bpm_dry.wav")
wet = load_audio("80s Beat 90 bpm_wet.wav")

scales = [2048, 1024, 512, 256, 128, 64, 32, 16, 8]

## Basic plots

In [None]:
plot_waveform(signal, "Original Signal Waveform")

In [None]:
plot_spectrogram(signal, 44100, "Original Signal Spectrogram")

In [None]:
plot_frequency_domain(signal, title="Original Signal Frequency Domain")

## Spectral Dist

In [None]:
signal.unsqueeze_(0)
wet.unsqueeze_(0)

distance = spectral_distance(signal, wet, scales)
print(f"Spectral distance between original and processed: {distance.item()}")

In [None]:
analyze_and_plot_spectral_distance(signal, wet, scales)