# Drift vs Lipid Dynamics – Final Runnable Pipeline

This notebook contains the **complete, runnable pipeline**:
1. Drift detection using boundary alignment residuals
2. PCA on contour features (only if drift is not dominant)

You only need to provide your existing AFM helper functions.

## Imports

In [17]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from skimage.morphology import binary_erosion
from skimage.registration import phase_cross_correlation
from scipy.ndimage import shift as ndi_shift

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

import gwyfile
import numpy as np
import matplotlib.pyplot as plt
from skimage import io, color
from skimage.feature import match_template
from skimage.filters import threshold_otsu
from skimage.morphology import remove_small_objects, remove_small_holes, binary_opening, binary_closing, disk
import numpy as np
from skimage.measure import find_contours
from scipy.spatial import ConvexHull
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

### Loading in the data

In [18]:
def Loading_data(data_count,plot=True):
    path = 'C:/Users/yevhe/Desktop/Hackathon/Feature detection/Domain evolution data/Raw data/'
    image = '%s.gwy' % data_count

    obj = gwyfile.load(path+image)
    channel = obj['/0/data']['data']
    yres = obj['/0/data']['yres']
    xres = obj['/0/data']['xres']
    raw_height = channel.reshape((yres, xres))

    if plot:
        im = plt.imshow(raw_height, cmap='gray')
        plt.title("Raw data file")
        plt.colorbar(im, fraction=0.046)
        plt.tight_layout()
        plt.show()
    return raw_height

# Processing the data

### Level data by mean plane subtraction

In [19]:
def mean_plane_subtraction(Z, mask=None, plot=True):
    """
    Fit plane z = ax + by + c (least squares) and subtract it.
    mask: optional boolean array; True pixels are used for fitting the plane.
    plot: if True, shows original, fitted plane, and levelled result.
    """
    Z = Z.astype(float)
    ny, nx = Z.shape
    yy, xx = np.indices(Z.shape)

    if mask is None:
        x = xx.ravel()
        y = yy.ravel()
        z = Z.ravel()
    else:
        m = mask.astype(bool)
        x = xx[m].ravel()
        y = yy[m].ravel()
        z = Z[m].ravel()

    # Least squares solve for [a, b, c]
    A = np.column_stack([x, y, np.ones_like(x)])
    coeff, *_ = np.linalg.lstsq(A, z, rcond=None)
    a, b, c = coeff

    plane = a * xx + b * yy + c
    Z_level = Z - plane

    if plot:
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        im0 = axs[0].imshow(Z, cmap="viridis"); axs[0].set_title("Original")
        im1 = axs[1].imshow(plane, cmap="viridis"); axs[1].set_title("Fitted mean plane")
        im2 = axs[2].imshow(Z_level, cmap="viridis"); axs[2].set_title("Levelled (mean plane sub.)")
        for ax in axs:
            ax.set_xticks([]); ax.set_yticks([])
        plt.colorbar(im0, ax=axs[0], fraction=0.046)
        plt.colorbar(im1, ax=axs[1], fraction=0.046)
        plt.colorbar(im2, ax=axs[2], fraction=0.046)
        plt.tight_layout()
        plt.show()

    return Z_level, plane, (a, b, c)

#Z_level, plane, coeffs = mean_plane_subtraction(raw_height, plot=True)
#print("Plane coeffs (a,b,c):", coeffs)


### Flattening the image

In [20]:
def poly_flatten_rows(h, order=1,plot=True):
    h = np.asarray(h)
    ny, nx = h.shape
    x = np.arange(nx)
    out = np.empty_like(h, dtype=float)

    for i in range(ny):
        coeff = np.polyfit(x, h[i, :], order)
        baseline = np.polyval(coeff, x)
        out[i, :] = h[i, :] - baseline
    
    if plot:
        im = plt.imshow(out, cmap='gray')
        plt.title("Polynomial flattened image")
        plt.colorbar(im, fraction=0.046)
        plt.tight_layout()
        plt.show()
    return out


  # “horizontal” style


### Level data by fitting plane through 3 points

In [21]:
import numpy as np
import matplotlib.pyplot as plt

