<a href="https://colab.research.google.com/github/varshacvenkat-web/Varsha-Venkatapathy-Engineering-Portfolio-/blob/main/Article_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
from pystoi import stoi
from pypesq import pesq
from torch.utils.data import TensorDataset, DataLoader

# ── Constants ─────────────────────────────────────────────────────────────
SAMPLE_RATE = 8000
FRAME_SIZE  = 256
OVERLAP     = 128
EPSILON     = 1e-10
N_ITER      = 32
WINDOW_SIZE = 9
N_BINS      = FRAME_SIZE // 2 + 1  # 129

# ── Reconstruction Helpers ────────────────────────────────────────────────
def reconstruct_waveform_from_logpower(lp,
                                       n_fft=FRAME_SIZE,
                                       hop_length=OVERLAP,
                                       n_iter=N_ITER):
    lp    = np.nan_to_num(lp)
    power = np.exp(lp) - EPSILON
    mag   = np.sqrt(np.maximum(power, 0))
    return librosa.griffinlim(
        mag, n_iter=n_iter,
        n_fft=n_fft, hop_length=hop_length
    )

def reconstruct_with_noisy_phase(lp, noisy_phase,
                                 hop_length=OVERLAP):
    power = np.exp(np.nan_to_num(lp)) - EPSILON
    mag   = np.sqrt(np.clip(power, 0, None))
    S     = mag * np.exp(1j * noisy_phase)
    return librosa.istft(S, hop_length=hop_length)

# ── Data Loading ──────────────────────────────────────────────────────────
def load_article2_dataset_from_merged(merged_dir,
                                      window_size=WINDOW_SIZE):
    Xs, Ys = [], []
    for root, _, files in os.walk(merged_dir):
        for cf in files:
            if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
                continue
            cpath = os.path.join(root, cf)
            npath = cpath.replace('_clean_logpower_', '_noisy_concat_')
            if not os.path.exists(npath):
                continue
            C = np.load(cpath)  # [129, T]
            N = np.load(npath)
            if C.shape[1] != N.shape[1]:
                continue
            T = C.shape[1]
            for i in range(window_size - 1, T):
                Xs.append(N[:, i-window_size+1:i+1])
                Ys.append(C[:, i])
    X = np.array(Xs, dtype=np.float32)[..., np.newaxis]
    Y = np.array(Ys, dtype=np.float32)
    return X, Y

# ── Normalization ────────────────────────────────────────────────────────
def normalize_data(X, Y):
    Xm, Xs = X.mean(), X.std()
    Ym, Ys = Y.mean(), Y.std()
    Xn     = (X - Xm) / Xs
    Yn     = (Y - Ym) / Ys
    return Xn, Yn, Xm, Xs, Ym, Ys

# ── Model Definition ─────────────────────────────────────────────────────
class Article2CNN(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 129, kernel_size=(5,1), padding=(2,0))
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(129, 43, kernel_size=(5,1),
                               stride=(3,1), padding=(2,0))
        self.relu2 = nn.ReLU()
        D, T      = input_shape
        dummy     = torch.zeros(1,1,D,T)
        flat      = self.relu2(
                        self.conv2(
                          self.relu1(
                            self.conv1(dummy)
                          )
                        )
                      ).view(1,-1).size(1)
        self.fc1   = nn.Linear(flat, 1024)
        self.relu3 = nn.ReLU()
        self.drop  = nn.Dropout(0.2)
        self.fc2   = nn.Linear(1024, N_BINS)

    def forward(self, x):
        x = self.relu2(self.conv2(self.relu1(self.conv1(x))))
        x = x.view(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.drop(x)
        return self.fc2(x)

# ── Evaluation (full‑utterance PESQ/STOI) ─────────────────────────────────
def evaluate_on_validation(model,
                           train_dir, val_dir,
                           window_size=WINDOW_SIZE,
                           device='cpu',
                           return_avgs=False):
    # recompute train‑time stats
    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir, window_size)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)), replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]
    _,_, Xm, Xs, Ym, Ys = normalize_data(X_tr, Y_tr)

    # minimum length ≔ 1.5 seconds @ 8 kHz
    min_len = int(1.5 * SAMPLE_RATE)

    pesq_scores, stoi_scores = [], []

    for cf in os.listdir(val_dir):
        if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
            continue
        cpath = os.path.join(val_dir, cf)
        npath = cpath.replace('_clean_logpower_','_noisy_concat_')
        if not os.path.exists(npath):
            continue

        C = np.nan_to_num(np.load(cpath))  # [129,T]
        N = np.nan_to_num(np.load(npath))
        T = C.shape[1]
        if T < window_size:
            continue

        # reconstruct clean
        clean_wav = reconstruct_waveform_from_logpower(C)
        if len(clean_wav) < min_len or np.std(clean_wav) < 1e-4:
            continue

        # sliding-window predict
        preds = []
        for i in range(window_size - 1, T):
            win = N[:, i-window_size+1 : i+1]
            win = (win - Xm) / Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        Pn = np.stack(preds, axis=1)
        if Pn.shape[1] < T:
            Pn = np.pad(Pn, ((0,0),(0,T-Pn.shape[1])), mode='edge')
        else:
            Pn = Pn[:,:T]
        P = (Pn * Ys) + Ym

        # reconstruct noisy-phase baseline
        noisy_wav   = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(noisy_wav,
                                            n_fft=FRAME_SIZE,
                                            hop_length=OVERLAP))[:N_BINS]

        # reconstruct enhanced
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)
        if len(enhanced_wav) < min_len or np.std(enhanced_wav) < 1e-4:
            continue

        # force same length
        L = min(len(clean_wav), len(enhanced_wav))
        clean_wav    = clean_wav[:L]
        enhanced_wav = enhanced_wav[:L]

        # normalize both into [-1,1]
        clean_wav    /= (np.max(np.abs(clean_wav))    + 1e-12)
        enhanced_wav /= (np.max(np.abs(enhanced_wav)) + 1e-12)

        # score
        try:
            p = pesq(clean_wav, enhanced_wav, SAMPLE_RATE, 'nb')
        except:
            p = np.nan
        try:
            s = stoi(clean_wav, enhanced_wav, SAMPLE_RATE, False)
        except:
            s = np.nan

        pesq_scores.append(p)
        stoi_scores.append(s)

    avg_pesq = np.nanmean(pesq_scores)
    avg_stoi = np.nanmean(stoi_scores)
    if return_avgs:
        return avg_pesq, avg_stoi

    print(f"\nValidation on {len(pesq_scores)} utts → "
          f"PESQ {avg_pesq:.3f}, STOI {avg_stoi:.3f}")

# ── Training Loop ─────────────────────────────────────────────────────────
def train_and_validate(model,
                       train_loader, val_loader,
                       train_dir, val_dir,
                       Xm, Xs, Ym, Ys,
                       num_epochs=50,
                       device='cpu',
                       patience=5):
    crit = nn.MSELoss()
    opt  = torch.optim.Adam(model.parameters(), lr=1e-5)
    best, stale = float('inf'), 0
    train_mse, val_mse, val_pesq, val_stoi = [],[],[],[]

    for ep in range(1, num_epochs+1):
        model.train()
        t_loss = 0.0
        for bx, by in train_loader:
            bx,by = bx.to(device), by.to(device)
            opt.zero_grad()
            out   = model(bx)
            loss  = crit(out, by)
            loss.backward()
            opt.step()
            t_loss += loss.item()*bx.size(0)
        t_loss /= len(train_loader.dataset)
        train_mse.append(t_loss)

        model.eval()
        v_loss = 0.0
        with torch.no_grad():
            for vx, vy in val_loader:
                vx,vy = vx.to(device), vy.to(device)
                vo    = model(vx)
                v_loss+= crit(vo,vy).item()*vx.size(0)
        v_loss /= len(val_loader.dataset)
        val_mse.append(v_loss)

        p,s = evaluate_on_validation(
                  model, train_dir, val_dir,
                  WINDOW_SIZE, device, True
              )
        val_pesq.append(p)
        val_stoi.append(s)

        print(f"Epoch {ep:02d}: TrainMSE={t_loss:.4f}  "
              f"ValMSE={v_loss:.4f}  PESQ={p:.3f}  STOI={s:.3f}")

        if v_loss < best:
            best, stale = v_loss, 0
        else:
            stale += 1
            if stale >= patience:
                print("Early stopping.")
                break

    return model, train_mse, val_mse, val_pesq, val_stoi

# ── Plotting ──────────────────────────────────────────────────────────────
def plot_metrics(train_mse, val_mse, val_pesq, val_stoi):
    ep = np.arange(1, len(train_mse)+1)
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(ep, train_mse, '-o', label='Train MSE')
    plt.plot(ep, val_mse,   '-o', label='Val   MSE')
    plt.xlabel('Epoch'); plt.ylabel('MSE')
    plt.legend(); plt.title('Loss')
    plt.subplot(1,2,2)
    plt.plot(ep, val_pesq,  '-o', label='Val PESQ')
    plt.plot(ep, val_stoi,  '-o', label='Val STOI')
    plt.xlabel('Epoch'); plt.legend(); plt.title('Quality')
    plt.tight_layout(); plt.show()

# ── Waveform Plots ───────────────────────────────────────────────────────
def plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device='cpu'):
    files   = [f for f in os.listdir(val_dir)
               if f.endswith('.npy') and '_clean_logpower_' in f]
    samples = random.sample(files, min(3, len(files)))
    plt.figure(figsize=(12,6))
    for idx, cf in enumerate(samples):
        C = np.load(os.path.join(val_dir, cf))
        N = np.load(os.path.join(
                val_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]
        preds = []
        for i in range(WINDOW_SIZE-1, T):
            win = (N[:, i-WINDOW_SIZE+1 : i+1] - Xm)/Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Ys) + Ym

        clean_wav    = reconstruct_waveform_from_logpower(C)
        noisy_wav    = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase  = np.angle(librosa.stft(
                          noisy_wav,
                          n_fft=FRAME_SIZE,
                          hop_length=OVERLAP
                        ))[:N_BINS]
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)

        t = np.arange(2000) / SAMPLE_RATE
        for col, wav, title in zip(
            range(3),
            [clean_wav, noisy_wav, enhanced_wav],
            ['Clean','Noisy','Denoised']
        ):
            plt.subplot(3,3, idx*3 + col + 1)
            plt.plot(t, wav[:2000])
            plt.title(title); plt.ylim(-1,1)
    plt.tight_layout(); plt.show()

