In [1]:
import os
import torch
from torch.utils.data import TensorDataset
import librosa
import librosa.display
import numpy as np
import soundfile as sf
import matplotlib.pyplot as plt
import glob

# Function to preprocess LibriSpeech dataset into Mel spectrograms and save them as tensors
def librispeech_to_mel_and_save(
    root_dir,
    output_dir,
    target_sr=16000,
    n_mels=1024,       # Increased from 64 to 128 for higher frequency resolution
    n_fft=8192,       # Increased from 1024 to 2048 for higher frequency resolution
    hop_length=256,   # Decreased from 512 to 256 for higher time resolution
    target_length=256 # Adjusted to accommodate the change in hop_length
):
    """
    Convert LibriSpeech audio files to Mel spectrograms, save spectrograms as tensors, and save original audio.
    """
    flac_files = glob.glob(os.path.join(root_dir, "**", "*.flac"), recursive=True)
    os.makedirs(output_dir, exist_ok=True)  # Create output directory if it doesn't exist
    mel_spectrograms = []  # To store Mel spectrogram tensors

    if len(flac_files) == 0:
        raise ValueError(f"No FLAC files found in directory: {root_dir}")

    for i, file in enumerate(flac_files[:10]):  # Process only the first 10 files
        try:
            # Load the audio file
            audio, sr = librosa.load(file, sr=target_sr)
            audio = audio / np.max(np.abs(audio))  # Normalize audio to [-1, 1]

            # Save the original audio
            audio_output_path = os.path.join(output_dir, f"original_{i}.wav")
            sf.write(audio_output_path, audio, sr)
            print(f"Original audio saved to {audio_output_path}")

            # Convert to Mel spectrogram
            mel_spec = librosa.feature.melspectrogram(
                y=audio,
                sr=sr,
                n_fft=n_fft,
                hop_length=hop_length,
                n_mels=n_mels,
                power=1.0,
            )
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

            # Normalize using fixed range [mel_spec_db.min(), 0] to [-1, 1]
            min_db = mel_spec_db.min()
            mel_spec_db_norm = (mel_spec_db - min_db) / -min_db  # Normalize to [0, 1]
            mel_spec_db_norm = mel_spec_db_norm * 2 - 1           # Scale to [-1, 1]

            # Ensure the Mel spectrogram has a fixed length
            if mel_spec_db_norm.shape[1] < target_length:
                pad_width = target_length - mel_spec_db_norm.shape[1]
                mel_spec_db_norm = np.pad(
                    mel_spec_db_norm, ((0, 0), (0, pad_width)), mode="constant"
                )
            else:
                mel_spec_db_norm = mel_spec_db_norm[:, :target_length]

            # Save the Mel spectrogram as a tensor
            tensor_output_path = os.path.join(output_dir, f"mel_{i}.pt")
            torch.save(
                {
                    'mel_spec_db_norm': torch.tensor(mel_spec_db_norm, dtype=torch.float32),
                    'min_db': min_db  # Save min_db for de-normalization
                },
                tensor_output_path
            )
            print(f"Mel spectrogram tensor saved to {tensor_output_path}")

            # Add to the list of tensors
            mel_spectrograms.append(
                torch.tensor(mel_spec_db_norm, dtype=torch.float32).unsqueeze(0)
            )  # Add channel dim

        except Exception as e:
            print(f"Error processing file {file}: {e}")

    if len(mel_spectrograms) == 0:
        raise ValueError(
            "No valid audio files were processed into Mel spectrograms."
        )

    # Return TensorDataset of Mel spectrograms
    mel_spectrograms = torch.stack(
        mel_spectrograms
    )  # Shape: [num_samples, 1, n_mels, target_length]
    return TensorDataset(mel_spectrograms)