def level_by_3_points(Z, p1, p2, p3, show_plane=False,plot=True):
    """
    Level a height image by fitting a plane through 3 points
    and plot reference points for visual validation.

    Parameters
    ----------
    Z : 2D numpy array
        Height image
    p1, p2, p3 : (x, y)
        Pixel coordinates of reference points (x = column, y = row)
    show_plane : bool
        If True, also plot the fitted plane

    Returns
    -------
    Z_level : 2D numpy array
        Levelled height image
    plane : 2D numpy array
        Fitted plane
    """

    # Unpack points
    (x1, y1), (x2, y2), (x3, y3) = p1, p2, p3

    # Heights at points
    z1 = Z[int(round(y1)), int(round(x1))]
    z2 = Z[int(round(y2)), int(round(x2))]
    z3 = Z[int(round(y3)), int(round(x3))]

    # Solve plane: z = a*x + b*y + c
    A = np.array([[x1, y1, 1],
                  [x2, y2, 1],
                  [x3, y3, 1]], dtype=float)
    b = np.array([z1, z2, z3], dtype=float)

    a, bcoef, c = np.linalg.solve(A, b)

    yy, xx = np.indices(Z.shape)
    plane = a * xx + bcoef * yy + c

    Z_level = Z - plane

    # -----------------------------
    # Plotting
    if plot:
        # -----------------------------
        fig, axs = plt.subplots(1, 2 + int(show_plane), figsize=(12, 4))

        # Original image with points
        im0 = axs[0].imshow(Z, cmap='viridis')
        axs[0].scatter([x1, x2, x3], [y1, y2, y3],
                    c='red', s=60, marker='x', label='Reference points')
        axs[0].set_title("Original image")
        axs[0].legend()
        plt.colorbar(im0, ax=axs[0], fraction=0.046)

        # Levelled image
        im1 = axs[1].imshow(Z_level, cmap='gray')
        axs[1].set_title("Levelled image (3-point plane)")
        plt.colorbar(im1, ax=axs[1], fraction=0.046)

        # Optional plane plot
        if show_plane:
            im2 = axs[2].imshow(plane, cmap='viridis')
            axs[2].set_title("Fitted plane")
            plt.colorbar(im2, ax=axs[2], fraction=0.046)

        for ax in axs:
            ax.set_xticks([])
            ax.set_yticks([])

        plt.tight_layout()
        plt.show()

    return Z_level, plane




### Conveting the image to binary

### Croping Data

In [22]:
# --- crop helpers (keep yours if already defined) ---
def center_crop(img, crop_size=80):
    h, w = img.shape
    start_row = (h - crop_size) // 2
    start_col = (w - crop_size) // 2
    return img[start_row:start_row + crop_size, start_col:start_col + crop_size]

def random_crop(img, crop_size=80, rng=None):
    h, w = img.shape
    max_row = h - crop_size
    max_col = w - crop_size
    rng = np.random.default_rng() if rng is None else rng
    r0 = rng.integers(0, max_row + 1)
    c0 = rng.integers(0, max_col + 1)
    return img[r0:r0 + crop_size, c0:c0 + crop_size], (int(r0), int(c0))


In [23]:
def plot_binary(Z_level,plot=True):

    # --- 1) Build 1D array for histogram/thresholding ---
    h = Z_level.astype(float).ravel()

    lo, hi = np.percentile(h, [0.5, 99.5])
    h_clip = h[(h > lo) & (h < hi)]

    # --- 2) Otsu on clipped distribution ---
    thr = threshold_otsu(h_clip)
    #print(f"Otsu threshold (on clipped data): {thr:.3f} nm")

    # --- 3) Threshold full image ---
    mask = Z_level > thr

    # --- 4) Clean mask ---
    mask = remove_small_objects(mask, min_size=200)
    mask = remove_small_holes(mask, area_threshold=200)
    mask = binary_opening(mask, disk(1))
    mask = binary_closing(mask, disk(2))
    if plot:
        # --- 5) Main figure: histogram | image+contour | mask ---
        fig, (ax1, ax2, ax3) = plt.subplots(
            1, 3, figsize=(14, 4),
            gridspec_kw=dict(width_ratios=[1, 1.2, 1])
        )

        # Histogram
        ax1.hist(h_clip, bins=200)
        ax1.axvline(thr, color='red', linestyle='--', linewidth=2,
                    label=f'Otsu = {thr:.3f} nm')
        ax1.set_xlabel("Height (nm)")
        ax1.set_ylabel("Count")
        ax1.legend()

        # Image + contour
        ax2.imshow(Z_level, cmap='gray')
        ax2.contour(mask, levels=[0.5], colors='r', linewidths=1)
        ax2.set_title("Levelled image + mask boundary")
        ax2.axis('off')

        # Binary mask
        ax3.imshow(mask, cmap='gray')
        ax3.set_title("Binary mask (cleaned)")
        ax3.axis('off')

        fig.tight_layout()
        plt.show()

    return mask, thr


