In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
from scipy.interpolate import CubicSpline
from sklearn.preprocessing import StandardScaler, MinMaxScaler

In [None]:
# ---------------------------
# 1. Define CNN model 
# ---------------------------
class CNN(nn.Module):
    def __init__(self, filter1=128, filter2=32, dropout1=0.5, dropout2=0.3, dropout_fc=0.1):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv1d(1, filter1, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(filter1, filter2, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(filter2 * 61, 128)
        self.fc2 = nn.Linear(128, 1)
        self.dropout1 = nn.Dropout(dropout1)
        self.dropout2 = nn.Dropout(dropout2)
        self.dropout_fc = nn.Dropout(dropout_fc)
        self.batch_norm1 = nn.BatchNorm1d(filter1)
        self.batch_norm2 = nn.BatchNorm1d(filter2)

    def forward(self, x):
        x = self.pool(torch.relu(self.batch_norm1(self.conv1(x))))
        x = self.dropout1(x)
        x = self.pool(torch.relu(self.batch_norm2(self.conv2(x))))
        x = self.dropout2(x)
        x = x.view(-1, self.fc1.in_features)
        x = torch.relu(self.fc1(x))
        x = self.dropout_fc(x)
        x = self.fc2(x)
        return x

In [20]:
# ---------------------------
# 2. Load trained peak-finding model
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
peak_finding_model = CNN().to(device)
peak_finding_model.load_state_dict(torch.load("../models/waveI_CNN.pth", map_location=device))
peak_finding_model.eval()

CNN(
  (conv1): Conv1d(1, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(128, 32, kernel_size=(3,), stride=(1,), padding=(1,))
  (fc1): Linear(in_features=1952, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=1, bias=True)
  (dropout1): Dropout(p=0.5, inplace=False)
  (dropout2): Dropout(p=0.3, inplace=False)
  (dropout_fc): Dropout(p=0.1, inplace=False)
  (batch_norm1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batch_norm2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [None]:
# ---------------------------
# 3. Normalize waveform before inference
# ---------------------------
def interpolate_and_smooth(wave, target_points=244, total_ms=10.0):
    """
    Interpolates waveform to 244 points over 10 ms.
    """
    orig_points = len(wave)
    x_old = np.linspace(0, total_ms, orig_points)
    x_new = np.linspace(0, total_ms, target_points)

    if orig_points < target_points:
        # Upsample with cubic spline
        cs = CubicSpline(x_old, wave)
        interp_wave = cs(x_new)
    else:
        # Downsample with linear interpolation
        interp_wave = np.interp(x_new, x_old, wave)

    return interp_wave

def normalize_waveform(wave):
    """Standardize then scale to 0–1."""
    scaler1 = StandardScaler()
    zscored = scaler1.fit_transform(wave.reshape(-1, 1)).flatten()
    scaler2 = MinMaxScaler(feature_range=(0, 1))
    normalized = scaler2.fit_transform(zscored.reshape(-1, 1)).flatten()
    return normalized

In [None]:
# ---------------------------
# 4. Peak finding with normalization + smoothing
# ---------------------------
def peak_finding(wave):
    # Normalize before feeding to CNN
    norm_wave = normalize_waveform(wave)
    waveform_torch = torch.tensor(norm_wave, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    # CNN prediction for wave 1 index
    outputs = peak_finding_model(waveform_torch)
    prediction = int(round(outputs.detach().cpu().numpy()[0][0], 0))

    # Gaussian smoothing on original (µV) waveform
    smoothed_waveform = gaussian_filter1d(wave, sigma=1.0)

    # Find peaks & troughs around prediction
    n = 16
    t = 7
    window = 10
    start_point = max(0, prediction - window)
    smoothed_peaks, _ = find_peaks(smoothed_waveform[start_point:], distance=n)
    smoothed_troughs, _ = find_peaks(-smoothed_waveform, distance=t)

    peaks_within_ms = np.array([])
    ms_cutoff = 0.25
    while len(peaks_within_ms) == 0:
        ms_window = int(ms_cutoff * len(smoothed_waveform) / 10)  
        candidate_peaks = smoothed_peaks + start_point
        within_ms_mask = np.abs(candidate_peaks - prediction) <= ms_window
        peaks_within_ms = candidate_peaks[within_ms_mask]
        ms_cutoff += 0.25
    tallest_peak_idx = np.argmax(smoothed_waveform[peaks_within_ms])
    pk1 = peaks_within_ms[tallest_peak_idx]

    peaks = smoothed_peaks + start_point
    peaks = peaks[peaks>pk1]
    sorted_indices = np.argsort(smoothed_waveform[peaks])

    highest_smoothed_peaks = np.sort(np.concatenate(
        ([pk1], peaks[sorted_indices[-min(4, peaks.size):]])
        )) 
    
    relevant_troughs = []
    for p in range(len(highest_smoothed_peaks)):
        for tr in smoothed_troughs:
            if tr > highest_smoothed_peaks[p]:
                if p != 4:
                    try:
                        if tr < highest_smoothed_peaks[p + 1]:
                            relevant_troughs.append(int(tr))
                            break
                    except IndexError:
                        pass
                else:
                    relevant_troughs.append(int(tr))
                    break

    return highest_smoothed_peaks, np.array(relevant_troughs)

In [None]:
# ---------------------------
# 5. Convert peaks → latency & amplitude (Wave I only)
# ---------------------------
def extract_latency_amplitude(wave):
    peaks, troughs, interp_wave = peak_finding(wave)

    if len(peaks) == 0 or len(troughs) == 0:
        return {
            "peak_idx": None,
            "trough_idx": None,
            "latency_ms": None,
            "amplitude_uV": None
        }

    # Take the FIRST peak + its first corresponding trough
    p = int(peaks[0])
    t = int(troughs[0]) if len(troughs) > 0 else None

    latency_ms = p * (10.0 / 244)   # 10 ms / 244 samples
    amplitude_uV = interp_wave[p] - interp_wave[t] if t is not None else None

    return {
        "peak_idx": p,
        "trough_idx": t,
        "latency_ms": latency_ms,
        "amplitude_uV": amplitude_uV
    }


In [None]:
# ---------------------------
# 6. Batch process folder
# ---------------------------
def batch_process(input_dir, output_csv="peak_results.csv"):
    all_results = []
    for fname in os.listdir(input_dir):
        if not fname.endswith(".csv"):
            continue
        df = pd.read_csv(os.path.join(input_dir, fname))
        wave = df["Voltage"].values[:244]
        res = extract_latency_amplitude(wave)
        for r in res:
            r["file"] = fname
            all_results.append(r)
    pd.DataFrame(all_results).to_csv(output_csv, index=False)
    print(f"Saved results to {output_csv}")

In [None]:

# ---------------------------
# 7. Optional visualization
# ---------------------------
def plot_example(fname):
    df = pd.read_csv(fname)
    wave = df["Voltage"].values[:244]
    res = extract_latency_amplitude(wave)
    plt.plot(df["Time"][:244], wave, label="ABR waveform")
    for r in res:
        plt.scatter(r["latency_ms"], wave[r["peak_idx"]], c="r", label="Peak")
        plt.scatter(r["latency_ms"], wave[r["trough_idx"]], c="b", label="Trough")
    plt.xlabel("Time (ms)")
    plt.ylabel("Voltage (µV)")
    plt.title(f"{fname} | Peaks & Troughs")
    plt.legend()
    plt.show()