def reconstruct_audio_from_mel_tensors(
    output_dir,
    target_sr=16000,
    n_fft=8192,
    hop_length=256,
    n_iter=500,  # Increased iterations for better phase estimation
    target_length=256,
):
    """
    Reconstruct audio from Mel spectrogram tensors and save as WAV, ensuring correct length and preserving amplitude.
    """
    tensor_files = glob.glob(os.path.join(output_dir, "mel_*.pt"))
    tensor_files.sort()  # Ensure consistent order

    for i, tensor_file in enumerate(tensor_files):
        try:
            # Load the Mel spectrogram tensor
            data = torch.load(tensor_file)
            mel_spec_db_norm = data['mel_spec_db_norm'].numpy()
            min_db = data['min_db']

            # De-normalize from [-1, 1] back to the dB range
            mel_spec_db = (mel_spec_db_norm + 1) / 2  # Scale to [0, 1]
            mel_spec_db = mel_spec_db * -min_db + min_db  # Scale back to original dB range

            # Convert dB back to power
            mel_spec_power = librosa.db_to_power(mel_spec_db)  # Convert to power scale

            # Reconstruct audio from Mel spectrogram
            audio = librosa.feature.inverse.mel_to_audio(
                mel_spec_power,
                sr=target_sr,
                n_fft=n_fft,
                hop_length=hop_length,
                n_iter=n_iter,
                power=1.0,
            )

            # Match amplitude to original audio
            original_audio_path = os.path.join(output_dir, f"original_{i}.wav")
            original_audio, _ = librosa.load(original_audio_path, sr=target_sr)

            # Compute RMS of original and reconstructed audio
            original_rms = np.sqrt(np.mean(original_audio**2))
            reconstructed_rms = np.sqrt(np.mean(audio**2))

            if reconstructed_rms > 0:  # Prevent division by zero
                audio = audio * (original_rms / reconstructed_rms)

            # Ensure the audio matches the expected length
            expected_length = target_length * hop_length
            audio = librosa.util.fix_length(audio, size=expected_length)

            # Save the reconstructed audio
            reconstructed_audio_path = os.path.join(
                output_dir, f"reconstructed_{i}.wav"
            )
            sf.write(reconstructed_audio_path, audio, target_sr)
            print(f"Reconstructed audio saved to {reconstructed_audio_path}")

            # Plot Mel spectrograms for comparison
            plot_mel_spectrograms_comparison(
                original_audio_path,
                reconstructed_audio_path,
                sr=target_sr,
                n_fft=n_fft,
                hop_length=hop_length,
                n_mels=1024,  # Adjusted to match input
                i=i,
                output_dir=output_dir,
            )

        except Exception as e:
            print(f"Error reconstructing audio from file {tensor_file}: {e}")



def plot_mel_spectrograms_comparison(original_audio_path, reconstructed_audio_path, sr, n_fft, hop_length, n_mels, i, output_dir):
    """
    Plot and save Mel spectrograms of the original and reconstructed audio for comparison.
    """
    # Load original and reconstructed audio
    original_audio, _ = librosa.load(original_audio_path, sr=sr)
    reconstructed_audio, _ = librosa.load(reconstructed_audio_path, sr=sr)

    # Match lengths of original and reconstructed audio
    min_length = min(len(original_audio), len(reconstructed_audio))
    original_audio = original_audio[:min_length]
    reconstructed_audio = reconstructed_audio[:min_length]

    # Compute Mel spectrograms
    original_mel_spec = librosa.feature.melspectrogram(
        y=original_audio,
        sr=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        power=1.0,
    )
    original_mel_spec_db = librosa.power_to_db(original_mel_spec, ref=np.max)

    reconstructed_mel_spec = librosa.feature.melspectrogram(
        y=reconstructed_audio,
        sr=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        power=1.0,
    )
    reconstructed_mel_spec_db = librosa.power_to_db(reconstructed_mel_spec, ref=np.max)

    # Plotting
    fig, axs = plt.subplots(2, 1, figsize=(10, 8))

    img1 = librosa.display.specshow(
        original_mel_spec_db,
        x_axis='time',
        y_axis='mel',
        sr=sr,
        hop_length=hop_length,
        ax=axs[0]
    )
    axs[0].set_title(f'Original Audio Mel Spectrogram {i}')
    fig.colorbar(img1, ax=axs[0], format='%+2.0f dB')

    img2 = librosa.display.specshow(
        reconstructed_mel_spec_db,
        x_axis='time',
        y_axis='mel',
        sr=sr,
        hop_length=hop_length,
        ax=axs[1]
    )
    axs[1].set_title(f'Reconstructed Audio Mel Spectrogram {i}')
    fig.colorbar(img2, ax=axs[1], format='%+2.0f dB')

    plt.tight_layout()
    plot_path = os.path.join(output_dir, f'mel_spectrogram_comparison_{i}.png')
    plt.savefig(plot_path)
    plt.close(fig)
    print(f'Mel spectrogram comparison plot saved to {plot_path}')



