<a href="https://colab.research.google.com/github/varshacvenkat-web/Varsha-Venkatapathy-Engineering-Portfolio-/blob/main/article2_ipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import librosa
import soundfile as sf
from pystoi import stoi
from pypesq import pesq   # pip install pystoi pypesq
from torch.utils.data import TensorDataset

# ── Constants ─────────────────────────────────────────────────────────────
SAMPLE_RATE = 8000
FRAME_SIZE  = 256
OVERLAP     = 128
EPSILON     = 1e-10
N_ITER      = 32
WINDOW_SIZE = 9
N_BINS      = FRAME_SIZE // 2 + 1  # 129

# ── Reconstruction Helpers ────────────────────────────────────────────────
def reconstruct_waveform_from_logpower(lp, n_fft=FRAME_SIZE, hop_length=OVERLAP, n_iter=N_ITER):
    lp    = np.nan_to_num(lp)
    power = np.exp(lp) - EPSILON
    mag   = np.sqrt(np.maximum(power, 0))
    return librosa.griffinlim(mag,
                              n_iter=n_iter,
                              n_fft=n_fft,
                              hop_length=hop_length)

def reconstruct_waveform_using_phase(lp, phase, n_fft=FRAME_SIZE, hop_length=OVERLAP):
    lp    = np.nan_to_num(lp)
    power = np.exp(lp) - EPSILON
    mag   = np.sqrt(np.maximum(power, 0))
    Tc    = min(mag.shape[1], phase.shape[1])
    S     = mag[:,:Tc] * np.exp(1j*phase[:,:Tc])
    return librosa.istft(S, hop_length=hop_length)

# ── Data Loader (same as training) ─────────────────────────────────────────
def load_article2_dataset_from_merged(merged_dir, window_size=WINDOW_SIZE):
    Xs, Ys = [], []
    for root, _, files in os.walk(merged_dir):
        cleans = [f for f in files if "_clean_logpower_" in f and f.endswith(".npy")]
        for cf in cleans:
            cpath = os.path.join(root, cf)
            npath = cpath.replace("_clean_logpower_", "_noisy_concat_")
            if not os.path.exists(npath): continue
            C = np.load(cpath)
            N = np.load(npath)
            if C.shape[1] != N.shape[1]: continue
            T = C.shape[1]
            for i in range(window_size-1, T):
                Xs.append(N[:, i-window_size+1:i+1])
                Ys.append(C[:, i])
    X = np.array(Xs, dtype=np.float32)[..., np.newaxis]  # [N,129,9,1]
    Y = np.array(Ys, dtype=np.float32)                    # [N,129]
    return X, Y

def normalize_data(X, Y):
    Xm, Xs = X.mean(), X.std()
    Ym, Ys = Y.mean(), Y.std()
    return (X - Xm)/Xs, (Y - Ym)/Ys, Xm, Xs, Ym, Ys