# ── WAV Export ───────────────────────────────────────────────────────────
def save_enhanced_audio(model, data_dir, out_dir,
                        window_size=WINDOW_SIZE,
                        device='cpu', num_samples=5,
                        Y_mean=0.0, Y_std=1.0):
    os.makedirs(out_dir, exist_ok=True)
    allc = [f for f in os.listdir(data_dir)
            if f.endswith('.npy') and '_clean_logpower_' in f]
    for cf in random.sample(allc, min(num_samples, len(allc))):
        C = np.load(os.path.join(data_dir, cf))
        N = np.load(os.path.join(
                data_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]

        clean_wav = reconstruct_waveform_from_logpower(C)
        noisy_wav = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(
                        noisy_wav,
                        n_fft=FRAME_SIZE,
                        hop_length=OVERLAP
                       ))[:N_BINS]

        preds = []
        for i in range(window_size-1, T):
            win = (N[:, i-window_size+1 : i+1] - Y_mean)/Y_std
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Y_std) + Y_mean
        if P.shape[1] < T:
            P = np.pad(P, ((0,0),(0,T-P.shape[1])), mode='edge')

        enhanced_wav = reconstruct_with_noisy_phase(
                           P[:N_BINS], noisy_phase[:,:T]
                       )

        def norm(a):
            m = np.max(np.abs(a))
            return a/m if m>0 else a

        for tag, wav in zip(
            ['clean','noisy','enhanced'],
            [clean_wav, noisy_wav, enhanced_wav]
        ):
            wav = norm(wav)
            sf.write(
                os.path.join(out_dir,
                             f"{tag}_{cf.split('_clean_logpower_')[-1][:-4]}.wav"),
                wav, SAMPLE_RATE
            )
        print("Saved trio:", cf)

if __name__ == '__main__':
    train_dir = r'C:\Users\enhance\article2_merge\train'
    val_dir   = r'C:\Users\enhance\article2_merge\val'
    test_dir  = r'C:\Users\enhance\article2_merge\test'

    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)),
        replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]

    Xn_tr,Yn_tr,Xm,Xs,Ym,Ys = normalize_data(X_tr, Y_tr)

    Xt = torch.from_numpy(Xn_tr).permute(0,3,1,2)
    Yt = torch.from_numpy(Yn_tr)
    train_loader = DataLoader(
        TensorDataset(Xt, Yt),
        batch_size=4, shuffle=True
    )

    X_val, Y_val = load_article2_dataset_from_merged(val_dir)
    Xn_val = (X_val - Xm)/Xs
    Yn_val = (Y_val - Ym)/Ys
    Xv = torch.from_numpy(Xn_val).permute(0,3,1,2)
    Yv = torch.from_numpy(Yn_val)
    val_loader = DataLoader(
        TensorDataset(Xv, Yv),
        batch_size=4, shuffle=False
    )

    device = torch.device(
        'cuda' if torch.cuda.is_available() else 'cpu'
    )
    D, T = Xt.shape[2], Xt.shape[3]
    model = Article2CNN((D,T)).to(device)

    model, train_mse, val_mse, val_pesq, val_stoi = train_and_validate(
        model, train_loader, val_loader,
        train_dir, val_dir,
        Xm, Xs, Ym, Ys,
        num_epochs=50, device=device,
        patience=5
    )

    plot_metrics(train_mse, val_mse, val_pesq, val_stoi)
    plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device=device)

    save_enhanced_audio(
        model, test_dir,
        r'C:\Users\enhance\article2_test_wavs',
        device=device, num_samples=5,
        Y_mean=Ym, Y_std=Ys
    )

    print("✅ Done – check plots & test‑WAVs")


In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
from pystoi import stoi
from pesq import pesq as pesq_nb  # Changed from pypesq import pesq
from torch.utils.data import TensorDataset, DataLoader

def evaluate_on_validation(model,
                           train_dir, val_dir,
                           window_size=WINDOW_SIZE,
                           device='cpu',
                           return_avgs=False):
    # recompute train‑time stats
    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir, window_size)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)), replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]
    _,_, Xm, Xs, Ym, Ys = normalize_data(X_tr, Y_tr)

    # minimum length ≔ 1.5 seconds @ 8 kHz (adjusted for your data)
    min_len = int(1.5 * SAMPLE_RATE)

    pesq_scores, stoi_scores = [], []

    for cf in os.listdir(val_dir):
        if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
            continue
        cpath = os.path.join(val_dir, cf)
        npath = cpath.replace('_clean_logpower_','_noisy_concat_')
        if not os.path.exists(npath):
            continue

        C = np.nan_to_num(np.load(cpath))  # [129,T]
        N = np.nan_to_num(np.load(npath))
        T = C.shape[1]
        if T < window_size:
            continue

        # reconstruct clean
        clean_wav = reconstruct_waveform_from_logpower(C)
        if len(clean_wav) < min_len or np.std(clean_wav) < 1e-4:
            print(f"Skipping {cf} - too short or too quiet")
            continue

        # sliding-window predict
        preds = []
        for i in range(window_size - 1, T):
            win = N[:, i-window_size+1 : i+1]
            win = (win - Xm) / Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        Pn = np.stack(preds, axis=1)
        if Pn.shape[1] < T:
            Pn = np.pad(Pn, ((0,0),(0,T-Pn.shape[1])), mode='edge')
        else:
            Pn = Pn[:,:T]
        P = (Pn * Ys) + Ym

        # reconstruct noisy-phase baseline
        noisy_wav   = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(noisy_wav,
                                            n_fft=FRAME_SIZE,
                                            hop_length=OVERLAP))[:N_BINS]

        # reconstruct enhanced
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)
        if len(enhanced_wav) < min_len or np.std(enhanced_wav) < 1e-4:
            print(f"Skipping {cf} - enhanced audio too short or too quiet")
            continue

        # force same length and ensure minimum length
        L = min(len(clean_wav), len(enhanced_wav))
        if L < min_len:
            print(f"Skipping {cf} - final length {L} samples too short")
            continue

        clean_wav    = clean_wav[:L]
        enhanced_wav = enhanced_wav[:L]

        # normalize both into [-1,1]
        clean_max = np.max(np.abs(clean_wav))
        enh_max = np.max(np.abs(enhanced_wav))

        if clean_max < 1e-6 or enh_max < 1e-6:
            print(f"Skipping {cf} - normalization would be unstable")
            continue

        clean_wav    /= clean_max
        enhanced_wav /= enh_max

        # score
        try:
            # Try to compute PESQ even with shorter segments
            p = pesq_nb(SAMPLE_RATE, clean_wav, enhanced_wav, 'nb')
            s = stoi(clean_wav, enhanced_wav, SAMPLE_RATE, False)

            # Only append valid scores
            if not np.isnan(p) and not np.isnan(s):
                pesq_scores.append(p)
                stoi_scores.append(s)
            else:
                print(f"Skipping {cf} - got NaN scores")

        except Exception as e:
            print(f"Error processing {cf}: {str(e)}")
            continue

    # Only compute averages if we have valid scores
    if len(pesq_scores) > 0 and len(stoi_scores) > 0:
        avg_pesq = np.mean(pesq_scores)  # Changed from nanmean since we filtered NaNs
        avg_stoi = np.mean(stoi_scores)
    else:
        print("Warning: No valid scores were computed!")
        avg_pesq = float('nan')
        avg_stoi = float('nan')

    if return_avgs:
        return avg_pesq, avg_stoi

    print(f"\nValidation on {len(pesq_scores)} utts → "
          f"PESQ {avg_pesq:.3f}, STOI {avg_stoi:.3f}")

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
from pystoi import stoi
from pesq import pesq as pesq_nb  # Changed from pypesq import pesq
from torch.utils.data import TensorDataset, DataLoader

# ── Constants ─────────────────────────────────────────────────────────────
SAMPLE_RATE = 8000
FRAME_SIZE  = 256
OVERLAP     = 128
EPSILON     = 1e-10
N_ITER      = 32
WINDOW_SIZE = 9
N_BINS      = FRAME_SIZE // 2 + 1  # 129

# ── Reconstruction Helpers ────────────────────────────────────────────────
def reconstruct_waveform_from_logpower(lp,
                                       n_fft=FRAME_SIZE,
                                       hop_length=OVERLAP,
                                       n_iter=N_ITER):
    lp    = np.nan_to_num(lp)
    power = np.exp(lp) - EPSILON
    mag   = np.sqrt(np.maximum(power, 0))
    return librosa.griffinlim(
        mag, n_iter=n_iter,
        n_fft=n_fft, hop_length=hop_length
    )

def reconstruct_with_noisy_phase(lp, noisy_phase,
                                 hop_length=OVERLAP):
    power = np.exp(np.nan_to_num(lp)) - EPSILON
    mag   = np.sqrt(np.clip(power, 0, None))
    S     = mag * np.exp(1j * noisy_phase)
    return librosa.istft(S, hop_length=hop_length)

# ── Data Loading ──────────────────────────────────────────────────────────
def load_article2_dataset_from_merged(merged_dir,
                                      window_size=WINDOW_SIZE):
    Xs, Ys = [], []
    for root, _, files in os.walk(merged_dir):
        for cf in files:
            if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
                continue
            cpath = os.path.join(root, cf)
            npath = cpath.replace('_clean_logpower_', '_noisy_concat_')
            if not os.path.exists(npath):
                continue
            C = np.load(cpath)  # [129, T]
            N = np.load(npath)
            if C.shape[1] != N.shape[1]:
                continue
            T = C.shape[1]
            for i in range(window_size - 1, T):
                Xs.append(N[:, i-window_size+1:i+1])
                Ys.append(C[:, i])
    X = np.array(Xs, dtype=np.float32)[..., np.newaxis]
    Y = np.array(Ys, dtype=np.float32)
    return X, Y

# ── Normalization ────────────────────────────────────────────────────────
def normalize_data(X, Y):
    Xm, Xs = X.mean(), X.std()
    Ym, Ys = Y.mean(), Y.std()
    Xn     = (X - Xm) / Xs
    Yn     = (Y - Ym) / Ys
    return Xn, Yn, Xm, Xs, Ym, Ys

