# ECG Digitizer — Colab Inference Backend

> Convert a scanned or photographed 12-lead ECG into structured time-series data.  
> Upload an image → get back 12 calibrated voltage traces in mV, ready for analysis or EHR export.

---

## How it works

The pipeline is built on top of [hengck23's PhysioNet 2024 solution](https://www.kaggle.com/code/hengck23/demo-submission), extended with a serving layer and a lead-localization module. Three networks run in sequence:

```
Input image
    │
    ▼
[Stage 0]  ResNet keypoint detector
           Finds the four corners of the ECG paper → homography warp
    │
    ▼
[Stage 1]  Grid alignment network
           Detects the mm-grid intersections → rectifies any remaining distortion
    │
    ▼
[Stage 2]  ResNet34-UNet (Net3)
           Pixel-level segmentation of 4 simultaneous waveform rows
           → peak tracking → mV conversion → Savitzky-Golay smoothing
    │
    ▼
12-lead signals  +  per-lead bounding boxes  +  step-by-step images
```

The key insight from the original competition solution: rather than detecting individual leads separately, Stage 2 predicts **4 probability maps simultaneously** — one per printed row. Each row contains 4 leads side by side, so a single forward pass recovers all 12 traces.

---

## Before you start

1. **GPU runtime** — Runtime → Change runtime type → T4 GPU (or A100 if available)
2. **`kaggle.json`** — download from kaggle.com → Settings → API → Create New Token
3. **ngrok authtoken** — free account at ngrok.com → Dashboard → Your Authtoken

Run the cells top to bottom. The last cell starts the server and prints the public URL.


## 1 — Install dependencies

`connected-components-3d` is needed by the Stage 2 post-processing step to trace waveform centerlines.  
Everything else is standard inference stack.


In [None]:
!pip install -q fastapi uvicorn pyngrok python-multipart nest-asyncio
!pip install -q timm connected-components-3d kaggle
print('✅ done')


## 2 — Download model weights

The three checkpoint files (~400 MB total) are hosted as a Kaggle dataset that mirrors hengck23's original submission.  
Upload your `kaggle.json` when prompted — Colab will save it to `/root/.kaggle/` and set the right permissions.


In [None]:
import os
from google.colab import files

print('Upload kaggle.json  (kaggle.com → Settings → API → Create New Token)')
files.upload()

os.makedirs('/root/.kaggle', exist_ok=True)
!cp kaggle.json /root/.kaggle/kaggle.json
!chmod 600 /root/.kaggle/kaggle.json
print('✅ Kaggle API configured')


In [None]:
WEIGHTS_DIR = '/content/weights'
os.makedirs(WEIGHTS_DIR, exist_ok=True)

print(' Downloading weights (~5–10 min)...')
!kaggle datasets download -d kami1976/hengck23-submit-physionet -p {WEIGHTS_DIR} --unzip -q

HENGCK_DIR = f'{WEIGHTS_DIR}/hengck23-submit-physionet'
WEIGHT_DIR = f'{HENGCK_DIR}/weight'

print('\nVerifying:')
for fname in [
    'stage0-last.checkpoint.pth',
    'stage1-last.checkpoint.pth',
    'stage2-00005810.checkpoint.pth',
    '../stage0_common.py',
]:
    path = f'{WEIGHT_DIR}/{fname}'
    ok = os.path.exists(path)
    print(f'  {"✅" if ok else "❌"}  {fname}')

print('\n✅ All files present — proceed to Cell 3' if os.path.exists(f'{HENGCK_DIR}/stage0_common.py')
      else '\n❌ stage0_common.py missing — download likely failed, re-run this cell')


## 3 — Configure paths and imports

All weight paths are set once here and reused across later cells.  
`sys.path` is updated so the helper modules bundled inside the Kaggle dataset (`stage0_common.py`, etc.) are importable.


In [None]:
import sys, os, gc, cv2, base64
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import timm
from scipy.signal import savgol_filter

WEIGHTS_DIR = '/content/weights'
HENGCK_DIR  = f'{WEIGHTS_DIR}/hengck23-submit-physionet'
WEIGHT_DIR  = f'{HENGCK_DIR}/weight'

STAGE0_W = f'{WEIGHT_DIR}/stage0-last.checkpoint.pth'
STAGE1_W = f'{WEIGHT_DIR}/stage1-last.checkpoint.pth'
STAGE2_W = f'{WEIGHT_DIR}/stage2-00005810.checkpoint.pth'

if HENGCK_DIR not in sys.path:
    sys.path.insert(0, HENGCK_DIR)

import stage0_common as s0c
import stage1_common as s1c
import stage2_common as s2c
from stage0_model import Net as Stage0Net
from stage1_model import Net as Stage1Net
from stage2_model import MyCoordUnetDecoder, encode_with_resnet

LEADS_ORDER = ['I','II','III','aVR','aVL','aVF','V1','V2','V3','V4','V5','V6']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')

# quick sanity check
for label, path in [('stage0 weight', STAGE0_W), ('stage1 weight', STAGE1_W), ('stage2 weight', STAGE2_W)]:
    assert os.path.exists(path), f'Missing: {path}'
print('✅ paths OK')


## 4 — Image preprocessing utilities

Real-world ECG photos come from a wide range of sources — clinical scanners, phone cameras, aging paper — so a one-size-fits-all preprocessing step doesn't work well. This section defines a small toolkit of corrections and a dispatch function (`preprocess_by_source`) that picks the right combination based on the image origin.

**CLAHE on the V channel** is the most universally useful step. Instead of stretching the global histogram, it operates on small tiles independently, which recovers local contrast on faded or unevenly lit paper without blowing out already-bright regions.

`stage1_quality` gives a cheap edge-density + anisotropy score used later to decide which of two preprocessing paths produced the cleaner input for Stage 1.


In [None]:
# ── Contrast / color corrections ──────────────────────────────────────────────

def change_color(image_rgb):
    """CLAHE on the HSV V channel — the core contrast enhancement used before Stage 0."""
    hsv = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HSV)
    h, s, v = cv2.split(hsv)
    v_denoised = cv2.fastNlMeansDenoising(v, h=5.46)
    clip_limit = max(1.0, min(3.5, 2.0 + np.std(v_denoised) / 25))
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(8, 8))
    return cv2.cvtColor(cv2.merge([h, s, clahe.apply(v_denoised)]), cv2.COLOR_HSV2RGB)

