# Near-Death Classification — Featureless Model (scikit-learn)

We build a **featureless model** that uses raw behavioral time series
(X, Y, Speed, turning_angle) over 900 frames (30 minutes) for each segment.

Each segment is flattened into a single high-dimensional vector and fed into
a standard classifier (GradientBoosting), without using hand-crafted aging features.

Label:
- `1` = segment is **close to death** (time to death ≤ T)
- `0` = otherwise

Train/val/test splits are grouped by worm to avoid leakage.


In [2]:
import os
import re
import glob
import numpy as np
import pandas as pd

from sklearn.model_selection import GroupShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.ensemble import GradientBoostingClassifier

print("✓ Imports loaded.")


✓ Imports loaded.


Définit ou sont les données:resumé des durées de vie (lifespan_summary.csv), segments préprocessés 


- Fixe la longueur d'un segment = 900 frames 
- Feature cols = colonnes brutes qu'on veut utiliser 
- Definit le seuil near-death: si il reste <20 segments avant la mort -> close_to_death = 1 

In [3]:
DATA_DIR = "TERBINAFINE"
LIFESPAN_FILE = os.path.join(DATA_DIR, "lifespan_summary.csv")
SEGMENTS_DIR = "preprocessed_data/segments"

SEGMENT_LENGTH = 900
FEATURE_COLS = ["X", "Y", "Speed", "turning_angle"]

PROXIMITY_SEGMENTS = 20
CLOSE_THRESHOLD_FRAMES = PROXIMITY_SEGMENTS * SEGMENT_LENGTH

print("DATA_DIR:", DATA_DIR)
print("LIFESPAN_FILE:", LIFESPAN_FILE)
print("SEGMENTS_DIR:", SEGMENTS_DIR)
print("CLOSE_THRESHOLD_FRAMES:", CLOSE_THRESHOLD_FRAMES)


DATA_DIR: TERBINAFINE
LIFESPAN_FILE: TERBINAFINE/lifespan_summary.csv
SEGMENTS_DIR: preprocessed_data/segments
CLOSE_THRESHOLD_FRAMES: 18000


Transformer differents formats en un identifiant cohernet pour joindre les segments et lifespan mais pas sur que ca soit necesssaire. 

In [4]:
def normalize_filename(s: str) -> str:
    s = s.strip()
    s = os.path.splitext(s)[0]
    s = s.lstrip("/")
    return s

def extract_worm_id_from_source(source_file: str) -> str:
    if source_file is None or not isinstance(source_file, str):
        return None
    base = source_file.strip()
    base = os.path.splitext(base)[0]
    m = re.search(r"(\d{8}_piworm\d+_\d+)", base)
    if m:
        return m.group(1)
    base = base.lstrip("/")
    return base


- charge le fichier avec les infos de duree de vie par ver
- trouve automatiquement la colonne filename et la colonne lifespan_frames
- normalise le nom de fichier et creer un colonne worm_ID
- COntrsuit un dictionnaire lifespan_map[worm_ID] = lifespan_en_frames
Permet pour chaque segment de savoir combien de temps total a vecu le ver

In [5]:
lifespan_df = pd.read_csv(LIFESPAN_FILE)
lifespan_df.columns = lifespan_df.columns.str.strip()
colmap = {c.lower(): c for c in lifespan_df.columns}

filename_col = None
life_col = None
for key, col in colmap.items():
    if "filename" in key:
        filename_col = col
    if "lifespan" in key and "frames" in key:
        life_col = col

print("Filename col:", filename_col)
print("Lifespan col:", life_col)

lifespan_df["Worm_ID"] = lifespan_df[filename_col].astype(str).apply(normalize_filename)
lifespan_df[life_col] = lifespan_df[life_col].astype(float)
lifespan_map = dict(zip(lifespan_df["Worm_ID"], lifespan_df[life_col]))

list(lifespan_map.items())[:5]


Filename col: Filename
Lifespan col: LifespanInFrames