# Statisitics on the Data

# Data analysis

### Executing the data processing

In [24]:
import numpy as np
import pandas as pd


# --- your file selection ---
data_list = ['1','2','3','4','5','6','7','8']

rows = []
rng = np.random.default_rng(0)       # reproducible random crops
n_random_per_frame = 5               # set to 1 if you only want one random crop each

for data_count in data_list:
    raw_height = Loading_data(data_count, plot=False)
    Z_level, plane, coeffs = mean_plane_subtraction(raw_height, plot=False)
    flattened = poly_flatten_rows(Z_level, order=0, plot=False)
    leveled_3p, plane = level_by_3_points(
        flattened, p1=(10, 20), p2=(25, 50), p3=(30, 90),
        show_plane=True, plot=False
    )

    # ----------------
    # CENTER crop stats
    # ----------------
    center_img = center_crop(leveled_3p, crop_size=80)
    mask_c = plot_binary(center_img, plot=False)[0]
    stats_c, _ = contour_stats(mask_c, pixel_size=None)

    if stats_c.get("ok", False):
        stats_c["count"] = data_count
        stats_c["crop_type"] = "center"
        stats_c["rep"] = 0                  # always 0 for center
        stats_c["crop_r0"] = 10             # 100->80 center crop fixed offset
        stats_c["crop_c0"] = 10
        rows.append(stats_c)

    # ----------------
    # RANDOM crop stats (repeatable)
    # ----------------
    for rep in range(n_random_per_frame):
        rand_img, (r0, c0) = random_crop(leveled_3p, crop_size=80, rng=rng)
        mask_r = plot_binary(rand_img, plot=False)[0]
        stats_r, _ = contour_stats(mask_r, pixel_size=None)

        if not stats_r.get("ok", False):
            continue

        stats_r["count"] = data_count
        stats_r["crop_type"] = "random"
        stats_r["rep"] = rep
        stats_r["crop_r0"] = r0
        stats_r["crop_c0"] = c0
        rows.append(stats_r)

df = pd.DataFrame(rows)

# Multi-index makes splitting super easy later
df = df.set_index(["count", "crop_type", "rep"]).sort_index()

print(df.tail())



NameError: name 'contour_stats' is not defined

## Alignment and residual feature functions

In [None]:

def mask_to_boundary(mask):
    mask = mask.astype(bool)
    return (mask ^ binary_erosion(mask)).astype(float)

def estimate_shift(ref_img, mov_img, upsample_factor=10):
    shift_rc, error, _ = phase_cross_correlation(
        ref_img, mov_img, upsample_factor=upsample_factor
    )
    return shift_rc, float(error)

def residual_after_translation(mask_ref, mask_mov):
    ref = mask_to_boundary(mask_ref)
    mov = mask_to_boundary(mask_mov)

    shift_rc, pc_error = estimate_shift(ref, mov)

    mov_aligned = ndi_shift(
        mov, shift=shift_rc, order=0, mode="constant", cval=0.0
    )

    ref_b = ref > 0.5
    mov_b = mov_aligned > 0.5

    mismatch = np.logical_xor(ref_b, mov_b).mean()
    inter = np.logical_and(ref_b, mov_b).sum()
    union = np.logical_or(ref_b, mov_b).sum()
    iou = inter / union if union else 1.0

    return {
        "shift_row": float(shift_rc[0]),
        "shift_col": float(shift_rc[1]),
        "phasecorr_error": float(pc_error),
        "boundary_mismatch": float(mismatch),
        "boundary_iou": float(iou),
    }


