In [1]:
# ABRA_batch_threshold_inference.ipynb
# Minimal dependencies: tensorflow/keras, numpy, pandas, sklearn, scipy

import os
import numpy as np
import pandas as pd
from scipy.interpolate import CubicSpline
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from tensorflow.keras.models import load_model

In [None]:

# ---------------------------
# 1. Interpolation function
# ---------------------------
def interpolate_and_smooth(final, target_length=244):
    if len(final) > target_length:
        new_points = np.linspace(0, len(final), target_length + 2)
        interpolated_values = np.interp(new_points, np.arange(len(final)), final)
        final = np.array(interpolated_values[:target_length], dtype=float)
    elif len(final) < target_length:
        original_indices = np.arange(len(final))
        target_indices = np.linspace(0, len(final) - 1, target_length)
        cs = CubicSpline(original_indices, final)
        final = cs(target_indices)
    return np.array(final, dtype=float)

In [None]:

# ---------------------------
# 2. Thresholding function
# ---------------------------
def calculate_hearing_threshold(df, freq, level_col="Level(dB)", units="Microvolts", multiply_y_factor=1, calibration_levels=None):
    """
    Run thresholding model on one frequency's ABRs.
    df: DataFrame with ABR data (columns include freq, dB, waveform)
    freq: frequency to analyze
    level_col: 'Level(dB)' or 'PostAtten(dB)'
    units: Microvolts or Nanovolts
    multiply_y_factor: optional scaling
    calibration_levels: dict for calibration if PostAtten(dB) used
    """
    # Load trained model
    thresholding_model = load_model("../models/abr_thresholding.keras")
    thresholding_model.steps_per_execution = 1

    # Filter for frequency
    df_filtered = df[df["Freq(Hz)"] == freq]

    # Sort dB levels
    db_levels = sorted(df_filtered[level_col].unique(), reverse=True) if level_col == "Level(dB)" else sorted(df_filtered[level_col].unique())
    waves = []

    for db in db_levels:
        row = df_filtered[df_filtered[level_col] == np.abs(db)]
        if row.empty:
            continue
        index = row.index.values[-1]
        final = df_filtered.loc[index, "0":].dropna()
        final = pd.to_numeric(final, errors="coerce").values.astype(np.float64)

        # Interpolate to 244 points
        final = interpolate_and_smooth(final[:244])
        final *= multiply_y_factor

        if units == "Nanovolts":
            final /= 1000

        waves.append(final)

    # Normalize
    waves = np.array(waves)
    flat = waves.flatten().reshape(-1, 1)
    scaler1 = StandardScaler()
    zscored = scaler1.fit_transform(flat)
    scaler2 = MinMaxScaler(feature_range=(0, 1))
    scaled = scaler2.fit_transform(zscored).reshape(waves.shape)
    waves = np.expand_dims(scaled, axis=2)

    # Predict
    prediction = thresholding_model.predict(waves, verbose=0)
    y_pred = (prediction > 0.5).astype(int).flatten()

    # Calibration correction if needed
    if level_col == "PostAtten(dB)" and calibration_levels is not None:
        db_levels = np.array(db_levels)
        calibration_level = np.full(len(db_levels), calibration_levels[(df.name, freq)])
        db_levels = calibration_level - db_levels

    # Walk down until first "below-threshold" (0)
    lowest_db = db_levels[0]
    previous_prediction = None
    for p, d in zip(y_pred, db_levels):
        if p == 0:
            if previous_prediction == 0:
                break
            previous_prediction = 0
        else:
            lowest_db = d
            previous_prediction = 1

    return lowest_db


In [None]:

# ---------------------------
# 3. Batch process function
# ---------------------------
def batch_thresholds(df, freqs, level_col="Level(dB)"):
    results = []
    for f in freqs:
        thr = calculate_hearing_threshold(df, f, level_col=level_col)
        results.append({"frequency": f, "threshold_dB": thr})
    return pd.DataFrame(results)