def clahe_luminance_bgr(img_bgr, clip=2.0, tile=8):
    lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=float(clip), tileGridSize=(int(tile), int(tile)))
    return cv2.cvtColor(cv2.merge([clahe.apply(l), a, b]), cv2.COLOR_LAB2BGR)

def grayworld_white_balance(img_bgr):
    img = img_bgr.astype(np.float32)
    b, g, r = cv2.split(img)
    m = (b.mean() + g.mean() + r.mean()) / 3.0
    b *= m / (b.mean() + 1e-6)
    g *= m / (g.mean() + 1e-6)
    r *= m / (r.mean() + 1e-6)
    return np.clip(cv2.merge([b, g, r]), 0, 255).astype(np.uint8)

def denoise_median(img_bgr, k=3):
    k = int(k); k = k if k % 2 == 1 else k + 1
    return cv2.medianBlur(img_bgr, k)

def denoise_bilateral(img_bgr, d=7, sigmaColor=50, sigmaSpace=50):
    return cv2.bilateralFilter(img_bgr, int(d), float(sigmaColor), float(sigmaSpace))

def illumination_strength(img_bgr, sigma=35):
    """Estimate global illumination unevenness via low-frequency std."""
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
    return float(np.std(cv2.GaussianBlur(gray, (0, 0), sigma)))

def bg_correct_lab_l(img_bgr, k=81):
    """Subtract a morphological background estimate from the L channel."""
    lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    k = int(k); k = k if k % 2 == 1 else k + 1
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
    bg = cv2.morphologyEx(l, cv2.MORPH_OPEN, kernel)
    l_corr = cv2.normalize(cv2.subtract(l, bg), None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    return cv2.cvtColor(cv2.merge([l_corr, a, b]), cv2.COLOR_LAB2BGR)


# ── Source-aware dispatch ──────────────────────────────────────────────────────
# Source codes map to equipment types in the PhysioNet dataset.
# When no classifier is available, '0001' (identity) is the safe default.

def preprocess_by_source(img_bgr, source='0001'):
    s = str(source)
    if s in ('0001', '0004', '0012'):
        return img_bgr
    if s in ('0003', '0011'):
        return clahe_luminance_bgr(grayworld_white_balance(img_bgr), clip=1.2)
    if s == '0006':
        return clahe_luminance_bgr(denoise_bilateral(img_bgr, d=5, sigmaColor=25, sigmaSpace=25), clip=1.2)
    if s in ('0005', '0010'):
        x = img_bgr
        if illumination_strength(x) > 0.14:
            x = bg_correct_lab_l(x, k=81)
        if cv2.cvtColor(x, cv2.COLOR_BGR2GRAY).std() < 30:
            x = clahe_luminance_bgr(x, clip=1.1)
        return x
    if s == '0009':
        x = img_bgr
        if illumination_strength(x) > 0.14:
            x = bg_correct_lab_l(x, k=101)
        return denoise_median(x, k=3)
    return img_bgr


# ── Post-pipeline helpers ──────────────────────────────────────────────────────

def stage1_quality(s1_rgb):
    """
    Score used to pick the better of two preprocessing branches.
    Combines Canny edge density (favors sharp waveforms) with a horizontal/vertical
    gradient anisotropy term (favors well-aligned grids).
    """
    g = cv2.cvtColor(s1_rgb.astype(np.uint8), cv2.COLOR_RGB2GRAY)
    density = cv2.Canny(g, 50, 150).mean() / 255.0
    gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3)
    ax, ay = float(np.mean(np.abs(gx))), float(np.mean(np.abs(gy)))
    anis = max(ax, ay) / (min(ax, ay) + 1e-6)
    return float(density * 0.7 + np.tanh(anis - 1.0) * 0.3)

