In [8]:
from typing import Any, Callable, Dict, List, Optional, Union
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import repeat
from diffusers import (
    AudioLDMPipeline,
    AutoencoderKL,
    UNet2DConditionModel,
    DDIMScheduler,
)
from transformers import (
    ClapTextModelWithProjection,
    RobertaTokenizerFast,
    SpeechT5HifiGan,
)

class AudioLDM(nn.Module):
    def __init__(self, device='cuda', repo_id="cvssp/audioldm", config=None):
        super().__init__()
        self.device = torch.device(device)
        pipe = AudioLDMPipeline.from_pretrained(repo_id, use_safetensors=False)

        # Setup components and move to device
        self.pipe = pipe
        self.components = {
            'vae': (pipe.vae, AutoencoderKL),
            'tokenizer': (pipe.tokenizer, RobertaTokenizerFast),
            'text_encoder': (pipe.text_encoder, ClapTextModelWithProjection),
            'unet': (pipe.unet, UNet2DConditionModel),
            'vocoder': (pipe.vocoder, SpeechT5HifiGan),
            'scheduler': (pipe.scheduler, DDIMScheduler)
        }
        
        # Initialize and validate components
        for name, (component, expected_type) in self.components.items():
            if name in ['vae', 'text_encoder', 'unet', 'vocoder']:
                component = component.to(self.device)
            assert isinstance(component, expected_type), f"{name} type mismatch: {type(component)}"
            setattr(self, name, component)

        self.evalmode = True
        self.checkpoint_path = repo_id
        self.audio_duration = 10.24 if not config else config['duration']
        self.original_waveform_length = int(self.audio_duration * self.vocoder.config.sampling_rate)  # 10.24 * 16000 = 163840
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)  # 4
        print(f'[INFO] audioldm.py: loaded AudioLDM!')

In [9]:
from diffusers import AudioLDMPipeline
import torch

repo_id = "cvssp/audioldm"
pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm", use_safetensors=False, torch_dtype=torch.float32)
pipe = pipe.to("cuda")

prompt = "Foot steps on the wooden floor"
audio = pipe(prompt, num_inference_steps=100, audio_length_in_s=10.24).audios[0]



Loading pipeline components...: 100%|██████████| 6/6 [00:01<00:00,  5.46it/s]
100%|██████████| 100/100 [00:03<00:00, 33.18it/s]


In [10]:
print(audio.shape)

(163840,)


In [11]:
from IPython.display import Audio

Audio(audio, rate=16000)

In [12]:
import numpy as np

def normalize_wav(data: np.ndarray) -> np.ndarray:
    return data * 0.5

print(audio.min(), audio.max())
# normalized_data = normalize_wav(wav_data)

# print(normalized_data)

-0.72105396 0.55777234


In [13]:
import scipy

scipy.io.wavfile.write("Foot_steps_on_the_wooden_floor.wav", rate=16000, data=audio)

In [14]:
import torchaudio
wav, ori_sr = torchaudio.load("./Foot_steps_on_the_wooden_floor.wav", normalize=True)  # ts[C,N'±]
wav = wav.numpy()
print(wav.min(), wav.max())

wav2, ori_sr = torchaudio.load("./A_cat_meowing.wav", normalize=True)  # ts[C,N'±]
wav2 = wav2.numpy()
print(wav2.min(), wav2.max())




-0.72105396 0.55777234


LibsndfileError: Error opening './A_cat_meowing.wav': System error.

In [None]:
wav_sum = wav + wav2
print(wav_sum.min(), wav_sum.max())
print(wav_sum.dtype)
wav_sum = wav_sum.squeeze()
print(wav_sum.shape)
scipy.io.wavfile.write("Cat_n_Footstep.wav", rate=16000, data=wav_sum)


-0.5448767 0.58386886
float32
(163840,)


In [None]:
from src.data_processing.audio_processing import AudioDataProcessor
import torch