# ── Model Definition (exactly your train‐time architecture) ────────────────
class Article2CNN(nn.Module):
    def __init__(self):
        super().__init__()
        # conv layers
        self.conv1 = nn.Conv2d(1, 129, kernel_size=(5,1), padding=(2,0))
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(129, 43, kernel_size=(5,1), stride=(3,1), padding=(2,0))
        self.relu2 = nn.ReLU()
        # figure out flatten size with a dummy [1,1,129,9]
        with torch.no_grad():
            dummy = torch.zeros(1,1, N_BINS, WINDOW_SIZE)
            x = self.relu2(self.conv2(self.relu1(self.conv1(dummy))))
            flat = x.view(1, -1).size(1)
        # FC layers
        self.fc1   = nn.Linear(flat, 1024)
        self.relu3 = nn.ReLU()
        self.drop  = nn.Dropout(0.2)
        self.fc2   = nn.Linear(1024, N_BINS)

    def forward(self, x):
        x = self.relu2(self.conv2(self.relu1(self.conv1(x))))
        x = x.view(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.drop(x)
        return self.fc2(x)

# ── Validation Routine ────────────────────────────────────────────────────
def evaluate_on_validation(model, train_dir, val_dir, device="cpu"):
    model.eval()

    # 1) Re‑compute train‐time normalization stats *exactly* as you did during training
    X_tr, Y_tr = load_article2_dataset_from_merged(train_dir)
    idxs = np.random.default_rng(42).choice(len(X_tr),
                                            size=min(100_000, len(X_tr)),
                                            replace=False)
    X_sub, Y_sub = X_tr[idxs], Y_tr[idxs]
    _, _, Xm, Xs, Ym, Ys = normalize_data(X_sub, Y_sub)

    pesq_scores = []
    stoi_scores = []

    # 2) Loop over your val files
    for fname in os.listdir(val_dir):
        if not fname.endswith(".npy") or "_clean_logpower_" not in fname:
            continue
        cpath = os.path.join(val_dir, fname)
        npath = cpath.replace("_clean_logpower_", "_noisy_concat_")
        if not os.path.exists(npath): continue

        C = np.load(cpath)   # [129, T]
        N = np.load(npath)   # [129, T]
        T = C.shape[1]
        if T < WINDOW_SIZE: continue

        # reconstruct “ground‐truth” clean for scoring
        clean_wav = reconstruct_waveform_from_logpower(C)

        # 3) model inference, frame by frame
        preds = []
        for i in range(WINDOW_SIZE-1, T):
            win = N[:, i-WINDOW_SIZE+1:i+1]
            win = (win - Xm)/Xs
            inp = torch.from_numpy(win.astype(np.float32))\
                       .unsqueeze(0).unsqueeze(0)\
                       .to(device)               # [1,1,129,9]
            with torch.no_grad():
                out = model(inp).squeeze(0).cpu().numpy()  # [129]
            preds.append(out)
        Pn = np.stack(preds, axis=1)   # [129, T−8]
        # pad/trim back to length T
        if Pn.shape[1] < T:
            Pn = np.pad(Pn, ((0,0),(0,T-Pn.shape[1])), mode="edge")
        else:
            Pn = Pn[:,:T]

        # 4) de‐normalize
        P = Pn * Ys + Ym    # [129, T]

        # 5) grab original noisy phase
        temp_noisy = reconstruct_waveform_from_logpower(N)
        S_noisy    = librosa.stft(temp_noisy,
                                  n_fft=FRAME_SIZE,
                                  hop_length=OVERLAP,
                                  center=False)[:N_BINS]
        phase_noisy = np.angle(S_noisy)

        # 6) reconstruct enhanced waveform
        enhanced = reconstruct_waveform_using_phase(P, phase_noisy[:,:T])

        # 7) score
        try:
            p = pesq(clean_wav, enhanced, SAMPLE_RATE, 'nb')
        except:
            p = np.nan
        try:
            s = stoi(clean_wav, enhanced, SAMPLE_RATE, extended=False)
        except:
            s = np.nan

        pesq_scores.append(p)
        stoi_scores.append(s)

    # final averages
    print(f"\nValidated on {len(pesq_scores)} utterances.")
    print(f" → PESQ = {np.nanmean(pesq_scores):.3f}")
    print(f" → STOI = {np.nanmean(stoi_scores):.3f}")

# ── Main: load your checkpoint and run ─────────────────────────────────────
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # build & load
    model = Article2CNN().to(device)
    ckpt  = torch.load(r"C:\Users\enhance\article2_model_trained.pth",
                       map_location=device)
    model.load_state_dict(ckpt)
    print("✅ Loaded trained weights.")

    # evaluate
    evaluate_on_validation(
        model,
        train_dir=r"C:\Users\enhance\article2_merge\train",
        val_dir  =r"C:\Users\enhance\article2_merge\val",
        device=device
    )