def series_dict(series_4row):
    """
    Unpack the (4, N) output of Stage 2 into a named dict.

    The ECG is printed in 3 data rows × 4 columns + 1 full-length rhythm strip (row 3 = Lead II).
    Each data row gets split into 4 equal segments, then named by standard 12-lead convention.
    """
    series_4row = np.asarray(series_4row)
    if series_4row.ndim == 3:
        series_4row = series_4row[0]
    if series_4row.shape[0] != 4 and series_4row.shape[1] == 4:
        series_4row = series_4row.T
    d = {}
    names = [['I', 'aVR', 'V1', 'V4'], ['II_short', 'aVL', 'V2', 'V5'], ['III', 'aVF', 'V3', 'V6']]
    for r in range(3):
        for lead, arr in zip(names[r], np.array_split(series_4row[r], 4)):
            d[lead] = np.asarray(arr, dtype=np.float32)
    d['II'] = np.asarray(series_4row[3], dtype=np.float32)  # full-length rhythm strip
    return d

def dw(d, alpha=0.33):
    """
    Soft Einthoven correction.
    II = I + III theoretically, but digitization errors accumulate in each channel independently.
    This blends the residual error back proportionally rather than forcing an exact constraint.
    """
    if all(k in d for k in ['I', 'II_short', 'III']):
        L1, L2s, L3 = d['I'], d['II_short'], d['III']
        e = L2s - (L1 + L3)
        d['I']         = L1  + alpha * e
        d['III']       = L3  + alpha * e
        d['II_short']  = L2s - alpha * e
    return d

def img_to_b64(img_rgb: np.ndarray, max_width=1200) -> str:
    h, w = img_rgb.shape[:2]
    if w > max_width:
        img_rgb = cv2.resize(img_rgb, (max_width, int(h * max_width / w)))
    _, buf = cv2.imencode('.jpg', cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR),
                          [cv2.IMWRITE_JPEG_QUALITY, 88])
    return 'data:image/jpeg;base64,' + base64.b64encode(buf).decode()