processor = AudioDataProcessor(device='cuda', config_path='configs/audioldm.yaml')
wav = processor.read_wav_file('data/samples/A_cat_meowing.wav')
wav = processor.prepare_wav(wav)
print(wav.shape)
stft, stft_complex = processor.wav_to_stft(wav)
_wav = processor.inverse_stft(stft, stft_complex)

print(_wav.shape)

start assertions
torch.Size([1, 163840])
torch.Size([163840])
(1, 163840)


In [None]:
print(processor.spec_length)

1025


In [None]:
import numpy as np
import torch

wav = wav.cpu().numpy() if isinstance(wav, torch.Tensor) else wav
_wav = _wav.cpu().numpy() if isinstance(_wav, torch.Tensor) else _wav

# 1. Mean Absolute Error (MAE)
mae = np.mean(np.abs(wav - _wav))

# 2. Pearson Correlation Coefficient (PCC)
pcc = np.corrcoef(wav, _wav)[0, 1]

# 3. Normalized Root Mean Squared Error (NRMSE)
rmse = np.sqrt(np.mean((wav - _wav) ** 2))
nrmse = rmse / (np.max(wav) - np.min(_wav) + 1e-8)  # 0 나눔 방지

print(f'MAE: {mae:.4f}, PCC: {pcc:.4f}, NRMSE: {nrmse:.4f}')

MAE: 0.0008, PCC: 1.0000, NRMSE: 0.0014


In [10]:
import pandas as pd

# CSV 파일을 DataFrame으로 읽어오기
df = pd.read_csv('./output.csv', header=None)
df.columns = ['sdr', 'sisdr']
# 각 열에 대해 처리
for idx, col_name in enumerate(df.columns):
    # 리스트로 변환
    values = df[col_name].tolist()
    
    # min, max, mean 계산
    col_min = min(values)
    col_max = max(values)
    col_mean = df[col_name].mean()

    print(f"Column {idx+1}:")
    print(f"  List : {values}")
    print(f"  Min  : {col_min}")
    print(f"  Max  : {col_max}")
    print(f"  Mean : {col_mean}\n")

top_n = 100
top_indices_dict = {}
for col in df.columns:
    top_indices = df[col].nlargest(top_n).index.tolist()
    top_indices_dict[col] = top_indices
    print(f"Top {top_n} indices for {col}:")
    print(top_indices)
    print()

common_indices = list(set(top_indices_dict['sdr']) & set(top_indices_dict['sisdr']))
common_indices.sort()

print("공통 인덱스:")
print(common_indices)