[('20240924_piworm09_1', 49500.0),
 ('20240924_piworm09_2', 63000.0),
 ('20240924_piworm09_3', 57600.0),
 ('20240924_piworm09_4', 50400.0),
 ('20240924_piworm09_5', 48900.0)]

je pense que je peux l'effacer: 

In [10]:
print("Example Worm_ID keys from lifespan_map:")
for k in list(lifespan_map.keys())[:10]:
    print("  ", k)


Example Worm_ID keys from lifespan_map:
   20240924_piworm09_1
   20240924_piworm09_2
   20240924_piworm09_3
   20240924_piworm09_4
   20240924_piworm09_5
   20240924_piworm09_6
   20240924_piworm10_1
   20240924_piworm10_2
   20240924_piworm10_3
   20240924_piworm10_4


je pense que je peux l'effacer: 

In [11]:
# Inspecter un segment pour voir le champ 'source_file'
import glob, os, pandas as pd

one_seg = glob.glob(os.path.join(SEGMENTS_DIR, "*.csv"))[0]
print("Example segment file:", one_seg)

df_example = pd.read_csv(one_seg)
print("Columns:", df_example.columns.tolist())
print("source_file example:", df_example["source_file"].iloc[0])


Example segment file: preprocessed_data/segments/coordinates_highestspeed_20250205_9_3_with_time_speed-fragment4.0-preprocessed.csv
Columns: ['GlobalFrame', 'Timestamp', 'Speed', 'X', 'Y', 'condition', 'source_file', 'Segment_index', 'turning_angle', 'worm_id', 'Segment']
source_file example: coordinates_highestspeed_20250205_9_3_with_time_speed.csv


Charger tous les segments et constriotre all_segments 

- parcourt tous les fichiers de segments (8150 ici)
- pour chaque segment: 
    - lit le CSV
    - recupere le worm_id
    - retrouve la duree de vie du ver dans lifespan_map 
    - trie par global frame 
    - prend les colonnes X,Y,speed, turning_angle 
    - coupe ou pad a 900 frames 
    - calcule le temps restant avant la mort pour ce segment 
    - cree le label binaire: 1=near death et 0 =pas near-death
    - stock tout ca dans liste all_segments 

A la fin elle imprime une liste avec le nombre de segments collectes et worms pas trouvés 

In [13]:
segment_files = glob.glob(os.path.join(SEGMENTS_DIR, "*.csv"))
print(f"Found {len(segment_files)} raw segment files in {SEGMENTS_DIR}")

all_segments = []

cnt_no_wormid = 0
cnt_worm_not_found = 0
cnt_no_globalframe = 0
cnt_missing_features = 0
cnt_ok = 0

print("\nInspecting a few segment files:")
for p in segment_files[:5]:
    print("  -", p)

for seg_path in segment_files:
    df_seg = pd.read_csv(seg_path)

    # 1) colonne 'worm_id' présente ?
    if "worm_id" not in df_seg.columns:
        cnt_no_wormid += 1
        continue

    worm_id = str(df_seg["worm_id"].iloc[0]).strip()

    # On normalise un peu au cas où
    worm_id_norm = normalize_filename(worm_id)

    if worm_id_norm in lifespan_map:
        worm_key = worm_id_norm
    elif worm_id in lifespan_map:
        worm_key = worm_id
    else:
        cnt_worm_not_found += 1
        continue

    total_lifespan = float(lifespan_map[worm_key])

    # 2) GlobalFrame ?
    if "GlobalFrame" not in df_seg.columns:
        cnt_no_globalframe += 1
        continue

    df_seg = df_seg.sort_values("GlobalFrame")

    # 3) Colonnes de features présentes ?
    if not all(col in df_seg.columns for col in FEATURE_COLS):
        cnt_missing_features += 1
        continue

    feats = df_seg[FEATURE_COLS].values.astype(float)

    # Padding / truncation
    if feats.shape[0] < SEGMENT_LENGTH:
        pad_len = SEGMENT_LENGTH - feats.shape[0]
        pad = np.zeros((pad_len, feats.shape[1]), dtype=feats.dtype)
        feats_padded = np.vstack([feats, pad])
        valid_length = feats.shape[0]
    else:
        feats_padded = feats[:SEGMENT_LENGTH]
        valid_length = SEGMENT_LENGTH

    start_frame = int(df_seg["GlobalFrame"].iloc[0])
    last_frame = start_frame + valid_length - 1
    time_to_death = max(total_lifespan - last_frame, 0.0)

    label = 1 if time_to_death <= CLOSE_THRESHOLD_FRAMES else 0

    all_segments.append({
        "features": feats_padded,
        "y": label,
        "worm_id": worm_key,
        "start_frame": start_frame,
        "lifespan": total_lifespan
    })
    cnt_ok += 1