# ── Model Definition ─────────────────────────────────────────────────────
class Article2CNN(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 129, kernel_size=(5,1), padding=(2,0))
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(129, 43, kernel_size=(5,1),
                               stride=(3,1), padding=(2,0))
        self.relu2 = nn.ReLU()
        D, T      = input_shape
        dummy     = torch.zeros(1,1,D,T)
        flat      = self.relu2(
                        self.conv2(
                          self.relu1(
                            self.conv1(dummy)
                          )
                        )
                      ).view(1,-1).size(1)
        self.fc1   = nn.Linear(flat, 1024)
        self.relu3 = nn.ReLU()
        self.drop  = nn.Dropout(0.2)
        self.fc2   = nn.Linear(1024, N_BINS)

    def forward(self, x):
        x = self.relu2(self.conv2(self.relu1(self.conv1(x))))
        x = x.view(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.drop(x)
        return self.fc2(x)

# ── Evaluation (full‑utterance PESQ/STOI) ─────────────────────────────────
def evaluate_on_validation(model,
                           train_dir, val_dir,
                           window_size=WINDOW_SIZE,
                           device='cpu',
                           return_avgs=False):
    # recompute train‑time stats
    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir, window_size)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)), replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]
    _,_, Xm, Xs, Ym, Ys = normalize_data(X_tr, Y_tr)

    # minimum length ≔ 1.5 seconds @ 8 kHz (adjusted for your data)
    min_len = int(1.5 * SAMPLE_RATE)

    pesq_scores, stoi_scores = [], []

    for cf in os.listdir(val_dir):
        if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
            continue
        cpath = os.path.join(val_dir, cf)
        npath = cpath.replace('_clean_logpower_','_noisy_concat_')
        if not os.path.exists(npath):
            continue

        C = np.nan_to_num(np.load(cpath))  # [129,T]
        N = np.nan_to_num(np.load(npath))
        T = C.shape[1]
        if T < window_size:
            continue

        # reconstruct clean
        clean_wav = reconstruct_waveform_from_logpower(C)
        if len(clean_wav) < min_len or np.std(clean_wav) < 1e-4:
            print(f"Skipping {cf} - too short or too quiet")
            continue

        # sliding-window predict
        preds = []
        for i in range(window_size - 1, T):
            win = N[:, i-window_size+1 : i+1]
            win = (win - Xm) / Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        Pn = np.stack(preds, axis=1)
        if Pn.shape[1] < T:
            Pn = np.pad(Pn, ((0,0),(0,T-Pn.shape[1])), mode='edge')
        else:
            Pn = Pn[:,:T]
        P = (Pn * Ys) + Ym

        # reconstruct noisy-phase baseline
        noisy_wav   = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(noisy_wav,
                                            n_fft=FRAME_SIZE,
                                            hop_length=OVERLAP))[:N_BINS]

        # reconstruct enhanced
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)
        if len(enhanced_wav) < min_len or np.std(enhanced_wav) < 1e-4:
            print(f"Skipping {cf} - enhanced audio too short or too quiet")
            continue

        # force same length and ensure minimum length
        L = min(len(clean_wav), len(enhanced_wav))
        if L < min_len:
            print(f"Skipping {cf} - final length {L} samples too short")
            continue

        clean_wav    = clean_wav[:L]
        enhanced_wav = enhanced_wav[:L]

        # normalize both into [-1,1]
        clean_max = np.max(np.abs(clean_wav))
        enh_max = np.max(np.abs(enhanced_wav))

        if clean_max < 1e-6 or enh_max < 1e-6:
            print(f"Skipping {cf} - normalization would be unstable")
            continue

        clean_wav    /= clean_max
        enhanced_wav /= enh_max

        # score
        try:
            # Try to compute PESQ even with shorter segments
            p = pesq_nb(SAMPLE_RATE, clean_wav, enhanced_wav, 'nb')
            s = stoi(clean_wav, enhanced_wav, SAMPLE_RATE, False)

            # Only append valid scores
            if not np.isnan(p) and not np.isnan(s):
                pesq_scores.append(p)
                stoi_scores.append(s)
            else:
                print(f"Skipping {cf} - got NaN scores")

        except Exception as e:
            print(f"Error processing {cf}: {str(e)}")
            continue

    # Only compute averages if we have valid scores
    if len(pesq_scores) > 0 and len(stoi_scores) > 0:
        avg_pesq = np.mean(pesq_scores)  # Changed from nanmean since we filtered NaNs
        avg_stoi = np.mean(stoi_scores)
    else:
        print("Warning: No valid scores were computed!")
        avg_pesq = float('nan')
        avg_stoi = float('nan')

    if return_avgs:
        return avg_pesq, avg_stoi

    print(f"\nValidation on {len(pesq_scores)} utts → "
          f"PESQ {avg_pesq:.3f}, STOI {avg_stoi:.3f}")

# ── Training Loop ─────────────────────────────────────────────────────────
def train_and_validate(model,
                       train_loader, val_loader,
                       train_dir, val_dir,
                       Xm, Xs, Ym, Ys,
                       num_epochs=50,
                       device='cpu',
                       patience=5):
    crit = nn.MSELoss()
    opt  = torch.optim.Adam(model.parameters(), lr=1e-5)
    best, stale = float('inf'), 0
    train_mse, val_mse, val_pesq, val_stoi = [],[],[],[]

    for ep in range(1, num_epochs+1):
        model.train()
        t_loss = 0.0
        for bx, by in train_loader:
            bx,by = bx.to(device), by.to(device)
            opt.zero_grad()
            out   = model(bx)
            loss  = crit(out, by)
            loss.backward()
            opt.step()
            t_loss += loss.item()*bx.size(0)
        t_loss /= len(train_loader.dataset)
        train_mse.append(t_loss)

        model.eval()
        v_loss = 0.0
        with torch.no_grad():
            for vx, vy in val_loader:
                vx,vy = vx.to(device), vy.to(device)
                vo    = model(vx)
                v_loss+= crit(vo,vy).item()*vx.size(0)
        v_loss /= len(val_loader.dataset)
        val_mse.append(v_loss)

        p,s = evaluate_on_validation(
                  model, train_dir, val_dir,
                  WINDOW_SIZE, device, True
              )
        val_pesq.append(p)
        val_stoi.append(s)

        print(f"Epoch {ep:02d}: TrainMSE={t_loss:.4f}  "
              f"ValMSE={v_loss:.4f}  PESQ={p:.3f}  STOI={s:.3f}")

        if v_loss < best:
            best, stale = v_loss, 0
        else:
            stale += 1
            if stale >= patience:
                print("Early stopping.")
                break

    return model, train_mse, val_mse, val_pesq, val_stoi

# ── Plotting ──────────────────────────────────────────────────────────────
def plot_metrics(train_mse, val_mse, val_pesq, val_stoi):
    ep = np.arange(1, len(train_mse)+1)
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(ep, train_mse, '-o', label='Train MSE')
    plt.plot(ep, val_mse,   '-o', label='Val   MSE')
    plt.xlabel('Epoch'); plt.ylabel('MSE')
    plt.legend(); plt.title('Loss')
    plt.subplot(1,2,2)
    plt.plot(ep, val_pesq,  '-o', label='Val PESQ')
    plt.plot(ep, val_stoi,  '-o', label='Val STOI')
    plt.xlabel('Epoch'); plt.legend(); plt.title('Quality')
    plt.tight_layout(); plt.show()

# ── Waveform Plots ───────────────────────────────────────────────────────
def plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device='cpu'):
    files   = [f for f in os.listdir(val_dir)
               if f.endswith('.npy') and '_clean_logpower_' in f]
    samples = random.sample(files, min(3, len(files)))
    plt.figure(figsize=(12,6))
    for idx, cf in enumerate(samples):
        C = np.load(os.path.join(val_dir, cf))
        N = np.load(os.path.join(
                val_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]
        preds = []
        for i in range(WINDOW_SIZE-1, T):
            win = (N[:, i-WINDOW_SIZE+1 : i+1] - Xm)/Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Ys) + Ym

        clean_wav    = reconstruct_waveform_from_logpower(C)
        noisy_wav    = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase  = np.angle(librosa.stft(
                          noisy_wav,
                          n_fft=FRAME_SIZE,
                          hop_length=OVERLAP
                        ))[:N_BINS]
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)

        t = np.arange(2000) / SAMPLE_RATE
        for col, wav, title in zip(
            range(3),
            [clean_wav, noisy_wav, enhanced_wav],
            ['Clean','Noisy','Denoised']
        ):
            plt.subplot(3,3, idx*3 + col + 1)
            plt.plot(t, wav[:2000])
            plt.title(title); plt.ylim(-1,1)
    plt.tight_layout(); plt.show()

# ── WAV Export ───────────────────────────────────────────────────────────
def save_enhanced_audio(model, data_dir, out_dir,
                        window_size=WINDOW_SIZE,
                        device='cpu', num_samples=5,
                        Y_mean=0.0, Y_std=1.0):
    os.makedirs(out_dir, exist_ok=True)
    allc = [f for f in os.listdir(data_dir)
            if f.endswith('.npy') and '_clean_logpower_' in f]
    for cf in random.sample(allc, min(num_samples, len(allc))):
        C = np.load(os.path.join(data_dir, cf))
        N = np.load(os.path.join(
                data_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]

        clean_wav = reconstruct_waveform_from_logpower(C)
        noisy_wav = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(
                        noisy_wav,
                        n_fft=FRAME_SIZE,
                        hop_length=OVERLAP
                       ))[:N_BINS]

        preds = []
        for i in range(window_size-1, T):
            win = (N[:, i-window_size+1 : i+1] - Y_mean)/Y_std
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Y_std) + Y_mean
        if P.shape[1] < T:
            P = np.pad(P, ((0,0),(0,T-P.shape[1])), mode='edge')

        enhanced_wav = reconstruct_with_noisy_phase(
                           P[:N_BINS], noisy_phase[:,:T]
                       )

        def norm(a):
            m = np.max(np.abs(a))
            return a/m if m>0 else a

        for tag, wav in zip(
            ['clean','noisy','enhanced'],
            [clean_wav, noisy_wav, enhanced_wav]
        ):
            wav = norm(wav)
            sf.write(
                os.path.join(out_dir,
                             f"{tag}_{cf.split('_clean_logpower_')[-1][:-4]}.wav"),
                wav, SAMPLE_RATE
            )
        print("Saved trio:", cf)

if __name__ == '__main__':
    train_dir = r'C:\Users\enhance\article2_merge\train'
    val_dir   = r'C:\Users\enhance\article2_merge\val'
    test_dir  = r'C:\Users\enhance\article2_merge\test'

    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)),
        replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]

    Xn_tr,Yn_tr,Xm,Xs,Ym,Ys = normalize_data(X_tr, Y_tr)

    Xt = torch.from_numpy(Xn_tr).permute(0,3,1,2)
    Yt = torch.from_numpy(Yn_tr)
    train_loader = DataLoader(
        TensorDataset(Xt, Yt),
        batch_size=4, shuffle=True
    )

    X_val, Y_val = load_article2_dataset_from_merged(val_dir)
    Xn_val = (X_val - Xm)/Xs
    Yn_val = (Y_val - Ym)/Ys
    Xv = torch.from_numpy(Xn_val).permute(0,3,1,2)
    Yv = torch.from_numpy(Yn_val)
    val_loader = DataLoader(
        TensorDataset(Xv, Yv),
        batch_size=4, shuffle=False
    )

    device = torch.device(
        'cuda' if torch.cuda.is_available() else 'cpu'
    )
    D, T = Xt.shape[2], Xt.shape[3]
    model = Article2CNN((D,T)).to(device)

    model, train_mse, val_mse, val_pesq, val_stoi = train_and_validate(
        model, train_loader, val_loader,
        train_dir, val_dir,
        Xm, Xs, Ym, Ys,
        num_epochs=15, device=device,
        patience=5
    )

    plot_metrics(train_mse, val_mse, val_pesq, val_stoi)
    plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device=device)

    save_enhanced_audio(
        model, test_dir,
        r'C:\Users\enhance\article2_test_wavs',
        device=device, num_samples=5,
        Y_mean=Ym, Y_std=Ys
    )

    print("✅ Done – check plots & test‑WAVs")