def heatmap_overlay(img_rgb: np.ndarray, pixel_4ch: np.ndarray) -> str:
    """Max-pool the 4 activation channels and overlay as a JET colormap."""
    h, w = img_rgb.shape[:2]
    heat = pixel_4ch.max(axis=0)
    heat = cv2.resize(heat, (w, h))
    heat = np.uint8(255 * heat / (heat.max() + 1e-6))
    heat_color = cv2.cvtColor(cv2.applyColorMap(heat, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
    overlay = np.uint8(img_rgb * 0.55 + heat_color * 0.45)
    return img_to_b64(overlay)

print('✅ utilities defined')


## 5 — Net3: the waveform segmentation model

Net3 is a standard ResNet34 encoder paired with a UNet-style decoder, with one small but important addition: a **coordinate channel** appended to the final feature map before the output convolution.

The coordinate channel is a vertical gradient (0 at top → 1 at bottom) broadcast across the full width. This gives the model an explicit signal about vertical position, which matters because the four ECG rows have different absolute y-positions in the image. Without it, the model would need to infer row identity purely from context.

The output is 4 channels — one probability map per printed row — not 12. The lead splitting happens algebraically afterward in `series_dict`.


In [None]:
class Net3(nn.Module):
    def __init__(self, pretrained=False):
        super().__init__()
        encoder_dim = [64, 128, 256, 512]
        decoder_dim = [256, 128, 64, 32]
        self.encoder = timm.create_model(
            'resnet34.a3_in1k', pretrained=pretrained,
            in_chans=3, num_classes=0, global_pool=''
        )
        self.decoder = MyCoordUnetDecoder(
            in_channel=encoder_dim[-1],
            skip_channel=encoder_dim[:-1][::-1] + [0],
            out_channel=decoder_dim,
            scale=[2, 2, 2, 2],
        )
        self.pixel = nn.Conv2d(decoder_dim[-1] + 1, 4, 1)  # +1 for the coord channel

    def forward(self, image):
        encode = encode_with_resnet(self.encoder, image)
        last, _ = self.decoder(feature=encode[-1], skip=encode[:-1][::-1] + [None])
        B, C, H, W = last.shape
        # vertical coordinate channel: shape (B, 1, H, W), values in [0, 1]
        coord = torch.linspace(0, 1, H, device=last.device).view(1, 1, H, 1).expand(B, 1, H, W)
        last = torch.cat([last, coord], dim=1)
        return self.pixel(last)

print('✅ Net3 defined')


## 6 — PhysioPipeline

The pipeline class wires the three stages together and handles the calibration arithmetic.

**Calibration constants** (`zero_mv`, `mv_to_pixel`, `t0`/`t1`) come from the physical properties of standard ECG paper:
- Paper speed: 25 mm/s → at the model's internal resolution of 4352 px wide and ~10 s of signal, each second is ~392 px
- Amplitude scale: 10 mm/mV → `mv_to_pixel = 78.8` (measured empirically from the dataset)
- `zero_mv` lists the y-pixel of each row's isoelectric baseline in the 4352×1696 resized space

**Dual-path preprocessing**: Stage 0 and Stage 1 run twice — once on the raw image and once on the preprocessed version — and `stage1_quality` picks the winner. This costs an extra few seconds but noticeably improves robustness on difficult inputs (e.g. strong background illumination gradients).

**`_compute_lead_boxes`** works backwards from these same calibration constants to compute where each of the 12 leads sits in the image, without any additional inference. The x-boundaries come from `t0`/`t1` divided into 4 equal columns; the y-boundaries are derived from `zero_mv` row gaps. The result is normalized to [0, 1] so the frontend can map them onto any display size.


## 6b — Clinical metrics

Computed server-side from the Lead II rhythm strip so the frontend doesn't need to implement peak detection.


In [None]:
from scipy.signal import find_peaks

def compute_metrics(lead_ii: list, fs: int = 500) -> dict:
    """
    Compute HR, RR interval, and QRS duration from the Lead II rhythm strip.

    Uses scipy.signal.find_peaks with physiologically-grounded constraints:
      - minimum peak distance: 250 samples (= 0.5 s, equivalent to 120 bpm ceiling)
      - adaptive height threshold: 60th percentile of absolute signal values,
        so the threshold scales with whatever amplitude the digitizer produced

    RR and HR use the median rather than the mean to be robust against the
    occasional missed or spurious peak at the start/end of the strip.

    QRS duration is estimated by walking outward from each R peak until the
    signal drops below 10% of that peak's amplitude, averaging across all beats.
    """
    sig = np.array(lead_ii, dtype=np.float32)
    if len(sig) < fs:
        return {'hr_bpm': None, 'rr_ms': None, 'qrs_ms': None}

    # Adaptive height threshold — scales with signal amplitude
    height_thresh = np.percentile(np.abs(sig), 60)

    peaks, _ = find_peaks(
        sig,
        distance=int(fs * 0.5),   # no two R-peaks closer than 0.5 s
        height=height_thresh,
    )

    if len(peaks) < 2:
        return {'hr_bpm': None, 'rr_ms': None, 'qrs_ms': None}

    rr_samples = np.diff(peaks)
    rr_ms      = float(np.median(rr_samples) / fs * 1000)
    hr_bpm     = round(60_000 / rr_ms, 1)

    # QRS duration: walk left and right from each peak until signal < 10% of peak
    qrs_durations = []
    for pk in peaks:
        amp = sig[pk]
        threshold = amp * 0.10

        # walk left
        left = pk
        while left > 0 and sig[left] > threshold:
            left -= 1

        # walk right
        right = pk
        while right < len(sig) - 1 and sig[right] > threshold:
            right += 1

        width_ms = (right - left) / fs * 1000
        # sanity bounds: 40–200 ms is the physiological range
        if 40 <= width_ms <= 200:
            qrs_durations.append(width_ms)

    qrs_ms = round(float(np.median(qrs_durations)), 1) if qrs_durations else None

    return {
        'hr_bpm': hr_bpm,
        'rr_ms':  round(rr_ms, 1),
        'qrs_ms': qrs_ms,
    }

print('✅ compute_metrics defined')


In [None]:
class PhysioPipeline:
    # Standard 12-lead print layout: 3 data rows × 4 columns
    # Row 3 is the full-length Lead II rhythm strip
    LEAD_LAYOUT = [
        ('I',   0, 0), ('aVR', 0, 1), ('V1', 0, 2), ('V4', 0, 3),
        ('II',  1, 0), ('aVL', 1, 1), ('V2', 1, 2), ('V5', 1, 3),
        ('III', 2, 0), ('aVF', 2, 1), ('V3', 2, 2), ('V6', 2, 3),
    ]

    def __init__(self, device='cuda:0'):
        self.device = device
        self.stage0_net = self.stage1_net = self.stage2_net = None

        # Stage2 reads this ROI from the Stage1 output
        self.x0, self.x1 = 0, 2176
        self.y0, self.y1 = 0, 1696

        # Calibration: isoelectric baseline y-positions in the 4352×1696 resized space
        # and the pixel-per-millivolt conversion factor
        self.zero_mv     = [703.5, 987.5, 1271.5, 1531.5]
        self.mv_to_pixel = 78.8

        # Usable signal x-range in the resized space (trims left/right borders and scale bars)
        self.t0, self.t1 = 235, 4161

        self.resize = T.Resize((1696, 4352), interpolation=T.InterpolationMode.BILINEAR)

    # ── Model loading ──────────────────────────────────────────────────────────

    def load_models(self, stage0_w, stage1_w, stage2_w):
        print('  Stage0 (rotation correction)...')
        self.stage0_net = s0c.load_net(Stage0Net(pretrained=False), stage0_w).to(self.device).eval()
        print('  Stage1 (grid alignment)...')
        self.stage1_net = s1c.load_net(Stage1Net(pretrained=False), stage1_w).to(self.device).eval()
        print('  Stage2 (Net3 waveform segmentation)...')
        self.stage2_net = Net3(pretrained=False).to(self.device).eval()
        st = torch.load(stage2_w, map_location='cpu')
        if isinstance(st, dict) and 'state_dict' in st:
            st = st['state_dict']
        self.stage2_net.load_state_dict(st, strict=False)
        print('✅ all models loaded')

    # ── Stage inference ────────────────────────────────────────────────────────

    def _run_stage0(self, img_bgr):
        try:
            img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
            batch   = s0c.image_to_batch(change_color(img_rgb))
            with torch.no_grad(), torch.amp.autocast(self.device.split(':')[0], dtype=torch.float32):
                output = self.stage0_net(batch)
            rotated, keypoint = s0c.output_to_predict(img_rgb, batch, output)
            keypoint = keypoint.astype(np.float32) if hasattr(keypoint, 'astype') else keypoint
            normalised, _, _ = s0c.normalise_by_homography(rotated, keypoint)
            return normalised
        except Exception as e:
            print(f'  ⚠️  Stage0 failed ({e}), falling back to resized original')
            img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
            return cv2.resize(img_rgb, (2176, 1696))

    def _run_stage1(self, stage0_rgb):
        batch = {'image': torch.from_numpy(
            np.ascontiguousarray(stage0_rgb.transpose(2, 0, 1))).unsqueeze(0)}
        with torch.no_grad(), torch.amp.autocast(self.device.split(':')[0], dtype=torch.float32):
            output = self.stage1_net(batch)
        gridpoint_xy, _ = s1c.output_to_predict(stage0_rgb, batch, output)
        return s1c.rectify_image(stage0_rgb, gridpoint_xy)

    def _run_stage2(self, stage1_rgb, length):
        img   = stage1_rgb[self.y0:self.y1, self.x0:self.x1] / 255.0
        batch = self.resize(
            torch.from_numpy(np.ascontiguousarray(img.transpose(2, 0, 1))).unsqueeze(0)
        ).float().to(self.device)
        with torch.no_grad(), torch.amp.autocast(self.device.split(':')[0], dtype=torch.float32):
            output = self.stage2_net(batch)
        pixel = torch.sigmoid(output).float().cpu().numpy()[0]   # (4, H, W)
        # pixel_to_series traces the waveform centerline in each row
        series_px = s2c.pixel_to_series(pixel[..., self.t0:self.t1], self.zero_mv, length)
        # convert pixel displacement from baseline → millivolts
        series = (np.array(self.zero_mv).reshape(4, 1) - series_px) / self.mv_to_pixel
        for i in range(4):
            series[i] = savgol_filter(series[i], window_length=7, polyorder=2)
        return series, pixel

    # ── Lead bounding boxes ────────────────────────────────────────────────────

    def _compute_lead_boxes(self, img_h, img_w):
        """
        Derive normalized [0,1] bounding boxes for all 12 leads using the same
        calibration constants that Stage2 uses for extraction — no extra inference.

        x-boundaries: divide the signal x-range [t0, t1] into 4 equal columns,
                      then scale from the internal 4352-px width to the actual image width.
        y-boundaries: divide the 3-row data area into equal thirds rather than using
                      midpoints between baselines. Using midpoints fails when R-peaks
                      extend beyond the halfway mark between rows, causing boxes to
                      visually clip into adjacent lead regions.
        """
        W_RS = 4352  # Stage2 internal width
        sx   = img_w / W_RS

        col_w   = (self.t1 - self.t0) / 4.0
        x_edges = [(self.t0 + i * col_w) * sx for i in range(5)]

        zm      = self.zero_mv
        row_gap = zm[1] - zm[0]

        # Define the usable data area top and bottom, then divide into 3 equal rows.
        # This guarantees no overlap regardless of how far waveforms extend from baseline.
        data_top = zm[0] - row_gap * 0.85   # clear the patient-info header
        data_bot = zm[2] + row_gap * 0.55   # some padding below the last data row
        row_h    = (data_bot - data_top) / 3.0

        y_tops = [data_top + r * row_h for r in range(3)] + [(zm[2] + zm[3]) / 2.0]
        y_bots = [data_top + (r + 1) * row_h for r in range(3)] + [float(img_h)]

        boxes = {}
        for lead, row, col in self.LEAD_LAYOUT:
            boxes[lead] = {
                'x1': round(max(0.0, x_edges[col]     / img_w), 4),
                'y1': round(max(0.0, y_tops[row]       / img_h), 4),
                'x2': round(min(1.0, x_edges[col + 1] / img_w), 4),
                'y2': round(min(1.0, y_bots[row]       / img_h), 4),
            }
        return boxes

    # ── Signal overlay ────────────────────────────────────────────────────────

    def _draw_overlay(self, s1_rgb, series_4row, pixel_4ch):
        """
        Draw the 12 extracted lead signals back onto the Stage 1 image.

        Three-layer noise handling:
          1. Percentile clip (2nd–98th) — robust outlier removal unaffected by extreme values
          2. Median filter — removes spike artifacts without smearing edges like mean smoothing
          3. Confidence masking — uses Stage 2 activation maps to skip low-confidence regions
             rather than drawing a noisy trace that would mislead the viewer
        """
        from scipy.signal import medfilt

        img   = s1_rgb.copy().astype(np.uint8)
        img_h, img_w = img.shape[:2]

        sx    = img_w / 4352.0
        col_w = (self.t1 - self.t0) / 4.0
        zm    = self.zero_mv
        series = np.asarray(series_4row)

        # Row y-boundaries (equal thirds, same as _compute_lead_boxes)
        row_gap  = zm[1] - zm[0]
        data_top = zm[0] - row_gap * 0.85
        data_bot = (zm[2] + zm[3]) / 2.0
        row_h    = (data_bot - data_top) / 3.0
        row_y_top = [int(data_top + r * row_h) for r in range(3)]
        row_y_bot = [int(data_top + (r+1) * row_h) for r in range(3)]
        strip_top = int((zm[2] + zm[3]) / 2.0)

        COLORS = [
            [(255,59,48),  (255,149,0),  (255,204,0),  (48,209,88)],
            [(0,199,190),  (0,122,255),  (88,86,214),  (175,82,222)],
            [(255,45,85),  (255,107,107),(50,173,230),  (76,217,100)],
        ]

        # Pre-compute per-column confidence from Stage 2 activation maps
        # pixel_4ch shape: (4, H, W) in the 4352-wide resized space
        # Resize activation to original image width for column slicing
        # Confidence check: slice BOTH x (column range) AND y (row band) before scoring.
        # Without the y slice, waveform pixels (<0.2% of full-height column) are buried
        # in background, making any percentile/mean useless.
        CONF_THRESHOLD = 0.25   # max activation in the row band below this → skip

        act_col_w = (self.t1 - self.t0) / 4.0

        # pixel_4ch is in 4352-wide space; height matches s1_rgb (not resized)
        act_h = pixel_4ch.shape[1]   # typically 1696

        def col_confidence(row_idx, col):
            """Max activation inside the correct row band and column range."""
            x_start = int(self.t0 + col * act_col_w)
            x_end   = int(self.t0 + (col + 1) * act_col_w)
            # y band: ±1.5 × half-row-gap around this row's baseline
            half = int(row_gap * 0.75)
            y_center = int(zm[row_idx])
            y_start  = max(0, y_center - half)
            y_end    = min(act_h, y_center + half)
            region   = pixel_4ch[row_idx, y_start:y_end, x_start:x_end]
            return float(np.max(region)) if region.size else 0.0

        # Skip this many samples at the start of each column — covers the 1mV
        # calibration pulse printed at the left edge of each lead region.
        CALIB_SKIP = 80

        def clean(sig):
            """
            Two-pass outlier removal + median filter.
            Pass 1: 5–95th percentile clip (tighter than before).
            Pass 2: ±2×IQR around median to catch residual spikes that survive pass 1.
            """
            s = np.array(sig, dtype=np.float32)
            if len(s) < 5:
                return s
            lo, hi = np.percentile(s, 5), np.percentile(s, 95)
            s = np.clip(s, lo, hi)
            q25, q75 = np.percentile(s, 25), np.percentile(s, 75)
            iqr = q75 - q25
            if iqr > 0:
                med = np.median(s)
                s = np.clip(s, med - 2 * iqr, med + 2 * iqr)
            k = min(15, len(s) if len(s) % 2 == 1 else len(s) - 1)
            k = max(k, 5)
            return medfilt(s, kernel_size=k)

        for row in range(3):
            row_signal = series[row]
            seg_len    = len(row_signal) // 4
            y_min, y_max = row_y_top[row], row_y_bot[row]
            for col in range(4):
                if col_confidence(row, col) < CONF_THRESHOLD:
                    continue

                seg   = clean(row_signal[col * seg_len : (col + 1) * seg_len])
                color = COLORS[row][col]
                x0_rs = self.t0 + col * col_w
                pts   = []
                for i, mv in enumerate(seg):
                    if i < CALIB_SKIP:
                        continue
                    x_px = int((x0_rs + i / max(len(seg)-1, 1) * col_w) * sx)
                    y_px = int(zm[row] - mv * self.mv_to_pixel)
                    y_px = max(y_min, min(y_max, y_px))
                    pts.append((x_px, y_px))
                for k in range(1, len(pts)):
                    cv2.line(img, pts[k-1], pts[k], color, thickness=2, lineType=cv2.LINE_AA)

        # Lead II rhythm strip — use row 3 activation for confidence
        rhythm_act      = pixel_4ch[3] if pixel_4ch.shape[0] > 3 else pixel_4ch[2]
        rhythm_y_center = int(zm[3])
        rhythm_y_start  = max(0, rhythm_y_center - int(row_gap * 0.75))
        rhythm_y_end    = min(act_h, rhythm_y_center + int(row_gap * 0.75))
        rhythm_region   = rhythm_act[rhythm_y_start:rhythm_y_end, self.t0:self.t1]
        rhythm_conf     = float(np.max(rhythm_region)) if rhythm_region.size else 0.0
        if rhythm_conf >= CONF_THRESHOLD:
            rhythm = clean(series[3])
            n_r    = len(rhythm)
            pts    = []
            for i, mv in enumerate(rhythm):
                if i < CALIB_SKIP:
                    continue
                x_px = int((self.t0 + i / max(n_r-1, 1) * (self.t1 - self.t0)) * sx)
                y_px = int(zm[3] - mv * self.mv_to_pixel)
                y_px = max(strip_top, min(img_h - 1, y_px))
                pts.append((x_px, y_px))
            for k in range(1, len(pts)):
                cv2.line(img, pts[k-1], pts[k], (0,199,190), thickness=2, lineType=cv2.LINE_AA)

        return img_to_b64(img)

    # ── Full inference ─────────────────────────────────────────────────────────

    def run_full(self, img_bgr, fs=500, sig_len=5000):
        """
        Returns:
            steps      — list of 6 dicts {title, description, image(base64 JPEG)}
            leads      — {lead_name: [float ...]} in mV, 500 Hz
            lead_boxes — {lead_name: {x1, y1, x2, y2}} normalized on the Stage1 image
            sampling_hz
        """
        steps = []
        img_rgb_orig = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

        steps.append({
            'title':       'Raw Input Image',
            'description': 'The original upload. May be rotated, unevenly lit, or photographed at an angle.',
            'image':        img_to_b64(img_rgb_orig),
        })
        steps.append({
            'title':       'Contrast Enhancement (CLAHE)',
            'description': 'Adaptive histogram equalization on the HSV luminance channel sharpens the waveform traces without overexposing bright areas.',
            'image':        img_to_b64(change_color(img_rgb_orig)),
        })

        img_pp_bgr = preprocess_by_source(img_bgr.copy(), source='0001')
        steps.append({
            'title':       'Image Preprocessing',
            'description': 'Source-specific corrections: noise removal, white balance, illumination normalization.',
            'image':        img_to_b64(cv2.cvtColor(img_pp_bgr, cv2.COLOR_BGR2RGB)),
        })

        # Run both preprocessing paths through Stage 0 + 1 and keep the better result
        s0_raw, s0_pp = self._run_stage0(img_bgr), self._run_stage0(img_pp_bgr)
        s1_raw, s1_pp = self._run_stage1(s0_raw),  self._run_stage1(s0_pp)
        if stage1_quality(s1_pp) > stage1_quality(s1_raw) * 1.02:
            s0_best, s1_best = s0_pp, s1_pp
        else:
            s0_best, s1_best = s0_raw, s1_raw

        steps.append({
            'title':       'Rotation Correction + Perspective Warp (Stage 0)',
            'description': 'Keypoint network detects the four paper corners and applies a homography to produce a fronto-parallel view.',
            'image':        img_to_b64(s0_best),
        })
        steps.append({
            'title':       'ECG Grid Alignment (Stage 1)',
            'description': 'Grid-point network corrects residual distortion in the mm-grid, making the pixel→mV conversion accurate.',
            'image':        img_to_b64(s1_best),
        })

        series_4row, pixel_4ch = self._run_stage2(s1_best, length=sig_len)
        steps.append({
            'title':       'Waveform Segmentation Heatmap (Stage 2 · ResNet34-UNet)',
            'description': 'Four-channel sigmoid output overlaid on the input. Orange-red regions are the pixels classified as ECG waveform.',
            'image':        heatmap_overlay(s1_best, pixel_4ch),
        })
        steps.append({
            'title':       'Extraction Overlay — Signal Reconstructed onto Source',
            'description': 'The 12 extracted lead signals drawn back onto the Stage 1 image using the exact same calibration constants used during extraction. If the colored traces align with the printed waveforms, extraction is accurate.',
            'image':        self._draw_overlay(s1_best, series_4row, pixel_4ch),
        })

        d = dw(series_dict(series_4row))
        leads_out = {}
        for lead in LEADS_ORDER:
            arr = d.get('II', d.get('II_short', np.zeros(sig_len))) if lead == 'II'                   else d.get(lead, np.zeros(sig_len // 4))
            leads_out[lead] = arr.tolist()

        img_h, img_w   = s1_best.shape[:2]
        lead_boxes     = self._compute_lead_boxes(img_h, img_w)

        metrics = compute_metrics(leads_out['II'], fs)

        return {
            'steps':       steps,
            'leads':       leads_out,
            'lead_boxes':  lead_boxes,
            'metrics':     metrics,
            'sampling_hz': fs,
        }

print('✅ PhysioPipeline defined')


## 7 — Load models

This takes 1–2 minutes. The Stage 2 weight file is the largest (~350 MB) and uses a slightly non-standard checkpoint format, so we check for the `state_dict` key before loading.


In [None]:
pipeline = PhysioPipeline(device='cuda:0' if device == 'cuda' else 'cpu')
pipeline.load_models(STAGE0_W, STAGE1_W, STAGE2_W)


## 8 — Smoke test *(optional)*

Drop an ECG image into the Colab file panel (left sidebar), update the path, and run this cell to verify the full pipeline works end to end before starting the server.


In [None]:
TEST_IMG = '/content/test_ecg.png'  # ← update path if needed

if os.path.exists(TEST_IMG):
    result = pipeline.run_full(cv2.imread(TEST_IMG))
    print(f'steps:         {len(result["steps"])}')
    print(f'leads:         {list(result["leads"].keys())}')
    print(f'Lead II pts:   {len(result["leads"]["II"])}')
    print(f'lead_boxes[I]: {result["lead_boxes"]["I"]}')
else:
    print(f'⚠️  {TEST_IMG} not found — skipping (safe to proceed)')


## 9 — Start the API server

The server is a minimal FastAPI app exposed over ngrok.

A few things worth noting:
- `nest_asyncio` patches the Colab event loop so `await server.serve()` works in a notebook cell
- CORS is fully open (`allow_origins=['*']`) — fine for a demo, tighten for production
- The ngrok URL changes on every Colab restart; paste the new one into your frontend each time

**Replace `YOUR_NGROK_TOKEN_HERE`** with your token from [dashboard.ngrok.com/authtokens](https://dashboard.ngrok.com/authtokens).


In [None]:
import os
os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True'

NGROK_TOKEN = 'YOUR_NGROK_TOKEN_HERE'  # ← paste your token here

from pyngrok import ngrok, conf
conf.get_default().auth_token = NGROK_TOKEN

import nest_asyncio, uvicorn
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

nest_asyncio.apply()

app = FastAPI(title='ECG Digitizer')
app.add_middleware(CORSMiddleware, allow_origins=['*'], allow_methods=['*'], allow_headers=['*'])

@app.get('/')
def health():
    return {'status': 'ok', 'device': device}

@app.post('/digitize')
async def digitize(file: UploadFile = File(...), fs: int = 500, sig_len: int = 5000):
    if not file.content_type.startswith('image/'):
        raise HTTPException(400, detail='image files only')
    contents = await file.read()
    img_bgr  = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
    if img_bgr is None:
        raise HTTPException(422, detail='could not decode image')
    try:
        gc.collect()
        torch.cuda.empty_cache()
        result = pipeline.run_full(img_bgr, fs=fs, sig_len=sig_len)
        gc.collect(); torch.cuda.empty_cache()
        return JSONResponse(content=result)
    except Exception as e:
        import traceback; traceback.print_exc()
        raise HTTPException(500, detail=str(e))

PORT   = 8000
tunnel = ngrok.connect(PORT, 'http')
url    = tunnel.public_url

print(f'  API   →  {url}/digitize')
print(f'  Docs  →  {url}/docs')
print('\nKeep this cell running. The URL changes every time Colab restarts.')

config = uvicorn.Config(app, host='0.0.0.0', port=PORT, log_level='warning')
await uvicorn.Server(config).serve()


---

## API reference

### `POST /digitize`

**Request** — `multipart/form-data`

| field | type | description |
|-------|------|-------------|
| `file` | image file | JPEG, PNG, or any format OpenCV can decode |
| `fs` | int (default 500) | target sampling rate in Hz |
| `sig_len` | int (default 5000) | number of output samples |

**Response** — JSON

```json
{
  "steps": [
    { "title": "...", "description": "...", "image": "data:image/jpeg;base64,..." },
    ...
  ],
  "leads": {
    "I":   [0.012, -0.003, ...],
    "II":  [0.021,  0.008, ...],
    ...
  },
  "lead_boxes": {
    "I":   { "x1": 0.054, "y1": 0.264, "x2": 0.279, "y2": 0.498 },
    ...
  },
  "sampling_hz": 500
}
```

**`steps` index**

| i | stage |
|---|-------|
| 0 | Raw input image |
| 1 | CLAHE contrast enhancement |
| 2 | Source-aware preprocessing |
| 3 | Stage 0 — rotation + perspective warp |
| 4 | Stage 1 — grid alignment |
| 5 | Stage 2 — waveform segmentation heatmap |

**`lead_boxes` coordinate system**

All values are normalized to `[0, 1]` relative to the **Stage 1 image** (`steps[4].image`).  
To get pixel coordinates for a rendered image of size `W × H`:

```javascript
const { x1, y1, x2, y2 } = data.lead_boxes['I'];
const box = {
  left:   x1 * W,
  top:    y1 * H,
  width:  (x2 - x1) * W,
  height: (y2 - y1) * H,
};
```

### Quick test with `curl`

```bash
curl -X POST https://<your-ngrok-url>/digitize \
     -F "file=@/path/to/ecg.jpg" \
     | python3 -c "import sys,json; d=json.load(sys.stdin); print(list(d['leads'].keys()))"
```