print("\n=== Segment loading summary ===")
print("Total files          :", len(segment_files))
print("Segments collected   :", cnt_ok)
print("Missing worm_id col  :", cnt_no_wormid)
print("Worm not in map      :", cnt_worm_not_found)
print("Missing GlobalFrame  :", cnt_no_globalframe)
print("Missing feature cols :", cnt_missing_features)


Found 8150 raw segment files in preprocessed_data/segments

Inspecting a few segment files:
  - preprocessed_data/segments/coordinates_highestspeed_20250205_9_3_with_time_speed-fragment4.0-preprocessed.csv
  - preprocessed_data/segments/coordinates_highestspeed_20240924_11_1_with_time_speed-fragment33.0-preprocessed.csv
  - preprocessed_data/segments/coordinates_highestspeed_20240924_12_6_with_time_speed-fragment19.0-preprocessed.csv
  - preprocessed_data/segments/coordinates_highestspeed_20250415_9_1_with_time_speed-fragment56.0-preprocessed.csv
  - preprocessed_data/segments/coordinates_highestspeed_20250205_12_5_with_time_speed-fragment5.0-preprocessed.csv

=== Segment loading summary ===
Total files          : 8150
Segments collected   : 8150
Missing worm_id col  : 0
Worm not in map      : 0
Missing GlobalFrame  : 0
Missing feature cols : 0


Construire X,Y 
- verifie qu'on a bien au moins un segment 
- empile les features -> X_raw de shape (N_segments, 900, 4) car 900 frames et 4 signaux X,Y,speed et turning_angle
- Cible binaire dans y 
- Les worm_groups, start_frames, lifespans serviront pour faire le split par ver et analyser par stage de vie (early/mid/late)


choix featureless: X_flat=chaque segment aplati en un vecteur de dimension 900x4=3600, donc le modèle voit directement les series brutes sans features d'aging 

In [14]:
if len(all_segments) == 0:
    raise RuntimeError("No segments collected. Check paths / parsing.")

X_raw = np.stack([s["features"] for s in all_segments])   # (N, 900, 4)
y = np.array([s["y"] for s in all_segments], dtype=int)
worm_groups = np.array([s["worm_id"] for s in all_segments])
start_frames = np.array([s["start_frame"] for s in all_segments])
lifespans = np.array([s["lifespan"] for s in all_segments])

print("X_raw shape:", X_raw.shape)
print("y distribution:", np.bincount(y))
print("unique worms:", len(np.unique(worm_groups)))

# ➜ featureless: on APLATIT les 900×4 temps en un seul vecteur
N, T, F = X_raw.shape
X_flat = X_raw.reshape(N, T * F)   # (N, 3600)

print("X_flat shape:", X_flat.shape)


X_raw shape: (8150, 900, 4)
y distribution: [6257 1893]
unique worms: 104
X_flat shape: (8150, 3600)


Remarques:
- X_raw shape: (nombre segments, frames, signaux)
- Xflat_shape: chaque segment applati: (nb segment, framesxsignaux)
- y distrib: 1893 near-death (1) et 6257 non (0)
- 104 worms distribués 

Split train/val/test par worm
- Utilise GroupeShuffleSplit pour separer par worm:
    - train = 70% des worms 
    - val + test = 30% restants, puis séparés 50/50 
- Evite que des segments du meme ver se retrouvent dans le train et test 
- Construit les 6 ensembles X_train, X_val, X_test + y_* + infos worms/stage de vie