# Main Function
def main():
    data_dir = r"data/LibriSpeech/LibriSpeech/dev-clean"  # Input directory with audio files
    output_dir = r"output_mel_spectrograms"  # Output directory for spectrograms and reconstructed audio

    # Step 1: Convert audio to Mel spectrograms and save
    print("Processing audio files and saving Mel spectrograms...")
    mel_dataset = librispeech_to_mel_and_save(data_dir, output_dir)
    print(f"Total number of processed samples: {len(mel_dataset)}")

    # Step 2: Reconstruct audio from saved Mel spectrograms and plot comparisons
    print("Reconstructing audio from Mel spectrograms and plotting comparisons...")
    reconstruct_audio_from_mel_tensors(output_dir)

if __name__ == "__main__":
    main()


Processing audio files and saving Mel spectrograms...
Original audio saved to output_mel_spectrograms/original_0.wav
Mel spectrogram tensor saved to output_mel_spectrograms/mel_0.pt
Original audio saved to output_mel_spectrograms/original_1.wav
Mel spectrogram tensor saved to output_mel_spectrograms/mel_1.pt
Original audio saved to output_mel_spectrograms/original_2.wav
Mel spectrogram tensor saved to output_mel_spectrograms/mel_2.pt
Original audio saved to output_mel_spectrograms/original_3.wav
Mel spectrogram tensor saved to output_mel_spectrograms/mel_3.pt
Original audio saved to output_mel_spectrograms/original_4.wav
Mel spectrogram tensor saved to output_mel_spectrograms/mel_4.pt
Original audio saved to output_mel_spectrograms/original_5.wav
Mel spectrogram tensor saved to output_mel_spectrograms/mel_5.pt
Original audio saved to output_mel_spectrograms/original_6.wav
Mel spectrogram tensor saved to output_mel_spectrograms/mel_6.pt
Original audio saved to output_mel_spectrograms/or

  data = torch.load(tensor_file)


Reconstructed audio saved to output_mel_spectrograms/reconstructed_0.wav
Mel spectrogram comparison plot saved to output_mel_spectrograms/mel_spectrogram_comparison_0.png
Reconstructed audio saved to output_mel_spectrograms/reconstructed_1.wav
Mel spectrogram comparison plot saved to output_mel_spectrograms/mel_spectrogram_comparison_1.png
Reconstructed audio saved to output_mel_spectrograms/reconstructed_2.wav
Mel spectrogram comparison plot saved to output_mel_spectrograms/mel_spectrogram_comparison_2.png
Reconstructed audio saved to output_mel_spectrograms/reconstructed_3.wav
Mel spectrogram comparison plot saved to output_mel_spectrograms/mel_spectrogram_comparison_3.png
Reconstructed audio saved to output_mel_spectrograms/reconstructed_4.wav
Mel spectrogram comparison plot saved to output_mel_spectrograms/mel_spectrogram_comparison_4.png
Reconstructed audio saved to output_mel_spectrograms/reconstructed_5.wav
Mel spectrogram comparison plot saved to output_mel_spectrograms/mel_spe