In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
from pystoi import stoi
from pesq import pesq as pesq_nb  # Changed from pypesq import pesq
from torch.utils.data import TensorDataset, DataLoader

# ── Constants ─────────────────────────────────────────────────────────────
SAMPLE_RATE = 8000
FRAME_SIZE  = 256
OVERLAP     = 128
EPSILON     = 1e-10
N_ITER      = 32
WINDOW_SIZE = 9
N_BINS      = FRAME_SIZE // 2 + 1  # 129

# ── Reconstruction Helpers ────────────────────────────────────────────────
def reconstruct_waveform_from_logpower(lp,
                                       n_fft=FRAME_SIZE,
                                       hop_length=OVERLAP,
                                       n_iter=N_ITER):
    lp    = np.nan_to_num(lp)
    power = np.exp(lp) - EPSILON
    mag   = np.sqrt(np.maximum(power, 0))
    return librosa.griffinlim(
        mag, n_iter=n_iter,
        n_fft=n_fft, hop_length=hop_length
    )

def reconstruct_with_noisy_phase(lp, noisy_phase,
                                 hop_length=OVERLAP):
    power = np.exp(np.nan_to_num(lp)) - EPSILON
    mag   = np.sqrt(np.clip(power, 0, None))
    S     = mag * np.exp(1j * noisy_phase)
    return librosa.istft(S, hop_length=hop_length)

# ── Data Loading ──────────────────────────────────────────────────────────
def load_article2_dataset_from_merged(merged_dir,
                                      window_size=WINDOW_SIZE):
    Xs, Ys = [], []
    for root, _, files in os.walk(merged_dir):
        for cf in files:
            if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
                continue
            cpath = os.path.join(root, cf)
            npath = cpath.replace('_clean_logpower_', '_noisy_concat_')
            if not os.path.exists(npath):
                continue
            C = np.load(cpath)  # [129, T]
            N = np.load(npath)
            if C.shape[1] != N.shape[1]:
                continue
            T = C.shape[1]
            for i in range(window_size - 1, T):
                Xs.append(N[:, i-window_size+1:i+1])
                Ys.append(C[:, i])
    X = np.array(Xs, dtype=np.float32)[..., np.newaxis]
    Y = np.array(Ys, dtype=np.float32)
    return X, Y

# ── Normalization ────────────────────────────────────────────────────────
def normalize_data(X, Y):
    Xm, Xs = X.mean(), X.std()
    Ym, Ys = Y.mean(), Y.std()
    Xn     = (X - Xm) / Xs
    Yn     = (Y - Ym) / Ys
    return Xn, Yn, Xm, Xs, Ym, Ys

# ── Model Definition ─────────────────────────────────────────────────────
class Article2CNN(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 129, kernel_size=(5,1), padding=(2,0))
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(129, 43, kernel_size=(5,1),
                               stride=(3,1), padding=(2,0))
        self.relu2 = nn.ReLU()
        D, T      = input_shape
        dummy     = torch.zeros(1,1,D,T)
        flat      = self.relu2(
                        self.conv2(
                          self.relu1(
                            self.conv1(dummy)
                          )
                        )
                      ).view(1,-1).size(1)
        self.fc1   = nn.Linear(flat, 1024)
        self.relu3 = nn.ReLU()
        self.drop  = nn.Dropout(0.2)
        self.fc2   = nn.Linear(1024, N_BINS)

    def forward(self, x):
        x = self.relu2(self.conv2(self.relu1(self.conv1(x))))
        x = x.view(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.drop(x)
        return self.fc2(x)

# ── Evaluation (full‑utterance PESQ/STOI) ─────────────────────────────────
def evaluate_on_validation(model,
                           train_dir, val_dir,
                           window_size=WINDOW_SIZE,
                           device='cpu',
                           return_avgs=False):
    # recompute train‑time stats
    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir, window_size)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)), replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]
    _,_, Xm, Xs, Ym, Ys = normalize_data(X_tr, Y_tr)

    # minimum length ≔ 1.5 seconds @ 8 kHz (adjusted for your data)
    min_len = int(1.5 * SAMPLE_RATE)

    pesq_scores, stoi_scores = [], []

    for cf in os.listdir(val_dir):
        if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
            continue
        cpath = os.path.join(val_dir, cf)
        npath = cpath.replace('_clean_logpower_','_noisy_concat_')
        if not os.path.exists(npath):
            continue

        C = np.nan_to_num(np.load(cpath))  # [129,T]
        N = np.nan_to_num(np.load(npath))
        T = C.shape[1]
        if T < window_size:
            continue

        # reconstruct clean
        clean_wav = reconstruct_waveform_from_logpower(C)
        if len(clean_wav) < min_len or np.std(clean_wav) < 1e-4:
            print(f"Skipping {cf} - too short or too quiet")
            continue

        # sliding-window predict
        preds = []
        for i in range(window_size - 1, T):
            win = N[:, i-window_size+1 : i+1]
            win = (win - Xm) / Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        Pn = np.stack(preds, axis=1)
        if Pn.shape[1] < T:
            Pn = np.pad(Pn, ((0,0),(0,T-Pn.shape[1])), mode='edge')
        else:
            Pn = Pn[:,:T]
        P = (Pn * Ys) + Ym

        # reconstruct noisy-phase baseline
        noisy_wav   = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(noisy_wav,
                                            n_fft=FRAME_SIZE,
                                            hop_length=OVERLAP))[:N_BINS]

        # reconstruct enhanced
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)
        if len(enhanced_wav) < min_len or np.std(enhanced_wav) < 1e-4:
            print(f"Skipping {cf} - enhanced audio too short or too quiet")
            continue

        # force same length and ensure minimum length
        L = min(len(clean_wav), len(enhanced_wav))
        if L < min_len:
            print(f"Skipping {cf} - final length {L} samples too short")
            continue

        clean_wav    = clean_wav[:L]
        enhanced_wav = enhanced_wav[:L]

        # normalize both into [-1,1]
        clean_max = np.max(np.abs(clean_wav))
        enh_max = np.max(np.abs(enhanced_wav))

        if clean_max < 1e-6 or enh_max < 1e-6:
            print(f"Skipping {cf} - normalization would be unstable")
            continue

        clean_wav    /= clean_max
        enhanced_wav /= enh_max

        # score
        try:
            # Try to compute PESQ even with shorter segments
            p = pesq_nb(SAMPLE_RATE, clean_wav, enhanced_wav, 'nb')
            s = stoi(clean_wav, enhanced_wav, SAMPLE_RATE, False)

            # Only append valid scores
            if not np.isnan(p) and not np.isnan(s):
                pesq_scores.append(p)
                stoi_scores.append(s)
            else:
                print(f"Skipping {cf} - got NaN scores")

        except Exception as e:
            print(f"Error processing {cf}: {str(e)}")
            continue

    # Only compute averages if we have valid scores
    if len(pesq_scores) > 0 and len(stoi_scores) > 0:
        avg_pesq = np.mean(pesq_scores)  # Changed from nanmean since we filtered NaNs
        avg_stoi = np.mean(stoi_scores)
    else:
        print("Warning: No valid scores were computed!")
        avg_pesq = float('nan')
        avg_stoi = float('nan')

    if return_avgs:
        return avg_pesq, avg_stoi

    print(f"\nValidation on {len(pesq_scores)} utts → "
          f"PESQ {avg_pesq:.3f}, STOI {avg_stoi:.3f}")

# ── Training Loop ─────────────────────────────────────────────────────────
def train_and_validate(model,
                       train_loader, val_loader,
                       train_dir, val_dir,
                       Xm, Xs, Ym, Ys,
                       num_epochs=50,
                       device='cpu',
                       patience=5):
    crit = nn.MSELoss()
    opt  = torch.optim.Adam(model.parameters(), lr=1e-5)
    best, stale = float('inf'), 0
    train_mse, val_mse, val_pesq, val_stoi = [],[],[],[]

    for ep in range(1, num_epochs+1):
        model.train()
        t_loss = 0.0
        for bx, by in train_loader:
            bx,by = bx.to(device), by.to(device)
            opt.zero_grad()
            out   = model(bx)
            loss  = crit(out, by)
            loss.backward()
            opt.step()
            t_loss += loss.item()*bx.size(0)
        t_loss /= len(train_loader.dataset)
        train_mse.append(t_loss)

        model.eval()
        v_loss = 0.0
        with torch.no_grad():
            for vx, vy in val_loader:
                vx,vy = vx.to(device), vy.to(device)
                vo    = model(vx)
                v_loss+= crit(vo,vy).item()*vx.size(0)
        v_loss /= len(val_loader.dataset)
        val_mse.append(v_loss)

        p,s = evaluate_on_validation(
                  model, train_dir, val_dir,
                  WINDOW_SIZE, device, True
              )
        val_pesq.append(p)
        val_stoi.append(s)

        print(f"Epoch {ep:02d}: TrainMSE={t_loss:.4f}  "
              f"ValMSE={v_loss:.4f}  PESQ={p:.3f}  STOI={s:.3f}")

        if v_loss < best:
            best, stale = v_loss, 0
        else:
            stale += 1
            if stale >= patience:
                print("Early stopping.")
                break

    return model, train_mse, val_mse, val_pesq, val_stoi

# ── Plotting ──────────────────────────────────────────────────────────────
def plot_metrics(train_mse, val_mse, val_pesq, val_stoi):
    ep = np.arange(1, len(train_mse)+1)
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(ep, train_mse, '-o', label='Train MSE')
    plt.plot(ep, val_mse,   '-o', label='Val   MSE')
    plt.xlabel('Epoch'); plt.ylabel('MSE')
    plt.legend(); plt.title('Loss')
    plt.subplot(1,2,2)
    plt.plot(ep, val_pesq,  '-o', label='Val PESQ')
    plt.plot(ep, val_stoi,  '-o', label='Val STOI')
    plt.xlabel('Epoch'); plt.legend(); plt.title('Quality')
    plt.tight_layout(); plt.show()