In [15]:
gss = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
train_idx, temp_idx = next(gss.split(X_flat, y, groups=worm_groups))

temp_worms = worm_groups[temp_idx]
gss_val_test = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=43)
val_idx_rel, test_idx_rel = next(gss_val_test.split(X_flat[temp_idx], y[temp_idx], groups=temp_worms))

val_idx = temp_idx[val_idx_rel]
test_idx = temp_idx[test_idx_rel]

X_train, X_val, X_test = X_flat[train_idx], X_flat[val_idx], X_flat[test_idx]
y_train, y_val, y_test = y[train_idx], y[val_idx], y[test_idx]

worm_train, worm_val, worm_test = worm_groups[train_idx], worm_groups[val_idx], worm_groups[test_idx]
start_train, start_val, start_test = start_frames[train_idx], start_frames[val_idx], start_frames[test_idx]
lifespan_train, lifespan_val, lifespan_test = lifespans[train_idx], lifespans[val_idx], lifespans[test_idx]

print("Train:", X_train.shape, " y:", y_train.shape)
print("Val  :", X_val.shape, " y:", y_val.shape)
print("Test :", X_test.shape, " y:", y_test.shape)
print("unique worms train/val/test:", len(np.unique(worm_train)), len(np.unique(worm_val)), len(np.unique(worm_test)))


Train: (5743, 3600)  y: (5743,)
Val  : (1301, 3600)  y: (1301,)
Test : (1106, 3600)  y: (1106,)
unique worms train/val/test: 72 16 16


Strandardisation + modele GradientBoosting 
(TRES LONG A RUN!!)

- Strandardise les features (moyenne 0, variance 1) avec StandardScaler
- Crée un modèle GradientBoostingClassifier
- Entraine sur le train
- Calcule les performances sur validation set: accuracy, f1, auc

In [16]:
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

clf = GradientBoostingClassifier(
    n_estimators=300,
    learning_rate=0.05,
    max_depth=3,
    random_state=42
)

clf.fit(X_train_scaled, y_train)

proba_val = clf.predict_proba(X_val_scaled)[:, 1]
pred_val = (proba_val >= 0.5).astype(int)

acc_val = accuracy_score(y_val, pred_val)
f1_val = f1_score(y_val, pred_val)
auc_val = roc_auc_score(y_val, proba_val)

print("=== Validation performance ===")
print(f"Accuracy: {acc_val:.3f}")
print(f"F1-score: {f1_val:.3f}")
print(f"AUC:      {auc_val:.3f}")


=== Validation performance ===
Accuracy: 0.779
F1-score: 0.443
AUC:      0.845


Performance sur le test
- utilise le modèle entrainé pour predire le test set
- calcule les metriques finales sur test (accuracy f1 et auc)
- affiche aussi la distribution des vrais labels et des predictions 

In [17]:
proba_test = clf.predict_proba(X_test_scaled)[:, 1]
pred_test = (proba_test >= 0.5).astype(int)

acc = accuracy_score(y_test, pred_test)
f1 = f1_score(y_test, pred_test)
auc = roc_auc_score(y_test, proba_test)

print("=== Test performance (overall) ===")
print(f"Accuracy: {acc:.3f}")
print(f"F1-score: {f1:.3f}")
print(f"AUC:      {auc:.3f}")
print("y_test distribution:", np.bincount(y_test))
print("pred_test distribution:", np.bincount(pred_test))


=== Test performance (overall) ===
Accuracy: 0.821
F1-score: 0.615
AUC:      0.855
y_test distribution: [794 312]
pred_test distribution: [904 202]


