In [1]:
"""
ChatGPT chat histories that helped me
https://chatgpt.com/share/686a7921-3030-8005-9b79-65d7198f586c
https://chatgpt.com/share/686a7813-b100-8005-ad6b-add8985b711c
https://chatgpt.com/share/686a7871-638c-8005-89d6-78e0567f5124
https://chatgpt.com/share/686a790c-4ce8-8005-b5ae-4bb3fc107c7e
https://chatgpt.com/share/686a7885-1654-8005-8078-30985754675e
https://chatgpt.com/share/686a7899-f834-8005-a729-877113dc0878
"""

'\nChatGPT chat histories that helped me\nhttps://chatgpt.com/share/686a7921-3030-8005-9b79-65d7198f586c\nhttps://chatgpt.com/share/686a7813-b100-8005-ad6b-add8985b711c\nhttps://chatgpt.com/share/686a7871-638c-8005-89d6-78e0567f5124\nhttps://chatgpt.com/share/686a790c-4ce8-8005-b5ae-4bb3fc107c7e\nhttps://chatgpt.com/share/686a7885-1654-8005-8078-30985754675e\nhttps://chatgpt.com/share/686a7899-f834-8005-a729-877113dc0878\n'

In [2]:
import os
import torch
import torchaudio
import numpy as np
from tqdm import tqdm

In [3]:
sample_rate = 22050
n_fft = 1024
hop_length = 512
n_mels = 128
target_duration = 5  # seconds
target_length = sample_rate * target_duration  # 110250
timesteps = 36
frames_per_timestep = 6

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=n_fft,
    hop_length=hop_length,
    n_mels=n_mels,
    window_fn=torch.hann_window
).to(device)

def find_audio_files_with_label(root_dir):
    files = []
    labels = []
    for label_dir, label in [('real', 0), ('fake', 1)]:
        full_dir = os.path.join(root_dir, label_dir)
        for root, _, filenames in os.walk(full_dir):
            for fname in filenames:
                if fname.lower().endswith(('.wav', '.mp3')):
                    files.append(os.path.join(root, fname))
                    labels.append(label)
    return files, labels

def load_audio(filepath, target_sr=sample_rate):
    waveform, sr = torchaudio.load(filepath)
    if sr != target_sr:
        waveform = torchaudio.functional.resample(waveform, sr, target_sr)
    return waveform

def fix_audio_length(waveform, target_length):
    length = waveform.shape[1]
    if length < target_length:
        pad_amount = target_length - length
        waveform = torch.nn.functional.pad(waveform, (0, pad_amount))
    else:
        waveform = waveform[:, :target_length]
    return waveform

def normalize(x):
    return (x - x.min()) / (x.max() - x.min() + 1e-6)

def reshape_mel(mel):
    total_frames = mel.shape[1]
    expected_frames = timesteps * frames_per_timestep
    if total_frames < expected_frames:
        mel = np.pad(mel, ((0, 0), (0, expected_frames - total_frames)))
    else:
        mel = mel[:, :expected_frames]
    mel = mel.reshape(n_mels, timesteps, frames_per_timestep)
    mel = mel.transpose(1, 0, 2)  # (timesteps, height, width)
    mel = mel[..., np.newaxis]   # (timesteps, height, width, channels)
    return mel

In [5]:
dataset_dir = 'test_dataset'

audio_files, labels = find_audio_files_with_label(dataset_dir)

X = []
y = []

for i, file in enumerate(tqdm(audio_files, desc="Processing audio files")):
    waveform = load_audio(file).to(device)
    waveform = fix_audio_length(waveform, target_length)
    mel_spec = mel_transform(waveform).squeeze(0).cpu().numpy()
    mel_spec = normalize(mel_spec)
    mel_spec = reshape_mel(mel_spec)
    X.append(mel_spec)
    y.append(labels[i])

X = np.array(X)  # Shape: (samples, timesteps, height, width, channels)
y = np.array(y)

np.save('mel_spectrograms_test.npy', X)
np.save('labels_test.npy', y)

Processing audio files: 100%|██████████| 20046/20046 [01:48<00:00, 185.06it/s]