# ── Waveform Plots ───────────────────────────────────────────────────────
def plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device='cpu'):
    files   = [f for f in os.listdir(val_dir)
               if f.endswith('.npy') and '_clean_logpower_' in f]
    samples = random.sample(files, min(3, len(files)))
    plt.figure(figsize=(12,6))
    for idx, cf in enumerate(samples):
        C = np.load(os.path.join(val_dir, cf))
        N = np.load(os.path.join(
                val_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]
        preds = []
        for i in range(WINDOW_SIZE-1, T):
            win = (N[:, i-WINDOW_SIZE+1 : i+1] - Xm)/Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Ys) + Ym

        clean_wav    = reconstruct_waveform_from_logpower(C)
        noisy_wav    = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase  = np.angle(librosa.stft(
                          noisy_wav,
                          n_fft=FRAME_SIZE,
                          hop_length=OVERLAP
                        ))[:N_BINS]
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)

        t = np.arange(2000) / SAMPLE_RATE
        for col, wav, title in zip(
            range(3),
            [clean_wav, noisy_wav, enhanced_wav],
            ['Clean','Noisy','Denoised']
        ):
            plt.subplot(3,3, idx*3 + col + 1)
            plt.plot(t, wav[:2000])
            plt.title(title); plt.ylim(-1,1)
    plt.tight_layout(); plt.show()

# ── WAV Export ───────────────────────────────────────────────────────────
def save_enhanced_audio(model, data_dir, out_dir,
                        window_size=WINDOW_SIZE,
                        device='cpu', num_samples=5,
                        Y_mean=0.0, Y_std=1.0):
    os.makedirs(out_dir, exist_ok=True)
    allc = [f for f in os.listdir(data_dir)
            if f.endswith('.npy') and '_clean_logpower_' in f]
    for cf in random.sample(allc, min(num_samples, len(allc))):
        C = np.load(os.path.join(data_dir, cf))
        N = np.load(os.path.join(
                data_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]

        clean_wav = reconstruct_waveform_from_logpower(C)
        noisy_wav = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(
                        noisy_wav,
                        n_fft=FRAME_SIZE,
                        hop_length=OVERLAP
                       ))[:N_BINS]

        preds = []
        for i in range(window_size-1, T):
            win = (N[:, i-window_size+1 : i+1] - Y_mean)/Y_std
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Y_std) + Y_mean
        if P.shape[1] < T:
            P = np.pad(P, ((0,0),(0,T-P.shape[1])), mode='edge')

        enhanced_wav = reconstruct_with_noisy_phase(
                           P[:N_BINS], noisy_phase[:,:T]
                       )

        def norm(a):
            m = np.max(np.abs(a))
            return a/m if m>0 else a

        for tag, wav in zip(
            ['clean','noisy','enhanced'],
            [clean_wav, noisy_wav, enhanced_wav]
        ):
            wav = norm(wav)
            sf.write(
                os.path.join(out_dir,
                             f"{tag}_{cf.split('_clean_logpower_')[-1][:-4]}.wav"),
                wav, SAMPLE_RATE
            )
        print("Saved trio:", cf)

if __name__ == '__main__':
    train_dir = r'C:\Users\enhance\article2_merge\train'
    val_dir   = r'C:\Users\enhance\article2_merge\val'
    test_dir  = r'C:\Users\enhance\article2_merge\test'

    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)),
        replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]

    Xn_tr,Yn_tr,Xm,Xs,Ym,Ys = normalize_data(X_tr, Y_tr)

    Xt = torch.from_numpy(Xn_tr).permute(0,3,1,2)
    Yt = torch.from_numpy(Yn_tr)
    train_loader = DataLoader(
        TensorDataset(Xt, Yt),
        batch_size=4, shuffle=True
    )

    X_val, Y_val = load_article2_dataset_from_merged(val_dir)
    Xn_val = (X_val - Xm)/Xs
    Yn_val = (Y_val - Ym)/Ys
    Xv = torch.from_numpy(Xn_val).permute(0,3,1,2)
    Yv = torch.from_numpy(Yn_val)
    val_loader = DataLoader(
        TensorDataset(Xv, Yv),
        batch_size=4, shuffle=False
    )

    device = torch.device(
        'cuda' if torch.cuda.is_available() else 'cpu'
    )
    D, T = Xt.shape[2], Xt.shape[3]
    model = Article2CNN((D,T)).to(device)

    model, train_mse, val_mse, val_pesq, val_stoi = train_and_validate(
        model, train_loader, val_loader,
        train_dir, val_dir,
        Xm, Xs, Ym, Ys,
        num_epochs=15, device=device,
        patience=5
    )

    plot_metrics(train_mse, val_mse, val_pesq, val_stoi)
    plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device=device)

    save_enhanced_audio(
        model, test_dir,
        r'C:\Users\enhance\article2_test_wavs',
        device=device, num_samples=5,
        Y_mean=Ym, Y_std=Ys
    )

    print("✅ Done – check plots & test‑WAVs")


In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
from pystoi import stoi
from pesq import pesq as pesq_nb  # Changed from pypesq import pesq
from torch.utils.data import TensorDataset, DataLoader

# ── Constants ─────────────────────────────────────────────────────────────
SAMPLE_RATE = 8000
FRAME_SIZE  = 256
OVERLAP     = 128
EPSILON     = 1e-10
N_ITER      = 32
WINDOW_SIZE = 9
N_BINS      = FRAME_SIZE // 2 + 1  # 129

# ── Loss Functions ─────────────────────────────────────────────────────────
def envelope_loss(S1, S2):
    """Compute loss on temporal envelopes to help STOI"""
    env1 = torch.sum(torch.abs(S1), dim=1)
    env2 = torch.sum(torch.abs(S2), dim=1)
    env1 = env1 / (torch.max(env1) + 1e-8)
    env2 = env2 / (torch.max(env2) + 1e-8)
    return torch.mean((env1 - env2) ** 2)

def spectral_loss(S1, S2):
    """Compute loss on spectral structure to help PESQ"""
    return torch.mean(torch.abs(S1 - S2))

class EnhancementLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        # Convert from log to linear domain
        pred_linear = torch.exp(pred)
        target_linear = torch.exp(target)

        # Spectral loss (helps PESQ)
        spec_loss = spectral_loss(pred_linear, target_linear)

        # Envelope loss (helps STOI)
        env_loss = envelope_loss(pred_linear, target_linear)

        # Log-domain loss
        log_loss = torch.mean(torch.abs(pred - target))

        # Combine losses with weights
        total_loss = log_loss + 0.3 * spec_loss + 0.3 * env_loss

        return total_loss

# ── Reconstruction Helpers ────────────────────────────────────────────────
def reconstruct_waveform_from_logpower(lp,
                                       n_fft=FRAME_SIZE,
                                       hop_length=OVERLAP,
                                       n_iter=N_ITER):
    lp    = np.nan_to_num(lp)
    power = np.exp(lp) - EPSILON
    mag   = np.sqrt(np.maximum(power, 0))
    return librosa.griffinlim(
        mag, n_iter=n_iter,
        n_fft=n_fft, hop_length=hop_length
    )

def reconstruct_with_noisy_phase(lp, noisy_phase,
                                 hop_length=OVERLAP):
    power = np.exp(np.nan_to_num(lp)) - EPSILON
    mag   = np.sqrt(np.clip(power, 0, None))
    S     = mag * np.exp(1j * noisy_phase)
    return librosa.istft(S, hop_length=hop_length)

# ── Data Loading ──────────────────────────────────────────────────────────
def load_article2_dataset_from_merged(merged_dir,
                                      window_size=WINDOW_SIZE):
    Xs, Ys = [], []
    for root, _, files in os.walk(merged_dir):
        for cf in files:
            if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
                continue
            cpath = os.path.join(root, cf)
            npath = cpath.replace('_clean_logpower_', '_noisy_concat_')
            if not os.path.exists(npath):
                continue
            C = np.load(cpath)  # [129, T]
            N = np.load(npath)
            if C.shape[1] != N.shape[1]:
                continue
            T = C.shape[1]
            for i in range(window_size - 1, T):
                Xs.append(N[:, i-window_size+1:i+1])
                Ys.append(C[:, i])
    X = np.array(Xs, dtype=np.float32)[..., np.newaxis]
    Y = np.array(Ys, dtype=np.float32)
    return X, Y

# ── Normalization ────────────────────────────────────────────────────────
def normalize_data(X, Y):
    Xm, Xs = X.mean(), X.std()
    Ym, Ys = Y.mean(), Y.std()
    Xn     = (X - Xm) / Xs
    Yn     = (Y - Ym) / Ys
    return Xn, Yn, Xm, Xs, Ym, Ys

# ── Model Definition ─────────────────────────────────────────────────────
class Article2CNNWithAttention(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 129, kernel_size=(5,1), padding=(2,0))
        self.bn1 = nn.BatchNorm2d(129)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(129, 43, kernel_size=(5,1),
                              stride=(3,1), padding=(2,0))
        self.bn2 = nn.BatchNorm2d(43)
        self.relu2 = nn.ReLU()

        # Temporal attention
        self.attention = nn.Sequential(
            nn.Conv2d(43, 1, kernel_size=1),
            nn.Sigmoid()
        )

        D, T = input_shape
        dummy = torch.zeros(1,1,D,T)
        flat = self.relu2(
            self.bn2(
                self.conv2(
                    self.relu1(
                        self.bn1(
                            self.conv1(dummy)
                        )
                    )
                )
            )
        ).view(1,-1).size(1)

        self.fc1 = nn.Linear(flat, 1024)
        self.bn3 = nn.BatchNorm1d(1024)
        self.relu3 = nn.ReLU()
        self.drop = nn.Dropout(0.2)
        self.fc2 = nn.Linear(1024, N_BINS)

    def forward(self, x):
        # Convolutional layers with batch norm
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))

        # Apply temporal attention
        att = self.attention(x)
        x = x * att

        # Fully connected layers
        x = x.view(x.size(0), -1)
        x = self.relu3(self.bn3(self.fc1(x)))
        x = self.drop(x)
        return self.fc2(x)