Remarques:
- modele se trompe sur 18% des segments (accuracy)
- f1 correcte pour classe minoritaire
- auc 0.855 tres bon score (modele separe assez bien near-death et not-near-death a partir des series brutes sans features d'aging)

Analyse par stage de vie 
- calcule pour chaque segment de test ou il se situe dans la vie du ver:
    - life_fraction = start_frame/lifespan
    - early (<25% de la vie)
    - mid (25-75%)
    - late (>75%)
pour chaque stage on recalcule accuracy/f1/auc

In [18]:
life_fraction_test = start_test / np.maximum(lifespan_test, 1e-6)

def life_stage(frac):
    if frac < 0.25:
        return "early"
    elif frac < 0.75:
        return "mid"
    else:
        return "late"

stages = np.array([life_stage(f) for f in life_fraction_test])

for stage_name in ["early", "mid", "late"]:
    idx = np.where(stages == stage_name)[0]
    if len(idx) == 0:
        continue

    y_s = y_test[idx]
    proba_s = proba_test[idx]
    pred_s = (proba_s >= 0.5).astype(int)

    acc_s = accuracy_score(y_s, pred_s)
    f1_s = f1_score(y_s, pred_s)
    auc_s = roc_auc_score(y_s, proba_s)

    print(f"\n=== {stage_name.upper()} life stage ===")
    print(f"n segments: {len(idx)}")
    print(f"Accuracy: {acc_s:.3f}")
    print(f"F1-score: {f1_s:.3f}")
    print(f"AUC:      {auc_s:.3f}")



=== EARLY life stage ===
n segments: 286
Accuracy: 0.993
F1-score: 0.000
AUC:      nan

=== MID life stage ===
n segments: 561
Accuracy: 0.831
F1-score: 0.144
AUC:      0.622

=== LATE life stage ===
n segments: 259
Accuracy: 0.610
F1-score: 0.748
AUC:      0.650




remarques: 

early: 
- AUC-> NaN: modèle predit quasiment tout en 0 (pas near death) 
- classe positive absente ou tres tres rare donc F1=0 car pas de vrai positif

Assez logique car tot dans la vie les worms ne vont pas mourir mais bon le modele n'arrive pa a detecter la mort prochaine dans le early stage 

mid: 
- modele commence a cpater un peu de signal mais separation pas hyper nette 

late :
- il détecte bien les segments near-death (bonne F1, donc bon rappel et précision) mais se trompe davantage sur les non-near-death (d’où l’accuracy plus faible).


TEst en changeant le threshold de proximity: 
- recalcul du label near-death selon le nouveau threshold
- reconstruction automatique train/test par worm
- réentraîne un modèle featureless (GradientBoosting)
- calcule AUC / F1 / ACC
- crée un tableau simple :

| threshold | Accuracy | F1 | AUC |

(met tres longtemps a run aussi)

In [19]:
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GroupShuffleSplit

thresholds = [5, 10, 15, 20, 25, 30]

results_thr = []

for thr in thresholds:
    print(f"\n==============================")
    print(f" Threshold = {thr} segments")
    print(f"==============================")

    # Recompute labels
    y_thr = np.array([
        1 if (lifespans[i] - (start_frames[i] + 900)) <= thr*900 else 0
        for i in range(len(start_frames))
    ])

    print("Label distribution:", np.bincount(y_thr))

    # Split by worm
    gss = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
    train_idx, test_idx = next(gss.split(X_flat, y_thr, groups=worm_groups))

    X_train, X_test = X_flat[train_idx], X_flat[test_idx]
    y_train, y_test = y_thr[train_idx], y_thr[test_idx]

    scaler = StandardScaler()
    X_train_sc = scaler.fit_transform(X_train)
    X_test_sc  = scaler.transform(X_test)

    clf = GradientBoostingClassifier(
        n_estimators=300,
        learning_rate=0.05,
        max_depth=3,
        random_state=42
    )

    clf.fit(X_train_sc, y_train)

    proba = clf.predict_proba(X_test_sc)[:, 1]
    preds = (proba >= 0.5).astype(int)

    acc = accuracy_score(y_test, preds)
    f1  = f1_score(y_test, preds)
    auc = roc_auc_score(y_test, proba)

    results_thr.append({
        "threshold": thr,
        "Accuracy": acc,
        "F1": f1,
        "AUC": auc
    })

results_thresholds_featureless = pd.DataFrame(results_thr)
display(results_thresholds_featureless)



 Threshold = 5 segments
Label distribution: [7661  489]

 Threshold = 10 segments
Label distribution: [7201  949]


KeyboardInterrupt: 