## Crop utility

In [None]:

def center_crop(img, crop_size=80):
    h, w = img.shape
    s = (h - crop_size) // 2
    return img[s:s+crop_size, s:s+crop_size]


## Build masks from AFM frames

In [None]:

def build_center_masks(data_list):
    masks = []
    for idx in data_list:
        raw = Loading_data(idx, plot=False)
        Z, *_ = mean_plane_subtraction(raw, plot=False)
        Z = poly_flatten_rows(Z, order=0, plot=False)
        Z, _ = level_by_3_points(Z, (10,20),(25,50),(30,90), plot=False)

        crop = center_crop(Z, 80)
        mask = plot_binary(crop, plot=False)[0]
        masks.append(mask)
    return masks


## Drift feature extraction

In [None]:

def drift_features_from_masks(masks):
    rows = []
    for i in range(1, len(masks)):
        feats = residual_after_translation(masks[i-1], masks[i])
        feats["frame"] = i
        rows.append(feats)
    return pd.DataFrame(rows).set_index("frame")


## Drift diagnostics

In [None]:

def plot_drift_diagnostics(df):
    plt.figure(figsize=(6,3))
    plt.plot(df.index, df["shift_row"], "-o", label="Δrow")
    plt.plot(df.index, df["shift_col"], "-o", label="Δcol")
    plt.legend()
    plt.xlabel("Frame")
    plt.ylabel("Shift (px)")
    plt.title("Estimated drift per frame")
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(4,4))
    plt.scatter(df["boundary_iou"], df["boundary_mismatch"], s=60, edgecolor="k")
    plt.xlabel("Boundary IoU")
    plt.ylabel("Boundary mismatch")
    plt.title("Alignment residual space")
    plt.tight_layout()
    plt.show()


## Drift gate

In [None]:

def drift_gate(df, iou_thresh=0.9, mismatch_thresh=0.08, frac_thresh=0.6):
    drift_like = (
        (df["boundary_iou"] >= iou_thresh) &
        (df["boundary_mismatch"] <= mismatch_thresh)
    )
    frac = drift_like.mean() if len(df) else 0.0
    return frac >= frac_thresh, frac


## PCA analysis

In [None]:

def run_pca(df, index_col="count"):
    df2 = df.copy()
    if index_col in df2.columns:
        df2 = df2.set_index(index_col)

    X = df2.select_dtypes(include=[np.number])
    scaler = StandardScaler()
    Xs = scaler.fit_transform(X)

    pca = PCA()
    scores = pca.fit_transform(Xs)

    loadings = pd.DataFrame(
        pca.components_.T,
        index=X.columns,
        columns=[f"PC{i+1}" for i in range(pca.n_components_)]
    )

    plt.figure(figsize=(5,3))
    plt.plot(np.cumsum(pca.explained_variance_ratio_), marker='o')
    plt.xlabel("Components")
    plt.ylabel("Cumulative explained variance")
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(5,4))
    plt.scatter(scores[:,0], scores[:,1])
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.tight_layout()
    plt.show()

    return pca, loadings


## Run full pipeline

In [None]:

# data_list = ['1','2','3','4','5','6','7','8']
# df = <your contour feature dataframe>

masks = build_center_masks(data_list)
df_drift = drift_features_from_masks(masks)

plot_drift_diagnostics(df_drift)

drift_detected, frac = drift_gate(df_drift)
print(f"Drift detected: {drift_detected} (fraction={frac:.2f})")

if not drift_detected:
    print("Running PCA...")
    pca, loadings = run_pca(df)
    for pc in loadings.columns[:3]:
        print(f"\nTop contributors to {pc}:")
        print(loadings[pc].abs().sort_values(ascending=False).head(8))
else:
    print("PCA skipped due to drift dominance.")