# ── Evaluation (full‑utterance PESQ/STOI) ─────────────────────────────────
def evaluate_on_validation(model,
                           train_dir, val_dir,
                           window_size=WINDOW_SIZE,
                           device='cpu',
                           return_avgs=False):
    # recompute train‑time stats
    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir, window_size)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)), replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]
    _,_, Xm, Xs, Ym, Ys = normalize_data(X_tr, Y_tr)

    # minimum length ≔ 1.5 seconds @ 8 kHz (adjusted for your data)
    min_len = int(1.5 * SAMPLE_RATE)

    pesq_scores, stoi_scores = [], []

    for cf in os.listdir(val_dir):
        if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
            continue
        cpath = os.path.join(val_dir, cf)
        npath = cpath.replace('_clean_logpower_','_noisy_concat_')
        if not os.path.exists(npath):
            continue

        C = np.nan_to_num(np.load(cpath))  # [129,T]
        N = np.nan_to_num(np.load(npath))
        T = C.shape[1]
        if T < window_size:
            continue

        # reconstruct clean
        clean_wav = reconstruct_waveform_from_logpower(C)
        if len(clean_wav) < min_len or np.std(clean_wav) < 1e-4:
            print(f"Skipping {cf} - too short or too quiet")
            continue

        # sliding-window predict
        preds = []
        for i in range(window_size - 1, T):
            win = N[:, i-window_size+1 : i+1]
            win = (win - Xm) / Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        Pn = np.stack(preds, axis=1)
        if Pn.shape[1] < T:
            Pn = np.pad(Pn, ((0,0),(0,T-Pn.shape[1])), mode='edge')
        else:
            Pn = Pn[:,:T]
        P = (Pn * Ys) + Ym

        # reconstruct noisy-phase baseline
        noisy_wav   = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(noisy_wav,
                                            n_fft=FRAME_SIZE,
                                            hop_length=OVERLAP))[:N_BINS]

        # reconstruct enhanced
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)
        if len(enhanced_wav) < min_len or np.std(enhanced_wav) < 1e-4:
            print(f"Skipping {cf} - enhanced audio too short or too quiet")
            continue

        # force same length and ensure minimum length
        L = min(len(clean_wav), len(enhanced_wav))
        if L < min_len:
            print(f"Skipping {cf} - final length {L} samples too short")
            continue

        clean_wav    = clean_wav[:L]
        enhanced_wav = enhanced_wav[:L]

        # normalize both into [-1,1]
        clean_max = np.max(np.abs(clean_wav))
        enh_max = np.max(np.abs(enhanced_wav))

        if clean_max < 1e-6 or enh_max < 1e-6:
            print(f"Skipping {cf} - normalization would be unstable")
            continue

        clean_wav    /= clean_max
        enhanced_wav /= enh_max

        # score
        try:
            # Try to compute PESQ even with shorter segments
            p = pesq_nb(SAMPLE_RATE, clean_wav, enhanced_wav, 'nb')
            s = stoi(clean_wav, enhanced_wav, SAMPLE_RATE, False)

            # Only append valid scores
            if not np.isnan(p) and not np.isnan(s):
                pesq_scores.append(p)
                stoi_scores.append(s)
            else:
                print(f"Skipping {cf} - got NaN scores")

        except Exception as e:
            print(f"Error processing {cf}: {str(e)}")
            continue

    # Only compute averages if we have valid scores
    if len(pesq_scores) > 0 and len(stoi_scores) > 0:
        avg_pesq = np.mean(pesq_scores)  # Changed from nanmean since we filtered NaNs
        avg_stoi = np.mean(stoi_scores)
    else:
        print("Warning: No valid scores were computed!")
        avg_pesq = float('nan')
        avg_stoi = float('nan')

    if return_avgs:
        return avg_pesq, avg_stoi

    print(f"\nValidation on {len(pesq_scores)} utts → "
          f"PESQ {avg_pesq:.3f}, STOI {avg_stoi:.3f}")

# ── Training Loop ─────────────────────────────────────────────────────────
def train_and_validate(model,
                      train_loader, val_loader,
                      train_dir, val_dir,
                      Xm, Xs, Ym, Ys,
                      num_epochs=50,
                      device='cpu',
                      patience=5):
    # Use custom loss
    crit = EnhancementLoss()

    # Lower initial learning rate
    opt = torch.optim.Adam(model.parameters(), lr=5e-5)

    # Scheduler based on PESQ and STOI
    scheduler = torch.optim.ReduceLROnPlateau(
        opt, mode='max',
        factor=0.5,
        patience=2,
        verbose=True,
        threshold=1e-3
    )

    best_score = float('-inf')
    best_model_state = None
    stale = 0
    train_loss, val_loss, val_pesq, val_stoi = [], [], [], []

    for ep in range(1, num_epochs+1):
        # Training
        model.train()
        t_loss = 0.0
        for bx, by in train_loader:
            bx, by = bx.to(device), by.to(device)
            opt.zero_grad()
            out = model(bx)
            loss = crit(out, by)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            opt.step()
            t_loss += loss.item() * bx.size(0)
        t_loss /= len(train_loader.dataset)
        train_loss.append(t_loss)

        # Validation
        model.eval()
        v_loss = 0.0
        with torch.no_grad():
            for vx, vy in val_loader:
                vx, vy = vx.to(device), vy.to(device)
                vo = model(vx)
                v_loss += crit(vo, vy).item() * vx.size(0)
        v_loss /= len(val_loader.dataset)
        val_loss.append(v_loss)

        # Compute PESQ/STOI
        p, s = evaluate_on_validation(
            model, train_dir, val_dir,
            WINDOW_SIZE, device, True
        )
        val_pesq.append(p)
        val_stoi.append(s)

        print(f"Epoch {ep:02d}: TrainLoss={t_loss:.4f}  "
              f"ValLoss={v_loss:.4f}  PESQ={p:.3f}  STOI={s:.3f}")

        # Use combined metric for scheduling and model selection
        current_score = p + s  # Equal weight to PESQ and STOI
        scheduler.step(current_score)

        if current_score > best_score:
            best_score = current_score
            best_model_state = {
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'pesq': p,
                'stoi': s
            }
            stale = 0
        else:
            stale += 1
            if stale >= patience:
                print(f"Early stopping. Best PESQ: {best_model_state['pesq']:.3f}, "
                      f"Best STOI: {best_model_state['stoi']:.3f}")
                break

    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state['model_state_dict'])
        print(f"Restored best model from epoch {best_model_state['epoch']} "
              f"with PESQ={best_model_state['pesq']:.3f}, "
              f"STOI={best_model_state['stoi']:.3f}")

    return model, train_loss, val_loss, val_pesq, val_stoi

# ── Plotting ──────────────────────────────────────────────────────────────
def plot_metrics(train_loss, val_loss, val_pesq, val_stoi):
    ep = np.arange(1, len(train_loss)+1)
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(ep, train_loss, '-o', label='Train Loss')
    plt.plot(ep, val_loss,   '-o', label='Val Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss')
    plt.legend(); plt.title('Loss')
    plt.subplot(1,2,2)
    plt.plot(ep, val_pesq,  '-o', label='Val PESQ')
    plt.plot(ep, val_stoi,  '-o', label='Val STOI')
    plt.xlabel('Epoch'); plt.legend(); plt.title('Quality')
    plt.tight_layout(); plt.show()

# ── Waveform Plots ───────────────────────────────────────────────────────
def plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device='cpu'):
    files   = [f for f in os.listdir(val_dir)
               if f.endswith('.npy') and '_clean_logpower_' in f]
    samples = random.sample(files, min(3, len(files)))
    plt.figure(figsize=(12,6))
    for idx, cf in enumerate(samples):
        C = np.load(os.path.join(val_dir, cf))
        N = np.load(os.path.join(
                val_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]
        preds = []
        for i in range(WINDOW_SIZE-1, T):
            win = (N[:, i-WINDOW_SIZE+1 : i+1] - Xm)/Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Ys) + Ym

        clean_wav    = reconstruct_waveform_from_logpower(C)
        noisy_wav    = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase  = np.angle(librosa.stft(
                          noisy_wav,
                          n_fft=FRAME_SIZE,
                          hop_length=OVERLAP
                        ))[:N_BINS]
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)

        t = np.arange(2000) / SAMPLE_RATE
        for col, wav, title in zip(
            range(3),
            [clean_wav, noisy_wav, enhanced_wav],
            ['Clean','Noisy','Denoised']
        ):
            plt.subplot(3,3, idx*3 + col + 1)
            plt.plot(t, wav[:2000])
            plt.title(title); plt.ylim(-1,1)
    plt.tight_layout(); plt.show()

# ── WAV Export ───────────────────────────────────────────────────────────
def save_enhanced_audio(model, data_dir, out_dir,
                        window_size=WINDOW_SIZE,
                        device='cpu', num_samples=5,
                        Y_mean=0.0, Y_std=1.0):
    os.makedirs(out_dir, exist_ok=True)
    allc = [f for f in os.listdir(data_dir)
            if f.endswith('.npy') and '_clean_logpower_' in f]
    for cf in random.sample(allc, min(num_samples, len(allc))):
        C = np.load(os.path.join(data_dir, cf))
        N = np.load(os.path.join(
                data_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]

        clean_wav = reconstruct_waveform_from_logpower(C)
        noisy_wav = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(
                        noisy_wav,
                        n_fft=FRAME_SIZE,
                        hop_length=OVERLAP
                       ))[:N_BINS]

        preds = []
        for i in range(window_size-1, T):
            win = (N[:, i-window_size+1 : i+1] - Y_mean)/Y_std
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Y_std) + Y_mean
        if P.shape[1] < T:
            P = np.pad(P, ((0,0),(0,T-P.shape[1])), mode='edge')

        enhanced_wav = reconstruct_with_noisy_phase(
                           P[:N_BINS], noisy_phase[:,:T]
                       )

        def norm(a):
            m = np.max(np.abs(a))
            return a/m if m>0 else a

        for tag, wav in zip(
            ['clean','noisy','enhanced'],
            [clean_wav, noisy_wav, enhanced_wav]
        ):
            wav = norm(wav)
            sf.write(
                os.path.join(out_dir,
                             f"{tag}_{cf.split('_clean_logpower_')[-1][:-4]}.wav"),
                wav, SAMPLE_RATE
            )
        print("Saved trio:", cf)