Column 1:
  List : [5.1745358, 4.938553, -0.072019726, 3.8968942, 4.0985007, 4.4469557, -5.230257, 0.96649003, -0.010994993, 4.2537494, 10.482616, 9.3321705, 2.8375123, -7.2876053, 7.0423803, 5.007015, 3.230535, -0.8949924, 7.7470436, 2.9047976, 0.7426441, 6.1531444, 6.2870264, 6.8650956, 0.9379583, 2.1382139, -3.8619494, 3.1181364, 3.8559976, 4.857766, 3.3090296, 1.6351025, 0.39957702, 1.5171443, 1.1501591, 10.2819395, -4.0753202, -1.4676024, 6.914149, -5.2394953, -3.9802976, -1.188281, 2.0688937, 11.697668, 0.71216035, 2.0796568, -0.4248659, 6.0563135, -3.61853, 5.3345013, 5.612132, 6.907404, 1.1273618, -11.00503, 4.296626, -1.6389083, 2.3314223, 2.1040397, -2.4831128, -3.0109046, 1.1309744, 2.43755, 5.0695076, -7.943907, 5.222909, 8.845705, 5.172761, 2.3688917, -6.4914274, 4.648074, 5.129834, 5.2231884, 4.880218, 4.9207296, -4.5867105, 5.5279436, 1.1327202, 5.142906, 6.6814613, 1.9476868, 5.1909685, -0.1479365, -0.41313663, 1.0607333, 4.2121773, 3.4379659, 1.5241592, 2.540207, 4.009

In [14]:
import os
import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display as ld
from src.utils import read_wav_file, calculate_sdr, calculate_sisdr

def plot_wav_mel(wav_arrays, sr=16000, save_path="./test/result/mel_05lev_5it_thres.png", score=(0,0), idx=000):
    fig, axes = plt.subplots(2, len(wav_arrays), figsize=(4 * len(wav_arrays), 6))

    clip_duration = 10.24  # 클리핑 길이 (초)
    hop_length = 512       # Hop length 설정

    for i, wav in enumerate(wav_arrays):
        # NumPy 배열 확인 후 1D 변환
        if len(wav.shape) > 1:
            wav = wav.squeeze()

        # 샘플링 레이트 및 오디오 데이터 설정
        duration = len(wav) / sr  # 오디오 길이(초)

        # Clip to first 10.24 seconds if longer
        if duration > clip_duration:
            wav = wav[: int(clip_duration * sr)]  # 앞 10.24초만 유지

        time = np.linspace(0, len(wav) / sr, num=len(wav))

        # **Waveform 플로팅**
        axes[0, i].plot(time, wav, lw=0.5)
        axes[0, i].set_title(f"Waveform {i+1}")
        axes[0, i].set_xlabel("Time (s)")
        axes[0, i].set_ylabel("Amplitude")
        axes[0, i].set_ylim([-1, 1])

        # **Mel Spectrogram 계산 및 플로팅**
        mel_spec = librosa.feature.melspectrogram(y=wav, sr=sr, n_mels=128, hop_length=hop_length)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

        # ✅ specshow()에 전달할 때 (n_mels, frames) 차원 확인
        ld.specshow(
            mel_spec_db,
            sr=sr,
            hop_length=hop_length,
            x_axis="time",
            y_axis="mel",
            vmin=-80,   # 최소값
            vmax=0,     # 최대값
            ax=axes[1, i]
        )
        
        axes[1, i].set_title(f"Mel Spectrogram {i+1}")

    plt.suptitle(f"Index: {idx} | SDR: {score[0]:.2f}, SISDR: {score[1]:.2f}", fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

# mix = "./data/samples/Cat_n_Footstep.wav"
# sep = "./test/result/cat_separated.wav"
# gt = "./data/samples/A_cat_meowing.wav"
# mix_wav = read_wav_file(filename=mix, target_duration=10.24, target_sr=16000)
# sep_wav = read_wav_file(filename=sep, target_duration=10.24, target_sr=16000)
# gt_wav = read_wav_file(filename=gt, target_duration=10.24, target_sr=16000)
# sdr = calculate_sdr(mix_wav, sep_wav)
# sisdr = calculate_sisdr(mix_wav, sep_wav)
# print(f'SDR: {sdr:.4f}, SI-SDR: {sisdr:.4f}')
# sdr = calculate_sdr(gt_wav, sep_wav)
# sisdr = calculate_sisdr(gt_wav, sep_wav)
# print(f'SDR: {sdr:.4f}, SI-SDR: {sisdr:.4f}')
# wav_paths = [mix_wav, sep_wav, gt_wav]
# plot_wav_mel(wav_paths, save_path=f"./test/result/audioldm_05lev.png")


import csv
from tqdm import tqdm
import torchaudio

metadata_pth='./src/benchmarks/metadata/vggsound_eval.csv'
audio_dir='./data/vggsound'

with open(metadata_pth) as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    eval_list = [row for row in csv_reader][1:]

for i, eval_data in enumerate(tqdm(eval_list)):
    if i not in common_indices:
        continue  # common_indices에 없으면 skip

    file_id, mix_wav, s0_wav, s0_text, s1_wav, s1_text = eval_data
    mixture_path = os.path.join(audio_dir, mix_wav)
    source_path = os.path.join(audio_dir, s0_wav)

    wav = read_wav_file(filename=mixture_path, target_duration=10.24, target_sr=16000)
    sep_pth = f"./test/vgg_result/{file_id}.wav"
    wav_sep = read_wav_file(filename=sep_pth, target_duration=10.24, target_sr=16000)
    wav_gt = read_wav_file(filename=source_path, target_duration=10.24, target_sr=16000)

    assert wav.shape == wav_sep.shape == wav_gt.shape, f"shape must be same: {wav.shape}, {wav_sep.shape}, {wav_gt.shape}"

    wav_paths = [wav, wav_sep, wav_gt]
    sdr = calculate_sdr(wav_sep, wav_gt)
    sisdr = calculate_sisdr(wav_sep, wav_gt)
    print(f'SDR: {sdr:.4f}, SI-SDR: {sisdr:.4f}')

    plot_wav_mel(
        wav_paths,
        save_path=f"./test/vgg_mels/{file_id}.png",
        score=(sdr, sisdr),
        idx=file_id
    )

  0%|          | 0/1000 [00:00<?, ?it/s]

SDR: 9.4700, SI-SDR: 10.3062


  1%|          | 11/1000 [00:01<01:46,  9.26it/s]

SDR: 8.4296, SI-SDR: 8.8579


  1%|          | 12/1000 [00:02<04:00,  4.11it/s]

SDR: 7.3298, SI-SDR: 7.0000


  2%|▏         | 19/1000 [00:03<03:19,  4.91it/s]

SDR: 3.8387, SI-SDR: 7.1377


  2%|▏         | 24/1000 [00:06<05:58,  2.72it/s]

SDR: 9.4880, SI-SDR: 9.9362


  4%|▎         | 36/1000 [00:08<03:29,  4.59it/s]

SDR: 9.7918, SI-SDR: 14.2825


  4%|▍         | 44/1000 [00:09<03:06,  5.11it/s]

SDR: 3.9267, SI-SDR: 7.1357


  5%|▌         | 52/1000 [00:10<02:44,  5.75it/s]

SDR: 8.0160, SI-SDR: 8.2596


  7%|▋         | 66/1000 [00:11<02:05,  7.43it/s]

SDR: 12.1138, SI-SDR: 13.0895


 10%|█         | 102/1000 [00:12<01:02, 14.45it/s]

SDR: 9.1688, SI-SDR: 8.6712


 11%|█         | 107/1000 [00:13<01:18, 11.43it/s]

SDR: 6.7914, SI-SDR: 9.2915


 11%|█         | 109/1000 [00:14<01:43,  8.64it/s]

SDR: 8.8537, SI-SDR: 8.7653


 12%|█▏        | 118/1000 [00:16<01:50,  8.01it/s]

SDR: 5.9333, SI-SDR: 8.0789


 14%|█▍        | 141/1000 [00:17<01:15, 11.40it/s]

SDR: 2.1534, SI-SDR: 10.1788


 15%|█▍        | 146/1000 [00:18<01:32,  9.27it/s]

SDR: 9.8273, SI-SDR: 10.6714


 15%|█▌        | 151/1000 [00:19<01:45,  8.05it/s]

SDR: 4.7591, SI-SDR: 7.1369


 15%|█▌        | 153/1000 [00:20<02:15,  6.27it/s]

SDR: 5.8832, SI-SDR: 7.5510


 16%|█▌        | 156/1000 [00:22<02:57,  4.76it/s]

SDR: 7.5138, SI-SDR: 7.5240


 16%|█▌        | 157/1000 [00:23<03:46,  3.72it/s]

SDR: 5.5340, SI-SDR: 9.3962


 16%|█▌        | 160/1000 [00:24<04:00,  3.49it/s]

SDR: 5.9888, SI-SDR: 8.6331


 18%|█▊        | 179/1000 [00:25<01:50,  7.44it/s]

SDR: 10.7683, SI-SDR: 14.2174


 21%|██        | 212/1000 [00:26<00:56, 13.92it/s]

SDR: 15.0685, SI-SDR: 16.8045


 22%|██▏       | 215/1000 [00:27<01:14, 10.59it/s]

SDR: 9.7223, SI-SDR: 10.2887


 22%|██▏       | 218/1000 [00:28<01:35,  8.18it/s]

SDR: 8.4620, SI-SDR: 8.4763


 22%|██▏       | 220/1000 [00:30<02:05,  6.22it/s]

SDR: 12.2945, SI-SDR: 12.0797


 25%|██▌       | 252/1000 [00:31<00:56, 13.18it/s]

SDR: 8.0525, SI-SDR: 7.3497


 25%|██▌       | 254/1000 [00:32<01:15,  9.86it/s]

SDR: 17.2718, SI-SDR: 17.3017


 26%|██▌       | 256/1000 [00:33<01:38,  7.55it/s]

SDR: 5.5677, SI-SDR: 17.5176


 26%|██▌       | 257/1000 [00:34<02:11,  5.63it/s]

SDR: 7.2956, SI-SDR: 7.0276


 26%|██▌       | 259/1000 [00:35<02:45,  4.47it/s]

SDR: 13.7787, SI-SDR: 13.6070


 26%|██▌       | 260/1000 [00:36<03:33,  3.47it/s]

SDR: 3.5700, SI-SDR: 7.9903


 28%|██▊       | 285/1000 [00:37<01:17,  9.20it/s]

SDR: 5.5579, SI-SDR: 7.2482


 30%|██▉       | 299/1000 [00:39<01:20,  8.75it/s]

SDR: 6.7611, SI-SDR: 7.9576


 33%|███▎      | 332/1000 [00:40<00:47, 14.11it/s]

SDR: 4.5453, SI-SDR: 7.4255


 34%|███▍      | 342/1000 [00:41<00:50, 12.97it/s]

SDR: 6.5421, SI-SDR: 7.0403


 35%|███▌      | 350/1000 [00:42<00:56, 11.50it/s]

SDR: 7.1431, SI-SDR: 11.4341


 36%|███▌      | 361/1000 [00:43<00:57, 11.20it/s]

SDR: 5.2015, SI-SDR: 6.8250


 36%|███▋      | 364/1000 [00:44<01:12,  8.79it/s]

SDR: 10.6433, SI-SDR: 10.8627


 37%|███▋      | 368/1000 [00:45<01:26,  7.32it/s]

SDR: 7.7053, SI-SDR: 8.0963


 37%|███▋      | 370/1000 [00:46<01:52,  5.60it/s]

SDR: 8.2131, SI-SDR: 8.3369


 38%|███▊      | 375/1000 [00:48<02:00,  5.19it/s]

SDR: 4.6532, SI-SDR: 9.8059


 38%|███▊      | 381/1000 [00:49<02:01,  5.10it/s]

SDR: 3.7981, SI-SDR: 7.3202


 38%|███▊      | 383/1000 [00:50<02:32,  4.03it/s]

SDR: 2.4303, SI-SDR: 10.5563


 41%|████      | 410/1000 [00:51<00:59,  9.86it/s]

SDR: 3.2409, SI-SDR: 8.9240


 42%|████▏     | 416/1000 [00:52<01:09,  8.45it/s]

SDR: 7.2236, SI-SDR: 8.0192


 42%|████▏     | 417/1000 [00:53<01:34,  6.16it/s]

SDR: 5.3728, SI-SDR: 7.6401


 42%|████▏     | 420/1000 [00:55<01:54,  5.08it/s]

SDR: 10.7778, SI-SDR: 10.6648


 42%|████▏     | 422/1000 [00:56<02:20,  4.12it/s]

SDR: 7.5326, SI-SDR: 6.9604


 43%|████▎     | 430/1000 [00:57<01:55,  4.95it/s]

SDR: 3.1062, SI-SDR: 8.9653


 44%|████▍     | 439/1000 [00:58<01:36,  5.79it/s]

SDR: 5.7422, SI-SDR: 7.9748


 47%|████▋     | 468/1000 [00:59<00:46, 11.38it/s]

SDR: 4.3212, SI-SDR: 7.8734


 50%|████▉     | 495/1000 [01:01<00:38, 13.22it/s]

SDR: 5.2074, SI-SDR: 8.2176


 54%|█████▍    | 538/1000 [01:02<00:22, 20.24it/s]

SDR: 8.5997, SI-SDR: 8.3222


 57%|█████▋    | 569/1000 [01:03<00:19, 22.45it/s]

SDR: 4.4769, SI-SDR: 7.1197


 60%|██████    | 600/1000 [01:04<00:16, 23.69it/s]

SDR: 11.5247, SI-SDR: 12.1940


 62%|██████▏   | 622/1000 [01:05<00:16, 22.64it/s]

SDR: 6.3759, SI-SDR: 7.2559


 64%|██████▎   | 635/1000 [01:06<00:18, 19.52it/s]

SDR: 5.4995, SI-SDR: 6.8391


 68%|██████▊   | 681/1000 [01:08<00:12, 24.63it/s]

SDR: 5.0371, SI-SDR: 6.8078


 68%|██████▊   | 685/1000 [01:09<00:17, 18.43it/s]

SDR: 7.2202, SI-SDR: 6.8394


 72%|███████▏  | 723/1000 [01:10<00:12, 22.72it/s]

SDR: 5.9718, SI-SDR: 7.1236


 79%|███████▉  | 792/1000 [01:11<00:06, 34.09it/s]

SDR: 7.2563, SI-SDR: 8.0287


 81%|████████  | 811/1000 [01:12<00:06, 28.50it/s]

SDR: 9.6704, SI-SDR: 9.2124


 87%|████████▋ | 869/1000 [01:13<00:03, 35.96it/s]

SDR: 5.6266, SI-SDR: 7.2561


 87%|████████▋ | 873/1000 [01:15<00:04, 26.20it/s]

SDR: 9.5741, SI-SDR: 9.3903


 88%|████████▊ | 877/1000 [01:16<00:06, 19.76it/s]

SDR: 1.9100, SI-SDR: 13.1918


 88%|████████▊ | 883/1000 [01:17<00:07, 15.21it/s]

SDR: 9.7092, SI-SDR: 10.5889


 89%|████████▉ | 890/1000 [01:18<00:09, 11.95it/s]

SDR: 9.4843, SI-SDR: 12.1320


 90%|████████▉ | 899/1000 [01:19<00:09, 10.46it/s]

SDR: 5.0685, SI-SDR: 9.7111


100%|██████████| 1000/1000 [01:20<00:00, 12.35it/s]


In [2]:
from src.utils import calculate_sdr, calculate_sisdr

sdr = calculate_sdr(wav_sep, wav_gt)
sisdr = calculate_sisdr(wav_sep, wav_gt)
print(f'SDR: {sdr:.4f}, SI-SDR: {sisdr:.4f}')

SDR: 1.8818, SI-SDR: 13.3102


In [6]:
sdr = calculate_sdr(_wav, _wav)
sisdr = calculate_sisdr(_wav, _wav)
print(f'SDR: {sdr:.4f}, SI-SDR: {sisdr:.4f}')

SDR: 77.1523, SI-SDR: 98.5334


In [7]:
import numpy as np
from scipy.signal import correlate

def align_wav_signals(wav1, wav2):
    assert wav1.shape == wav2.shape, "두 신호의 길이가 같아야 합니다."
    
    # 1D 배열로 변환
    wav1, wav2 = wav1.flatten(), wav2.flatten()
    
    # Cross-Correlation 계산
    correlation = correlate(wav1, wav2, mode="full")
    shift = correlation.argmax() - (len(wav1) - 1)  # 최대 상관점의 오프셋

    # shift 만큼 이동 (앞뒤 zero-padding 적용)
    if shift > 0:
        aligned_wav2 = np.pad(wav2[shift:], (0, shift), mode="constant")
    elif shift < 0:
        aligned_wav2 = np.pad(wav2[:shift], (-shift, 0), mode="constant")
    else:
        aligned_wav2 = wav2  # 이미 정렬됨

    return aligned_wav2, shift

aligned_wav2, applied_shift = align_wav_signals(wav, _wav)
sdr = calculate_sdr(wav, aligned_wav2)
sisdr = calculate_sisdr(wav, aligned_wav2)
print(f"적용된 시간 shift: {applied_shift} samples")
print(f'SDR: {sdr:.4f}, SI-SDR: {sisdr:.4f}')


적용된 시간 shift: 0 samples
SDR: 34.4945, SI-SDR: 49.7118