if __name__ == '__main__':
    train_dir = r'C:\Users\enhance\article2_merge\train'
    val_dir   = r'C:\Users\enhance\article2_merge\val'
    test_dir  = r'C:\Users\enhance\article2_merge\test'

    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)),
        replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]

    Xn_tr,Yn_tr,Xm,Xs,Ym,Ys = normalize_data(X_tr, Y_tr)

    Xt = torch.from_numpy(Xn_tr).permute(0,3,1,2)
    Yt = torch.from_numpy(Yn_tr)
    train_loader = DataLoader(
        TensorDataset(Xt, Yt),
        batch_size=4, shuffle=True
    )

    X_val, Y_val = load_article2_dataset_from_merged(val_dir)
    Xn_val = (X_val - Xm)/Xs
    Yn_val = (Y_val - Ym)/Ys
    Xv = torch.from_numpy(Xn_val).permute(0,3,1,2)
    Yv = torch.from_numpy(Yn_val)
    val_loader = DataLoader(
        TensorDataset(Xv, Yv),
        batch_size=4, shuffle=False
    )

    device = torch.device(
        'cuda' if torch.cuda.is_available() else 'cpu'
    )
    D, T = Xt.shape[2], Xt.shape[3]
    # Use the new model with attention
    model = Article2CNNWithAttention((D,T)).to(device)

    model, train_loss, val_loss, val_pesq, val_stoi = train_and_validate(
        model, train_loader, val_loader,
        train_dir, val_dir,
        Xm, Xs, Ym, Ys,
        num_epochs=15, device=device,
        patience=5
    )

    plot_metrics(train_loss, val_loss, val_pesq, val_stoi)
    plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device=device)

    save_enhanced_audio(
        model, test_dir,
        r'C:\Users\enhance\article2_test_wavs',
        device=device, num_samples=5,
        Y_mean=Ym, Y_std=Ys
    )

    print("✅ Done – check plots & test‑WAVs")

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
from pystoi import stoi
from pesq import pesq as pesq_nb  # Changed from pypesq import pesq
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

# ── Constants ─────────────────────────────────────────────────────────────
SAMPLE_RATE = 8000
FRAME_SIZE  = 256
OVERLAP     = 128
EPSILON     = 1e-10
N_ITER      = 32
WINDOW_SIZE = 9
N_BINS      = FRAME_SIZE // 2 + 1  # 129

# ── Loss Functions ─────────────────────────────────────────────────────────
def envelope_loss(S1, S2):
    """Compute loss on temporal envelopes to help STOI"""
    env1 = torch.sum(torch.abs(S1), dim=1)
    env2 = torch.sum(torch.abs(S2), dim=1)
    env1 = env1 / (torch.max(env1) + 1e-8)
    env2 = env2 / (torch.max(env2) + 1e-8)
    return torch.mean((env1 - env2) ** 2)

def spectral_loss(S1, S2):
    """Compute loss on spectral structure to help PESQ"""
    return torch.mean(torch.abs(S1 - S2))

class EnhancementLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        # Convert from log to linear domain
        pred_linear = torch.exp(pred)
        target_linear = torch.exp(target)

        # Spectral loss (helps PESQ)
        spec_loss = spectral_loss(pred_linear, target_linear)

        # Envelope loss (helps STOI)
        env_loss = envelope_loss(pred_linear, target_linear)

        # Log-domain loss
        log_loss = torch.mean(torch.abs(pred - target))

        # Combine losses with weights
        total_loss = log_loss + 0.3 * spec_loss + 0.3 * env_loss

        return total_loss

# ── Reconstruction Helpers ────────────────────────────────────────────────
def reconstruct_waveform_from_logpower(lp,
                                       n_fft=FRAME_SIZE,
                                       hop_length=OVERLAP,
                                       n_iter=N_ITER):
    lp    = np.nan_to_num(lp)
    power = np.exp(lp) - EPSILON
    mag   = np.sqrt(np.maximum(power, 0))
    return librosa.griffinlim(
        mag, n_iter=n_iter,
        n_fft=n_fft, hop_length=hop_length
    )

def reconstruct_with_noisy_phase(lp, noisy_phase,
                                 hop_length=OVERLAP):
    power = np.exp(np.nan_to_num(lp)) - EPSILON
    mag   = np.sqrt(np.clip(power, 0, None))
    S     = mag * np.exp(1j * noisy_phase)
    return librosa.istft(S, hop_length=hop_length)

# ── Data Loading ──────────────────────────────────────────────────────────
def load_article2_dataset_from_merged(merged_dir,
                                      window_size=WINDOW_SIZE):
    Xs, Ys = [], []
    for root, _, files in os.walk(merged_dir):
        for cf in files:
            if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
                continue
            cpath = os.path.join(root, cf)
            npath = cpath.replace('_clean_logpower_', '_noisy_concat_')
            if not os.path.exists(npath):
                continue
            C = np.load(cpath)  # [129, T]
            N = np.load(npath)
            if C.shape[1] != N.shape[1]:
                continue
            T = C.shape[1]
            for i in range(window_size - 1, T):
                Xs.append(N[:, i-window_size+1:i+1])
                Ys.append(C[:, i])
    X = np.array(Xs, dtype=np.float32)[..., np.newaxis]
    Y = np.array(Ys, dtype=np.float32)
    return X, Y

# ── Normalization ────────────────────────────────────────────────────────
def normalize_data(X, Y):
    Xm, Xs = X.mean(), X.std()
    Ym, Ys = Y.mean(), Y.std()
    Xn     = (X - Xm) / Xs
    Yn     = (Y - Ym) / Ys
    return Xn, Yn, Xm, Xs, Ym, Ys

# ── Model Definition ─────────────────────────────────────────────────────
class Article2CNNWithAttention(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 129, kernel_size=(5,1), padding=(2,0))
        self.bn1 = nn.BatchNorm2d(129)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(129, 43, kernel_size=(5,1),
                              stride=(3,1), padding=(2,0))
        self.bn2 = nn.BatchNorm2d(43)
        self.relu2 = nn.ReLU()

        # Temporal attention
        self.attention = nn.Sequential(
            nn.Conv2d(43, 1, kernel_size=1),
            nn.Sigmoid()
        )

        D, T = input_shape
        dummy = torch.zeros(1,1,D,T)
        flat = self.relu2(
            self.bn2(
                self.conv2(
                    self.relu1(
                        self.bn1(
                            self.conv1(dummy)
                        )
                    )
                )
            )
        ).view(1,-1).size(1)

        self.fc1 = nn.Linear(flat, 1024)
        self.bn3 = nn.BatchNorm1d(1024)
        self.relu3 = nn.ReLU()
        self.drop = nn.Dropout(0.2)
        self.fc2 = nn.Linear(1024, N_BINS)

    def forward(self, x):
        # Convolutional layers with batch norm
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))

        # Apply temporal attention
        att = self.attention(x)
        x = x * att

        # Fully connected layers
        x = x.view(x.size(0), -1)
        x = self.relu3(self.bn3(self.fc1(x)))
        x = self.drop(x)
        return self.fc2(x)

# ── Evaluation (full‑utterance PESQ/STOI) ─────────────────────────────────
def evaluate_on_validation(model,
                           train_dir, val_dir,
                           window_size=WINDOW_SIZE,
                           device='cpu',
                           return_avgs=False):
    # recompute train‑time stats
    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir, window_size)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)), replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]
    _,_, Xm, Xs, Ym, Ys = normalize_data(X_tr, Y_tr)

    # minimum length ≔ 1.5 seconds @ 8 kHz (adjusted for your data)
    min_len = int(1.5 * SAMPLE_RATE)

    pesq_scores, stoi_scores = [], []

    for cf in os.listdir(val_dir):
        if not (cf.endswith('.npy') and '_clean_logpower_' in cf):
            continue
        cpath = os.path.join(val_dir, cf)
        npath = cpath.replace('_clean_logpower_','_noisy_concat_')
        if not os.path.exists(npath):
            continue

        C = np.nan_to_num(np.load(cpath))  # [129,T]
        N = np.nan_to_num(np.load(npath))
        T = C.shape[1]
        if T < window_size:
            continue

        # reconstruct clean
        clean_wav = reconstruct_waveform_from_logpower(C)
        if len(clean_wav) < min_len or np.std(clean_wav) < 1e-4:
            print(f"Skipping {cf} - too short or too quiet")
            continue

        # sliding-window predict
        preds = []
        for i in range(window_size - 1, T):
            win = N[:, i-window_size+1 : i+1]
            win = (win - Xm) / Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        Pn = np.stack(preds, axis=1)
        if Pn.shape[1] < T:
            Pn = np.pad(Pn, ((0,0),(0,T-Pn.shape[1])), mode='edge')
        else:
            Pn = Pn[:,:T]
        P = (Pn * Ys) + Ym

        # reconstruct noisy-phase baseline
        noisy_wav   = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(noisy_wav,
                                            n_fft=FRAME_SIZE,
                                            hop_length=OVERLAP))[:N_BINS]

        # reconstruct enhanced
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)
        if len(enhanced_wav) < min_len or np.std(enhanced_wav) < 1e-4:
            print(f"Skipping {cf} - enhanced audio too short or too quiet")
            continue

        # force same length and ensure minimum length
        L = min(len(clean_wav), len(enhanced_wav))
        if L < min_len:
            print(f"Skipping {cf} - final length {L} samples too short")
            continue

        clean_wav    = clean_wav[:L]
        enhanced_wav = enhanced_wav[:L]

        # normalize both into [-1,1]
        clean_max = np.max(np.abs(clean_wav))
        enh_max = np.max(np.abs(enhanced_wav))

        if clean_max < 1e-6 or enh_max < 1e-6:
            print(f"Skipping {cf} - normalization would be unstable")
            continue

        clean_wav    /= clean_max
        enhanced_wav /= enh_max

        # score
        try:
            # Try to compute PESQ even with shorter segments
            p = pesq_nb(SAMPLE_RATE, clean_wav, enhanced_wav, 'nb')
            s = stoi(clean_wav, enhanced_wav, SAMPLE_RATE, False)

            # Only append valid scores
            if not np.isnan(p) and not np.isnan(s):
                pesq_scores.append(p)
                stoi_scores.append(s)
            else:
                print(f"Skipping {cf} - got NaN scores")

        except Exception as e:
            print(f"Error processing {cf}: {str(e)}")
            continue

    # Only compute averages if we have valid scores
    if len(pesq_scores) > 0 and len(stoi_scores) > 0:
        avg_pesq = np.mean(pesq_scores)  # Changed from nanmean since we filtered NaNs
        avg_stoi = np.mean(stoi_scores)
    else:
        print("Warning: No valid scores were computed!")
        avg_pesq = float('nan')
        avg_stoi = float('nan')

    if return_avgs:
        return avg_pesq, avg_stoi

    print(f"\nValidation on {len(pesq_scores)} utts → "
          f"PESQ {avg_pesq:.3f}, STOI {avg_stoi:.3f}")

# ── Training Loop ─────────────────────────────────────────────────────────
def train_and_validate(model,
                      train_loader, val_loader,
                      train_dir, val_dir,
                      Xm, Xs, Ym, Ys,
                      num_epochs=50,
                      device='cpu',
                      patience=5):
    # Use custom loss
    crit = EnhancementLoss()

    # Lower initial learning rate
    opt = torch.optim.Adam(model.parameters(), lr=5e-5)

    # Scheduler based on PESQ and STOI
    scheduler = ReduceLROnPlateau(
        opt, mode='max',
        factor=0.5,
        patience=2,
        verbose=True,
        threshold=1e-3
    )

    best_score = float('-inf')
    best_model_state = None
    stale = 0
    train_loss, val_loss, val_pesq, val_stoi = [], [], [], []

    for ep in range(1, num_epochs+1):
        # Training
        model.train()
        t_loss = 0.0
        for bx, by in train_loader:
            bx, by = bx.to(device), by.to(device)
            opt.zero_grad()
            out = model(bx)
            loss = crit(out, by)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            opt.step()
            t_loss += loss.item() * bx.size(0)
        t_loss /= len(train_loader.dataset)
        train_loss.append(t_loss)

        # Validation
        model.eval()
        v_loss = 0.0
        with torch.no_grad():
            for vx, vy in val_loader:
                vx, vy = vx.to(device), vy.to(device)
                vo = model(vx)
                v_loss += crit(vo, vy).item() * vx.size(0)
        v_loss /= len(val_loader.dataset)
        val_loss.append(v_loss)

        # Compute PESQ/STOI
        p, s = evaluate_on_validation(
            model, train_dir, val_dir,
            WINDOW_SIZE, device, True
        )
        val_pesq.append(p)
        val_stoi.append(s)

        print(f"Epoch {ep:02d}: TrainLoss={t_loss:.4f}  "
              f"ValLoss={v_loss:.4f}  PESQ={p:.3f}  STOI={s:.3f}")

        # Use combined metric for scheduling and model selection
        current_score = p + s  # Equal weight to PESQ and STOI
        scheduler.step(current_score)

        if current_score > best_score:
            best_score = current_score
            best_model_state = {
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'pesq': p,
                'stoi': s
            }
            stale = 0
        else:
            stale += 1
            if stale >= patience:
                print(f"Early stopping. Best PESQ: {best_model_state['pesq']:.3f}, "
                      f"Best STOI: {best_model_state['stoi']:.3f}")
                break

    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state['model_state_dict'])
        print(f"Restored best model from epoch {best_model_state['epoch']} "
              f"with PESQ={best_model_state['pesq']:.3f}, "
              f"STOI={best_model_state['stoi']:.3f}")

    return model, train_loss, val_loss, val_pesq, val_stoi

# ── Plotting ──────────────────────────────────────────────────────────────
def plot_metrics(train_loss, val_loss, val_pesq, val_stoi):
    ep = np.arange(1, len(train_loss)+1)
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(ep, train_loss, '-o', label='Train Loss')
    plt.plot(ep, val_loss,   '-o', label='Val Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss')
    plt.legend(); plt.title('Loss')
    plt.subplot(1,2,2)
    plt.plot(ep, val_pesq,  '-o', label='Val PESQ')
    plt.plot(ep, val_stoi,  '-o', label='Val STOI')
    plt.xlabel('Epoch'); plt.legend(); plt.title('Quality')
    plt.tight_layout(); plt.show()

# ── Waveform Plots ───────────────────────────────────────────────────────
def plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device='cpu'):
    files   = [f for f in os.listdir(val_dir)
               if f.endswith('.npy') and '_clean_logpower_' in f]
    samples = random.sample(files, min(3, len(files)))
    plt.figure(figsize=(12,6))
    for idx, cf in enumerate(samples):
        C = np.load(os.path.join(val_dir, cf))
        N = np.load(os.path.join(
                val_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]
        preds = []
        for i in range(WINDOW_SIZE-1, T):
            win = (N[:, i-WINDOW_SIZE+1 : i+1] - Xm)/Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Ys) + Ym

        clean_wav    = reconstruct_waveform_from_logpower(C)
        noisy_wav    = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase  = np.angle(librosa.stft(
                          noisy_wav,
                          n_fft=FRAME_SIZE,
                          hop_length=OVERLAP
                        ))[:N_BINS]
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)

        t = np.arange(2000) / SAMPLE_RATE
        for col, wav, title in zip(
            range(3),
            [clean_wav, noisy_wav, enhanced_wav],
            ['Clean','Noisy','Denoised']
        ):
            plt.subplot(3,3, idx*3 + col + 1)
            plt.plot(t, wav[:2000])
            plt.title(title); plt.ylim(-1,1)
    plt.tight_layout(); plt.show()

# ── WAV Export ───────────────────────────────────────────────────────────
def save_enhanced_audio(model, data_dir, out_dir,
                        window_size=WINDOW_SIZE,
                        device='cpu', num_samples=5,
                        Y_mean=0.0, Y_std=1.0):
    os.makedirs(out_dir, exist_ok=True)
    allc = [f for f in os.listdir(data_dir)
            if f.endswith('.npy') and '_clean_logpower_' in f]
    for cf in random.sample(allc, min(num_samples, len(allc))):
        C = np.load(os.path.join(data_dir, cf))
        N = np.load(os.path.join(
                data_dir,
                cf.replace('_clean_logpower_','_noisy_concat_')
            ))
        T = C.shape[1]

        clean_wav = reconstruct_waveform_from_logpower(C)
        noisy_wav = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(
                        noisy_wav,
                        n_fft=FRAME_SIZE,
                        hop_length=OVERLAP
                       ))[:N_BINS]

        preds = []
        for i in range(window_size-1, T):
            win = (N[:, i-window_size+1 : i+1] - Y_mean)/Y_std
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)
        P = (np.stack(preds, axis=1)*Y_std) + Y_mean
        if P.shape[1] < T:
            P = np.pad(P, ((0,0),(0,T-P.shape[1])), mode='edge')

        enhanced_wav = reconstruct_with_noisy_phase(
                           P[:N_BINS], noisy_phase[:,:T]
                       )

        def norm(a):
            m = np.max(np.abs(a))
            return a/m if m>0 else a

        for tag, wav in zip(
            ['clean','noisy','enhanced'],
            [clean_wav, noisy_wav, enhanced_wav]
        ):
            wav = norm(wav)
            sf.write(
                os.path.join(out_dir,
                             f"{tag}_{cf.split('_clean_logpower_')[-1][:-4]}.wav"),
                wav, SAMPLE_RATE
            )
        print("Saved trio:", cf)

if __name__ == '__main__':
    train_dir = r'C:\Users\enhance\article2_merge\train'
    val_dir   = r'C:\Users\enhance\article2_merge\val'
    test_dir  = r'C:\Users\enhance\article2_merge\test'

    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir)
    idxs = np.random.default_rng(42).choice(
        len(X_tr), size=min(100_000, len(X_tr)),
        replace=False
    )
    X_tr, Y_tr = X_tr[idxs], Y_tr[idxs]

    Xn_tr,Yn_tr,Xm,Xs,Ym,Ys = normalize_data(X_tr, Y_tr)

    Xt = torch.from_numpy(Xn_tr).permute(0,3,1,2)
    Yt = torch.from_numpy(Yn_tr)
    train_loader = DataLoader(
        TensorDataset(Xt, Yt),
        batch_size=4, shuffle=True
    )

    X_val, Y_val = load_article2_dataset_from_merged(val_dir)
    Xn_val = (X_val - Xm)/Xs
    Yn_val = (Y_val - Ym)/Ys
    Xv = torch.from_numpy(Xn_val).permute(0,3,1,2)
    Yv = torch.from_numpy(Yn_val)
    val_loader = DataLoader(
        TensorDataset(Xv, Yv),
        batch_size=4, shuffle=False
    )

    device = torch.device(
        'cuda' if torch.cuda.is_available() else 'cpu'
    )
    D, T = Xt.shape[2], Xt.shape[3]
    # Use the new model with attention
    model = Article2CNNWithAttention((D,T)).to(device)

    model, train_loss, val_loss, val_pesq, val_stoi = train_and_validate(
        model, train_loader, val_loader,
        train_dir, val_dir,
        Xm, Xs, Ym, Ys,
        num_epochs=15, device=device,
        patience=5
    )

    plot_metrics(train_loss, val_loss, val_pesq, val_stoi)
    plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device=device)

    save_enhanced_audio(
        model, test_dir,
        r'C:\Users\enhance\article2_test_wavs',
        device=device, num_samples=5,
        Y_mean=Ym, Y_std=Ys
    )

    print("✅ Done – check plots & test‑WAVs")

In [None]:
def plot_waveforms(model, val_dir, Xm, Xs, Ym, Ys, device='cpu'):
    files   = [f for f in os.listdir(val_dir)
               if f.endswith('.npy') and '_clean_logpower_' in f]
    samples = random.sample(files, min(3, len(files)))
    plt.figure(figsize=(12,6))

    for idx, cf in enumerate(samples):
        # load clean & noisy log‑power
        C = np.load(os.path.join(val_dir, cf))                  # [129, T]
        N = np.load(os.path.join(val_dir, cf.replace(
                '_clean_logpower_','_noisy_concat_')))         # [129, T]
        T = C.shape[1]

        # sliding‑window predict
        preds = []
        for i in range(WINDOW_SIZE-1, T):
            win = N[:, i-WINDOW_SIZE+1 : i+1]
            win = (win - Xm) / Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()
            preds.append(out)

        # stack + pad/truncate so we end up with exactly T frames
        Pn = np.stack(preds, axis=1)          # shape = (129, T‑WINDOW+1)
        if Pn.shape[1] < T:
            Pn = np.pad(Pn, ((0,0),(0,T-Pn.shape[1])), mode='edge')
        else:
            Pn = Pn[:, :T]
        P = (Pn * Ys) + Ym                    # back to log‑power domain

        # reconstruct waveforms
        clean_wav   = reconstruct_waveform_from_logpower(C)
        noisy_wav   = reconstruct_waveform_from_logpower(N[:N_BINS])
        noisy_phase = np.angle(librosa.stft(
                          noisy_wav,
                          n_fft=FRAME_SIZE,
                          hop_length=OVERLAP
                       ))[:N_BINS]
        enhanced_wav = reconstruct_with_noisy_phase(P[:N_BINS], noisy_phase)

        # time axis for plotting
        t = np.arange(len(clean_wav)) / SAMPLE_RATE
        t = t[:2000]  # zoom into first 2000 samples

        # plot clean, noisy, enhanced
        for col, wav, title in zip(
            range(3),
            [clean_wav, noisy_wav, enhanced_wav],
            ['Clean','Noisy','Denoised']
        ):
            plt.subplot(3,3, idx*3 + col + 1)
            plt.plot(t, wav[:2000])
            plt.title(title)
            plt.ylim(-1,1)

    plt.tight_layout()
    plt.show()
