In [None]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [None]:
!ls -lh /content/drive/MyDrive/oasis_project/data/demographics


total 52K
-rw-r--r-- 1 root root 50K Oct  9 04:46 oasis2_demographics.xlsx


In [None]:
from google.colab import drive
import os
import shutil

# Unmount drive if already mounted and clear the mountpoint
if os.path.exists('/content/drive'):
    try:
        drive.flush_and_unmount()
        # It might take a moment for the unmount to complete,
        # but often clearing the directory is sufficient.
    except ValueError:
        # Drive was not mounted, no need to unmount
        pass

    # Clear the mountpoint directory
    if os.path.exists('/content/drive') and os.path.isdir('/content/drive'):
        for item in os.listdir('/content/drive'):
            item_path = os.path.join('/content/drive', item)
            try:
                if os.path.isfile(item_path) or os.path.islink(item_path):
                    os.unlink(item_path)
                elif os.path.isdir(item_path):
                    shutil.rmtree(item_path)
            except Exception as e:
                print(f"Error removing {item_path}: {e}")


drive.mount('/content/drive', force_remount=True)

print("✅ Drive mounted.")
print("Drive contents at root:")
for f in os.listdir("/content/drive"):
    print(" -", f)

# verify main project folder exists
proj = "/content/drive/MyDrive/oasis_project"
print("\nProject folder exists?", os.path.exists(proj))
if os.path.exists(proj):
    print("Contents of oasis_project:")
    for f in os.listdir(proj):
        print("   ", f)

Drive not mounted, so nothing to flush and unmount.
Mounted at /content/drive
✅ Drive mounted.
Drive contents at root:
 - MyDrive
 - .shortcut-targets-by-id
 - .Trash-0
 - .Encrypted

Project folder exists? True
Contents of oasis_project:
    notebooks
    data
    outputs
    logs
    oasis2_graph_dataset.pt


In [None]:
from google.colab import files

print("📤 Please upload your demographics Excel file (e.g., oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx)")
uploaded = files.upload()


📤 Please upload your demographics Excel file (e.g., oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx)


Saving oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx to oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx


In [None]:
import os, shutil

BASE = "/content/drive/MyDrive/oasis_project"
DEM_DIR = os.path.join(BASE, "data", "demographics")
os.makedirs(DEM_DIR, exist_ok=True)

# Detect uploaded file automatically
uploaded_name = list(uploaded.keys())[0]
src = f"/content/{uploaded_name}"
dst = os.path.join(DEM_DIR, "oasis2_demographics.xlsx")

shutil.move(src, dst)
print(f"✅ Moved demographics file to: {dst}")


✅ Moved demographics file to: /content/drive/MyDrive/oasis_project/data/demographics/oasis2_demographics.xlsx


In [None]:
# FINALIZE LABELS: try safe fallback mapping (0003 -> OAS2_0003) and export final .pt
import os, glob, re, numpy as np, pandas as pd, torch, shutil
from torch_geometric.data import Data
from tqdm import tqdm

BASE = "/content/drive/MyDrive/oasis_project"
NPZ_DIR = os.path.join(BASE, "data", "graphs", "npz_graphs_resave")
LABELED_DIR = os.path.join(BASE, "data", "graphs", "labeled_npz_auto")
DEM_PATH = os.path.join(BASE, "data", "demographics", "oasis2_demographics.xlsx")
AUTO_INDEX_CSV = os.path.join(BASE, "data", "graphs", "graph_label_index_auto.csv")
AUTO_MISSES_CSV = os.path.join(BASE, "data", "graphs", "label_misses_auto.csv")
FINAL_INDEX_CSV = os.path.join(BASE, "data", "graphs", "final_graph_label_index_auto.csv")
FINAL_MISSES_CSV = os.path.join(BASE, "data", "graphs", "final_label_misses_auto.csv")
OUT_PT = os.path.join(BASE, "data", "graphs", "oasis2_graphs_labeled_auto_final.pt")

os.makedirs(LABELED_DIR, exist_ok=True)

# --- load demographics keys ---
if DEM_PATH.lower().endswith('.csv'):
    df_dem = pd.read_csv(DEM_PATH)
else:
    df_dem = pd.read_excel(DEM_PATH)
df_dem.rename(columns={c:c.strip() for c in df_dem.columns}, inplace=True)
cols = list(df_dem.columns)
id_col = next((c for c in cols if any(tok in c.upper() for tok in ("MRI","MRI_ID","MRIID","SUBJ","SUBJECT","ID"))), cols[0])
label_col = next((c for c in cols if any(tok in c.upper() for tok in ("CDR","CDR_GLOBAL","DEMENTIA","DIAG","SEVERITY"))), None)
if label_col is None:
    numeric_cols = [c for c in cols if np.issubdtype(df_dem[c].dtype, np.number)]
    label_col = numeric_cols[0] if numeric_cols else cols[-1]
dem_map = { str(x).strip(): df_dem.loc[i, label_col] for i,x in df_dem[id_col].items() if pd.notna(x) }
dem_keys = set(dem_map.keys())

print("Loaded dem keys sample:", list(dem_keys)[:8], " total:", len(dem_keys))

# --- load auto index + misses produced earlier ---
auto_idx = pd.read_csv(AUTO_INDEX_CSV) if os.path.exists(AUTO_INDEX_CSV) else pd.DataFrame()
misses = pd.read_csv(AUTO_MISSES_CSV) if os.path.exists(AUTO_MISSES_CSV) else pd.DataFrame()

print("Auto-index rows:", len(auto_idx), "Auto-misses rows:", len(misses))

# --- helper: extract sid from npz filename used earlier ---
def sid_from_fname(fname):
    bn = os.path.basename(fname)
    m = re.search(r'(subj[_-]\d+|sub[_-]\d+|subj\d+|sub\d+|\d{2,})', bn, flags=re.IGNORECASE)
    return (m.group(0) if m else os.path.splitext(bn)[0])

def normalize(x):
    if pd.isna(x): return None
    return str(x).strip()

# --- Attempt aggressive fallback mapping for misses: numeric -> OAS2_ + zero-pad(4) (only if key exists) ---
added_rows = []
remaining_misses = []
for _, row in misses.iterrows():
    p = row.get('npz_file') if 'npz_file' in row else row.get('file', None)
    sid = row.get('npz_sid') if 'npz_sid' in row else sid_from_fname(p)
    sid = normalize(sid)
    mapped = None
    if sid and sid.isdigit():
        cand = "OAS2_" + sid.zfill(4)
        if cand in dem_keys:
            mapped = cand
    # also try numeric-only without prefix
    if mapped is None and sid and sid.isdigit() and sid in dem_keys:
        mapped = sid
    # also try "OAS2-" variant
    if mapped is None and sid and sid.isdigit():
        cand2 = "OAS2-" + sid.zfill(4)
        if cand2 in dem_keys:
            mapped = cand2

    if mapped is not None:
        # label and write labeled npz
        arr = np.load(p, allow_pickle=True)
        x = arr.get('x', np.array([]))
        pos = arr.get('pos', np.array([]))
        ei = arr.get('edge_index', np.empty((2,0), dtype=np.int64))
        y = np.array([dem_map[mapped]])
        outname = os.path.join(LABELED_DIR, os.path.basename(p).replace(".npz","_labeled.npz"))
        np.savez_compressed(outname, x=x, pos=pos, edge_index=ei, y=y)
        added_rows.append({"npz_file": p, "labeled_npz": outname, "npz_sid": sid, "dem_key": mapped, "y": float(dem_map[mapped])})
    else:
        remaining_misses.append({"npz_file": p, "npz_sid": sid, "suggestions": row.get('suggestions') if 'suggestions' in row else []})

print("Aggressive fallback labeled:", len(added_rows), "Remaining misses after fallback:", len(remaining_misses))

# --- combine index and save final CSVs ---
final_index = pd.concat([
    auto_idx if not auto_idx.empty else pd.DataFrame(columns=["npz_file","labeled_npz","npz_sid","dem_key","y"]),
    pd.DataFrame(added_rows)
], ignore_index=True)

final_misses_df = pd.DataFrame(remaining_misses)

final_index.to_csv(FINAL_INDEX_CSV, index=False)
final_misses_df.to_csv(FINAL_MISSES_CSV, index=False)
print("Saved final index:", FINAL_INDEX_CSV)
print("Saved final misses:", FINAL_MISSES_CSV)
print("Total labeled (auto + fallback):", len(final_index))
print("Remaining misses:", len(final_misses_df))

# --- Optional: build final .pt from *all* labeled NPZs (both originally auto-labeled and newly labeled) ---
# We'll gather labeled files from final_index['npz_file'] to ensure consistent set
labeled_files = final_index['npz_file'].tolist()
data_list = []
for p in tqdm(labeled_files, desc="Creating Data objects"):
    try:
        arr = np.load(p, allow_pickle=True)
        x = arr.get('x', np.array([]))
        pos = arr.get('pos', np.array([]))
        ei = arr.get('edge_index', np.empty((2,0), dtype=np.int64))
        y = final_index.loc[final_index['npz_file']==p, 'y'].values
        # convert to tensors when possible
        try:
            xt = torch.tensor(x, dtype=torch.float) if getattr(x, "size", 0) else None
            pt = torch.tensor(pos, dtype=torch.float) if getattr(pos, "size", 0) else None
            eit = torch.tensor(ei, dtype=torch.long) if getattr(ei, "size", 0) else None
            yt = torch.tensor(y, dtype=torch.float) if len(y)>0 else None
            g = Data(x=xt, edge_index=eit, pos=pt, y=yt)
            # set subject id if available
            sid = sid_from_fname(p)
            g.subject_id = sid
            data_list.append(g)
        except Exception as e:
            # skip conversion errors, but continue
            pass
    except Exception as e:
        print("Failed loading npz", p, e)

if data_list:
    torch.save(data_list, OUT_PT)
    print("Saved final labeled .pt with", len(data_list), "graphs to:", OUT_PT)
else:
    print("No Data objects created (nothing saved as .pt).")

# --- If misses remain, give instructions and small template to edit manually ---
if len(final_misses_df) > 0:
    print("\nThere are remaining misses. Edit the CSV at:\n  ", FINAL_MISSES_CSV)
    print("Add a column named 'mapped_dem_key' and put the dem key (e.g. OAS2_0123) for each npz_file you want to map.")
    print("When done, run the small apply-manual-mapping block below (I included it here).")

    # show small apply-manual-mapping snippet for convenience
    print("\n--- To apply manual mappings (run after you edit), execute this block: ---\n")
    print(r"""
# APPLY manual mappings from final_label_misses_auto.csv (after editing)
import pandas as pd, numpy as np, os
BASE = "/content/drive/MyDrive/oasis_project"
LABELED_DIR = os.path.join(BASE, "data", "graphs", "labeled_npz_auto")
FINAL_MISSES_CSV = os.path.join(BASE, "data", "graphs", "final_label_misses_auto.csv")
DEM_PATH = os.path.join(BASE, "data", "demographics", "oasis2_demographics.xlsx")
dfm = pd.read_csv(FINAL_MISSES_CSV)
# expected columns: npz_file, npz_sid, suggestions, mapped_dem_key
for _, r in dfm.iterrows():
    p = r['npz_file']
    mkey = r.get('mapped_dem_key')
    if pd.isna(mkey) or not mkey: continue
    # load dem and write labeled npz
    if DEM_PATH.lower().endswith('.csv'):
        dem = pd.read_csv(DEM_PATH)
    else:
        dem = pd.read_excel(DEM_PATH)
    dem.rename(columns={c:c.strip() for c in dem.columns}, inplace=True)
    idcol = next((c for c in dem.columns if any(tok in c.upper() for tok in ("MRI","MRI_ID","MRIID","SUBJ","SUBJECT","ID"))), dem.columns[0])
    labcol = next((c for c in dem.columns if any(tok in c.upper() for tok in ("CDR","CDR_GLOBAL","DEMENTIA","DIAG","SEVERITY"))), dem.columns[-1])
    val = dem.loc[dem[idcol].astype(str).str.strip()==str(mkey).strip(), labcol]
    if val.size==0:
        print("mapped key not found in demographics:", mkey); continue
    y = np.array([float(val.iloc[0])])
    arr = np.load(p, allow_pickle=True)
    x = arr.get('x', np.array([])); pos = arr.get('pos', np.array([])); ei = arr.get('edge_index', np.empty((2,0),dtype=np.int64))
    outname = os.path.join(LABELED_DIR, os.path.basename(p).replace(".npz","_labeled.npz"))
    np.savez_compressed(outname, x=x, pos=pos, edge_index=ei, y=y)
    print("Saved manual-labeled:", outname)
""")

print("\nFinished finalization step.")


Loaded dem keys sample: ['OAS2_0081', 'OAS2_0035', 'OAS2_0114', 'OAS2_0112', 'OAS2_0175', 'OAS2_0181', 'OAS2_0046', 'OAS2_0022']  total: 150
Auto-index rows: 150 Auto-misses rows: 59
Aggressive fallback labeled: 0 Remaining misses after fallback: 59
Saved final index: /content/drive/MyDrive/oasis_project/data/graphs/final_graph_label_index_auto.csv
Saved final misses: /content/drive/MyDrive/oasis_project/data/graphs/final_label_misses_auto.csv
Total labeled (auto + fallback): 150
Remaining misses: 59


Creating Data objects: 100%|██████████| 150/150 [00:01<00:00, 82.05it/s]


Saved final labeled .pt with 150 graphs to: /content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_auto_final.pt

There are remaining misses. Edit the CSV at:
   /content/drive/MyDrive/oasis_project/data/graphs/final_label_misses_auto.csv
Add a column named 'mapped_dem_key' and put the dem key (e.g. OAS2_0123) for each npz_file you want to map.
When done, run the small apply-manual-mapping block below (I included it here).

--- To apply manual mappings (run after you edit), execute this block: ---


# APPLY manual mappings from final_label_misses_auto.csv (after editing)
import pandas as pd, numpy as np, os
BASE = "/content/drive/MyDrive/oasis_project"
LABELED_DIR = os.path.join(BASE, "data", "graphs", "labeled_npz_auto")
FINAL_MISSES_CSV = os.path.join(BASE, "data", "graphs", "final_label_misses_auto.csv")
DEM_PATH = os.path.join(BASE, "data", "demographics", "oasis2_demographics.xlsx")
dfm = pd.read_csv(FINAL_MISSES_CSV)
# expected columns: npz_file, npz_sid, suggesti

In [None]:
# FUZZY RELABEL: produce suggestions for remaining misses and optionally apply high-confidence matches
import os, glob, re, numpy as np, pandas as pd, difflib
from tqdm import tqdm

BASE = "/content/drive/MyDrive/oasis_project"
DEM_PATH = os.path.join(BASE, "data", "demographics", "oasis2_demographics.xlsx")
FINAL_MISSES = os.path.join(BASE, "data", "graphs", "final_label_misses_auto.csv")
FINAL_INDEX = os.path.join(BASE, "data", "graphs", "final_graph_label_index_auto.csv")
NPZ_DIR = os.path.join(BASE, "data", "graphs", "npz_graphs_resave")
LABELED_DIR = os.path.join(BASE, "data", "graphs", "labeled_npz_auto")
OUT_SUGGESTIONS = os.path.join(BASE, "data", "graphs", "label_misses_fuzzy_suggestions.csv")
OUT_UPDATED_INDEX = os.path.join(BASE, "data", "graphs", "final_graph_label_index_auto_postfuzzy.csv")
OUT_UPDATED_MISSES = os.path.join(BASE, "data", "graphs", "final_label_misses_auto_postfuzzy.csv")
os.makedirs(LABELED_DIR, exist_ok=True)

# Config
TOP_K = 5                 # how many fuzzy candidates to keep per miss
AUTO_APPLY = False        # set True to automatically label high-confidence matches
AUTO_THRESHOLD = 0.85     # similarity threshold (0..1) for auto-apply
PRINT_SAMPLE = True

# --- load demographics keys & label values ---
if DEM_PATH.lower().endswith('.csv'):
    dem_df = pd.read_csv(DEM_PATH)
else:
    dem_df = pd.read_excel(DEM_PATH)
dem_df.rename(columns={c:c.strip() for c in dem_df.columns}, inplace=True)
cols = list(dem_df.columns)
id_col = next((c for c in cols if any(tok in c.upper() for tok in ("MRI","MRI_ID","MRIID","SUBJ","SUBJECT","ID"))), cols[0])
label_col = next((c for c in cols if any(tok in c.upper() for tok in ("CDR","CDR_GLOBAL","DEMENTIA","DIAG","SEVERITY"))), None)
if label_col is None:
    numeric_cols = [c for c in cols if np.issubdtype(dem_df[c].dtype, np.number)]
    label_col = numeric_cols[0] if numeric_cols else cols[-1]
# build canonical key list (strings)
dem_keys = [str(x).strip() for x in dem_df[id_col].astype(str).tolist()]
dem_map = {str(k).strip(): dem_df.loc[i, label_col] for i,k in dem_df[id_col].astype(str).items()}
print("Dem keys sample:", dem_keys[:8], "count:", len(dem_keys))

# --- load misses dataframe produced earlier ---
if not os.path.exists(FINAL_MISSES):
    raise SystemExit("Missing file: " + FINAL_MISSES + " — run earlier finalize step first.")
miss_df = pd.read_csv(FINAL_MISSES)
# normalize npz_file column name
if 'npz_file' not in miss_df.columns and 'file' in miss_df.columns:
    miss_df = miss_df.rename(columns={'file':'npz_file'})
if 'npz_sid' not in miss_df.columns:
    # try extracting sid from filename
    def sid_from_fname(fname):
        bn = os.path.basename(fname)
        m = re.search(r'(subj[_-]\d+|sub[_-]\d+|subj\d+|sub\d+|\d{2,})', bn, flags=re.IGNORECASE)
        return (m.group(0) if m else os.path.splitext(bn)[0])
    miss_df['npz_sid'] = miss_df['npz_file'].apply(lambda p: sid_from_fname(p) if isinstance(p, str) else "")

print("Miss rows:", len(miss_df))

# candidate generator (common variants)
def candidate_variants(sid):
    sid = str(sid).strip()
    variants = set()
    if not sid: return []
    variants.add(sid)
    # numeric zero-pad forms
    digits = re.sub(r'\D+', '', sid)
    if digits:
        variants.add(digits)
        variants.add(digits.zfill(4))
        variants.add("OAS2_" + digits.zfill(4))
        variants.add("OAS2-" + digits.zfill(4))
        variants.add("OAS2" + digits.zfill(4))
    # uppercase/lowercase, prefixes, leading zeros
    variants.add(sid.upper())
    variants.add(sid.replace("subj","OAS2_").upper())
    variants.add("OAS2_" + sid)
    # strip leading zeros
    variants.add(sid.lstrip("0"))
    return [v for v in variants if v]

# fuzzy similarity wrapper (difflib SequenceMatcher -> [0,1])
def sim(a,b):
    try:
        return difflib.SequenceMatcher(None, str(a).lower(), str(b).lower()).ratio()
    except:
        return 0.0

rows_out = []
auto_applied = []
for _, r in tqdm(miss_df.iterrows(), total=len(miss_df), desc="Fuzzy matching"):
    npz_file = r.get('npz_file')
    sid = str(r.get('npz_sid') or "")
    cand_list = []
    # first try exact variant matches
    for v in candidate_variants(sid):
        if v in dem_map:
            cand_list.append((v, 1.0))
    # if none exact, compute fuzzy top-K vs dem_keys (but compute against reasonable subset)
    if not cand_list:
        # reduce search set: dem_keys that contain any digits from sid or share tokens with sid
        subset = dem_keys
        digits = re.sub(r'\D+','', sid)
        if digits:
            subset = [k for k in dem_keys if digits in re.sub(r'\D+','',k)]
        # if subset empty or too small, fallback to full dem_keys
        if not subset or len(subset) < 10:
            subset = dem_keys
        # compute top-K
        sims = [(k, sim(sid, k)) for k in subset]
        sims = sorted(sims, key=lambda x: x[1], reverse=True)[:TOP_K]
        cand_list.extend(sims)
    # dedupe and keep top-K by score
    seenk = set(); final_cands = []
    for k,score in cand_list:
        if k in seenk: continue
        seenk.add(k)
        final_cands.append((k, float(score)))
    final_cands = sorted(final_cands, key=lambda x: x[1], reverse=True)[:TOP_K]
    rows_out.append({"npz_file": npz_file, "npz_sid": sid, "candidates": ";".join([f"{k}|{s:.3f}" for k,s in final_cands])})

    # auto-apply if high confidence
    if AUTO_APPLY and final_cands:
        best_key, best_score = final_cands[0]
        if best_score >= AUTO_THRESHOLD:
            try:
                arr = np.load(npz_file, allow_pickle=True)
                x = arr.get('x', np.array([])); pos = arr.get('pos', np.array([]))
                ei = arr.get('edge_index', np.empty((2,0), dtype=np.int64))
                y = np.array([float(dem_map[best_key])])
                outp = os.path.join(LABELED_DIR, os.path.basename(npz_file).replace(".npz","_labeled.npz"))
                np.savez_compressed(outp, x=x, pos=pos, edge_index=ei, y=y)
                auto_applied.append({"npz_file": npz_file, "dem_key": best_key, "score": best_score, "out": outp})
            except Exception as e:
                print("Auto-apply failed for", npz_file, e)

# save suggestions CSV
sugg_df = pd.DataFrame(rows_out)
sugg_df.to_csv(OUT_SUGGESTIONS, index=False)
print("Saved fuzzy suggestions to:", OUT_SUGGESTIONS)
if PRINT_SAMPLE:
    print("\nSample suggestions (first 8):")
    print(sugg_df.head(8).to_dict(orient='records'))

# if auto-applied, update index and misses
if AUTO_APPLY and auto_applied:
    print("Auto-applied mappings:", len(auto_applied))
    # load existing final index
    idx = pd.read_csv(FINAL_INDEX) if os.path.exists(FINAL_INDEX) else pd.DataFrame(columns=["npz_file","labeled_npz","npz_sid","dem_key","y"])
    for a in auto_applied:
        idx = idx.append({"npz_file": a['npz_file'], "labeled_npz": a['out'], "npz_sid": re.sub(r'\D+','', os.path.basename(a['npz_file'])), "dem_key": a['dem_key'], "y": float(dem_map[a['dem_key']])}, ignore_index=True)
    idx.to_csv(OUT_UPDATED_INDEX, index=False)
    # recompute misses that remain (those in miss_df not auto_applied)
    applied_set = set([a['npz_file'] for a in auto_applied])
    remain = [row for row in rows_out if row['npz_file'] not in applied_set]
    pd.DataFrame(remain).to_csv(OUT_UPDATED_MISSES, index=False)
    print("Wrote updated index:", OUT_UPDATED_INDEX, "updated misses:", OUT_UPDATED_MISSES)
else:
    print("AUTO_APPLY disabled or no auto mappings applied. Review suggestions CSV and either (A) set AUTO_APPLY=True and re-run, or (B) manually edit the suggestions CSV to pick mapped_dem_key and run the manual apply block from earlier.")

print("\nDone. If you'd like I can also (1) re-run with a more permissive fuzzy method (Levenshtein) or (2) prepare a small interactive review table to let you pick matches in the notebook. Tell me which you prefer.")


Dem keys sample: ['OAS2_0001', 'OAS2_0001', 'OAS2_0002', 'OAS2_0002', 'OAS2_0002', 'OAS2_0004', 'OAS2_0004', 'OAS2_0005'] count: 373
Miss rows: 59


Fuzzy matching: 100%|██████████| 59/59 [00:00<00:00, 216.94it/s]

Saved fuzzy suggestions to: /content/drive/MyDrive/oasis_project/data/graphs/label_misses_fuzzy_suggestions.csv

Sample suggestions (first 8):
[{'npz_file': '/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0000_subj_0.npz', 'npz_sid': '', 'candidates': 'OAS2_0001|0.000;OAS2_0002|0.000'}, {'npz_file': '/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0003_subj_3.npz', 'npz_sid': '3', 'candidates': 'OAS2_0013|0.200;OAS2_0023|0.200'}, {'npz_file': '/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0006_subj_6.npz', 'npz_sid': '6', 'candidates': 'OAS2_0016|0.200;OAS2_0026|0.200;OAS2_0036|0.200'}, {'npz_file': '/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0011_subj_11.npz', 'npz_sid': '11', 'candidates': 'OAS2_0111|0.364;OAS2_0112|0.364;OAS2_0113|0.364'}, {'npz_file': '/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0015_subj_15.npz', 'npz_sid': '15', 'candidates': 'O




In [None]:
# INTERACTIVE REVIEW + APPLY (one cell)
# - recomputes better fuzzy scores (rapidfuzz) and rewrites suggestions CSV
# - lets you interactively accept a mapping for any miss row by index
# - saves labeled npz + updated index + updated misses
import os, glob, re, numpy as np, pandas as pd
from pathlib import Path
from IPython.display import display, HTML

BASE = "/content/drive/MyDrive/oasis_project"
DEM_PATH = os.path.join(BASE, "data", "demographics", "oasis2_demographics.xlsx")
MISSES_CSV = os.path.join(BASE, "data", "graphs", "final_label_misses_auto.csv")
SUGG_CSV = os.path.join(BASE, "data", "graphs", "label_misses_fuzzy_suggestions.csv")
NPZ_DIR = os.path.join(BASE, "data", "graphs", "npz_graphs_resave")
LABELED_DIR = os.path.join(BASE, "data", "graphs", "labeled_npz_auto")
INDEX_CSV = os.path.join(BASE, "data", "graphs", "final_graph_label_index_auto.csv")
OUT_SUGG_UPDATED = SUGG_CSV.replace(".csv","_rf.csv")

os.makedirs(LABELED_DIR, exist_ok=True)

# install rapidfuzz if available / needed
try:
    from rapidfuzz import fuzz, process
except Exception:
    print("Installing rapidfuzz...")
    !pip -q install rapidfuzz
    from rapidfuzz import fuzz, process

# load demographics and build map
if not os.path.exists(DEM_PATH):
    raise SystemExit(f"Demographics file not found: {DEM_PATH}")
if DEM_PATH.lower().endswith('.csv'):
    dem_df = pd.read_csv(DEM_PATH)
else:
    dem_df = pd.read_excel(DEM_PATH)
dem_df.rename(columns={c:c.strip() for c in dem_df.columns}, inplace=True)
cols = list(dem_df.columns)
id_col = next((c for c in cols if any(tok in c.upper() for tok in ("MRI","MRI_ID","MRIID","SUBJ","SUBJECT","ID"))), cols[0])
label_col = next((c for c in cols if any(tok in c.upper() for tok in ("CDR","CDR_GLOBAL","DEMENTIA","DIAG","SEVERITY"))), None)
if label_col is None:
    numeric_cols = [c for c in cols if np.issubdtype(dem_df[c].dtype, np.number)]
    label_col = numeric_cols[0] if numeric_cols else cols[-1]
dem_keys = [str(x).strip() for x in dem_df[id_col].astype(str).tolist()]
dem_map = {k: dem_df.loc[i, label_col] for i,k in dem_df[id_col].astype(str).items()}

print("Loaded demographics:", DEM_PATH)
print("id_col:", id_col, "label_col:", label_col, "dem keys:", len(dem_keys))

# load misses/suggestions
if not os.path.exists(MISSES_CSV):
    raise SystemExit("Missing misses CSV: " + MISSES_CSV)
miss_df = pd.read_csv(MISSES_CSV)
if 'npz_file' not in miss_df.columns and 'file' in miss_df.columns:
    miss_df = miss_df.rename(columns={'file':'npz_file'})
if 'npz_sid' not in miss_df.columns:
    # attempt to extract sid
    def sid_from_fname(fname):
        if not isinstance(fname, str): return ""
        bn = os.path.basename(fname)
        m = re.search(r'(OAS2[_-]?\d+|subj[_-]?\d+|sub[_-]?\d+|\d{2,})', bn, flags=re.IGNORECASE)
        return m.group(0) if m else os.path.splitext(bn)[0]
    miss_df['npz_sid'] = miss_df['npz_file'].apply(lambda p: sid_from_fname(p) if isinstance(p,str) else "")

print("Miss rows loaded:", len(miss_df))

# recompute improved suggestions using rapidfuzz (top 6)
from rapidfuzz import process as rf_process
TOP_K = 6
rows=[]
for i,row in miss_df.iterrows():
    npz_file = row['npz_file']
    sid = str(row.get('npz_sid','') or "")
    # candidate strings to try
    candidates = set()
    candidates.add(sid)
    digits = re.sub(r'\D+','', sid)
    if digits:
        candidates.add(digits)
        candidates.add(digits.zfill(4))
        candidates.add("OAS2_" + digits.zfill(4))
    candidates.update([sid.upper(), sid.lower(), sid.replace("subj","").replace("sub","")])
    # perform rapidfuzz against dem_keys (restrict set if digits present)
    pool = dem_keys
    if digits:
        subset = [k for k in dem_keys if digits in re.sub(r'\D+','',k)]
        if subset: pool = subset
    matches = rf_process.extract(str(sid), pool, scorer=fuzz.WRatio, limit=TOP_K)
    matches = [(m[0], float(m[1]/100.0)) for m in matches]  # convert 0-100 -> 0-1
    rows.append({"idx":int(i), "npz_file":npz_file, "npz_sid":sid, "candidates": ";".join([f"{k}|{s:.3f}" for k,s in matches])})
sugg_df = pd.DataFrame(rows)
sugg_df.to_csv(OUT_SUGG_UPDATED, index=False)
print("Wrote re-scored suggestions (rapidfuzz) ->", OUT_SUGG_UPDATED)
display(sugg_df.head(8))

# helper: show one miss with candidate list (index from miss_df)
def show_miss(idx):
    if idx not in miss_df.index:
        print("Index not in misses. Use a numeric index from 0..", len(miss_df)-1)
        return
    row = miss_df.loc[idx]
    npz_file = row['npz_file']
    sid = str(row['npz_sid'] or "")
    print(f"MISS idx={idx}  npz_file={npz_file}  npz_sid={sid}")
    entry = sugg_df[sugg_df['idx']==idx]
    if entry.empty:
        print("No suggestion row (unexpected).")
        return
    cand_str = entry.iloc[0]['candidates']
    cand_list = []
    for part in cand_str.split(";"):
        if not part: continue
        k,s = part.split("|")
        cand_list.append((k.strip(), float(s)))
    if not cand_list:
        print("No candidates found.")
        return
    print("Top candidates (score 0..1):")
    for i,(k,s) in enumerate(cand_list):
        print(f"  [{i}] {k}   score={s:.3f}   label={dem_map.get(k,'<not-in-dem>')}")
    print("\nTo accept a candidate, call: review_and_apply(idx, pick) where pick is candidate index (0..). Example: review_and_apply(3, 0)")

# helper: apply chosen mapping
def review_and_apply(idx, pick):
    idx = int(idx)
    # validate
    row = miss_df.loc[idx]
    npz_file = row['npz_file']
    entry = sugg_df[sugg_df['idx']==idx]
    if entry.empty:
        print("No suggestion row for idx", idx); return
    cand_str = entry.iloc[0]['candidates']
    cand = [p for p in cand_str.split(";") if p]
    if pick < 0 or pick >= len(cand):
        print("pick out of range:", pick); return
    key, score = cand[pick].split("|")
    key = key.strip(); score = float(score)
    print(f"Applying mapping: {os.path.basename(npz_file)} -> {key} (score={score:.3f})")
    # load npz and write labeled npz
    arr = np.load(npz_file, allow_pickle=True)
    x = arr.get('x', np.array([])); pos = arr.get('pos', np.array([]))
    ei = arr.get('edge_index', np.empty((2,0), dtype=np.int64))
    val = dem_map.get(key)
    if val is None:
        print("Selected dem key not in dem_map (abort).")
        return
    y = np.array([float(val)])
    outname = os.path.join(LABELED_DIR, os.path.basename(npz_file).replace(".npz","_labeled.npz"))
    np.savez_compressed(outname, x=x, pos=pos, edge_index=ei, y=y)
    print("Saved labeled npz ->", outname)
    # append to index CSV
    idx_row = {"npz_file": npz_file, "labeled_npz": outname, "npz_sid": row.get('npz_sid'), "dem_key": key, "y": float(val)}
    # load or create
    if os.path.exists(INDEX_CSV):
        index_df = pd.read_csv(INDEX_CSV)
    else:
        index_df = pd.DataFrame(columns=list(idx_row.keys()))
    index_df = index_df.append(idx_row, ignore_index=True)
    index_df.to_csv(INDEX_CSV, index=False)
    print("Appended to index CSV:", INDEX_CSV)
    # remove row from misses and save updated misses CSV
    new_miss_df = miss_df.drop(index=idx)
    new_miss_df.to_csv(MISSES_CSV, index=False)
    print("Removed from misses and saved updated misses CSV:", MISSES_CSV)
    # refresh in-memory miss_df and sugg_df (note: this does not mutate outer miss_df variable; re-run cell to refresh if needed)
    print("DONE. If you want to continue, call show_miss(idx2) and review_and_apply(idx2, pick).")

# quick usage examples (do not auto-run mapping)
print("\nREADY. Examples:")
print("  show_miss(0)           # print first miss + candidates")
print("  review_and_apply(0,0)  # accept candidate 0 for miss 0 (writes labeled npz + updates CSVs)\n")

# display a short table of all misses with their top candidate
summary = sugg_df.copy()
summary['top_candidate'] = summary['candidates'].str.split(";").str[0].fillna("")
display(summary[['idx','npz_file','npz_sid','top_candidate']].head(12))


Installing rapidfuzz...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m32.0 MB/s[0m eta [36m0:00:00[0m
[?25hLoaded demographics: /content/drive/MyDrive/oasis_project/data/demographics/oasis2_demographics.xlsx
id_col: Subject ID label_col: CDR dem keys: 373
Miss rows loaded: 59
Wrote re-scored suggestions (rapidfuzz) -> /content/drive/MyDrive/oasis_project/data/graphs/label_misses_fuzzy_suggestions_rf.csv


Unnamed: 0,idx,npz_file,npz_sid,candidates
0,0,/content/drive/MyDrive/oasis_project/data/grap...,,OAS2_0001|0.000;OAS2_0001|0.000;OAS2_0002|0.00...
1,1,/content/drive/MyDrive/oasis_project/data/grap...,3.0,OAS2_0013|0.600;OAS2_0013|0.600;OAS2_0013|0.60...
2,2,/content/drive/MyDrive/oasis_project/data/grap...,6.0,OAS2_0016|0.600;OAS2_0016|0.600;OAS2_0026|0.60...
3,3,/content/drive/MyDrive/oasis_project/data/grap...,11.0,OAS2_0111|0.900;OAS2_0111|0.900;OAS2_0112|0.90...
4,4,/content/drive/MyDrive/oasis_project/data/grap...,15.0,OAS2_0150|0.900;OAS2_0150|0.900;OAS2_0152|0.90...
5,5,/content/drive/MyDrive/oasis_project/data/grap...,19.0,OAS2_0119|0.900;OAS2_0119|0.900;OAS2_0119|0.900
6,6,/content/drive/MyDrive/oasis_project/data/grap...,24.0,OAS2_0124|0.900;OAS2_0124|0.900
7,7,/content/drive/MyDrive/oasis_project/data/grap...,25.0,OAS2_0002|0.600;OAS2_0002|0.600;OAS2_0002|0.60...



READY. Examples:
  show_miss(0)           # print first miss + candidates
  review_and_apply(0,0)  # accept candidate 0 for miss 0 (writes labeled npz + updates CSVs)



Unnamed: 0,idx,npz_file,npz_sid,top_candidate
0,0,/content/drive/MyDrive/oasis_project/data/grap...,,OAS2_0001|0.000
1,1,/content/drive/MyDrive/oasis_project/data/grap...,3.0,OAS2_0013|0.600
2,2,/content/drive/MyDrive/oasis_project/data/grap...,6.0,OAS2_0016|0.600
3,3,/content/drive/MyDrive/oasis_project/data/grap...,11.0,OAS2_0111|0.900
4,4,/content/drive/MyDrive/oasis_project/data/grap...,15.0,OAS2_0150|0.900
5,5,/content/drive/MyDrive/oasis_project/data/grap...,19.0,OAS2_0119|0.900
6,6,/content/drive/MyDrive/oasis_project/data/grap...,24.0,OAS2_0124|0.900
7,7,/content/drive/MyDrive/oasis_project/data/grap...,25.0,OAS2_0002|0.600
8,8,/content/drive/MyDrive/oasis_project/data/grap...,33.0,OAS2_0133|0.900
9,9,/content/drive/MyDrive/oasis_project/data/grap...,38.0,OAS2_0138|0.900


In [None]:
# ONE CELL: interactive review + apply fuzzy suggestions + finalize .pt
import os, glob, json, numpy as np, pandas as pd, shutil, re
from pathlib import Path
from IPython.display import display, HTML
from collections import OrderedDict

BASE = "/content/drive/MyDrive/oasis_project"
NPZ_DIR = os.path.join(BASE, "data", "graphs", "npz_graphs_resave")
LABELED_DIR = os.path.join(BASE, "data", "graphs", "labeled_npz_auto")
DEM_PATH = os.path.join(BASE, "data", "demographics", "oasis2_demographics.xlsx")
INDEX_CSV = os.path.join(BASE, "data", "graphs", "final_graph_label_index_auto.csv")
MISSES_CSV = os.path.join(BASE, "data", "graphs", "final_label_misses_auto.csv")
FUZZY_CSV_RF = os.path.join(BASE, "data", "graphs", "label_misses_fuzzy_suggestions_rf.csv")
FUZZY_CSV = os.path.join(BASE, "data", "graphs", "label_misses_fuzzy_suggestions.csv")
AUTO_INDEX_CSV = os.path.join(BASE, "data", "graphs", "graph_label_index_auto.csv")
AUTO_MISSES_CSV = os.path.join(BASE, "data", "graphs", "label_misses_auto.csv")

os.makedirs(LABELED_DIR, exist_ok=True)

# Helper: try to load the best fuzzy suggestions CSV that exists
fuzzy_candidates = [FUZZY_CSV_RF, FUZZY_CSV, AUTO_MISSES_CSV, MISSES_CSV]
fuzzy_path = next((p for p in fuzzy_candidates if os.path.exists(p)), None)

if fuzzy_path is None:
    raise SystemExit("No fuzzy/misses CSV found. Look for files like label_misses_fuzzy_suggestions_rf.csv or final_label_misses_auto.csv in data/graphs/")

print("Using fuzzy/misses file:", fuzzy_path)

# Load fuzzy suggestions (expected columns: npz_file, npz_sid, candidates or suggestions)
fdf = pd.read_csv(fuzzy_path)
# normalize columns
if 'candidates' not in fdf.columns and 'suggestions' in fdf.columns:
    fdf = fdf.rename(columns={'suggestions': 'candidates'})

# If candidates are stored as semicolon-delimited 'key|score;key|score' strings
def parse_candidates(cell):
    if pd.isna(cell): return []
    if isinstance(cell, (list,tuple)): return list(cell)
    s = str(cell)
    # sometimes stored as JSON list of dicts
    try:
        parsed = json.loads(s)
        if isinstance(parsed, list):
            # try to convert list of dicts -> list of "key|score"
            out=[]
            for item in parsed:
                if isinstance(item, dict):
                    k = next(iter(item.values())) if len(item)>0 else None
                    # fallback: create representation
                    out.append(json.dumps(item))
                else:
                    out.append(str(item))
            return out
    except Exception:
        pass
    parts = []
    # split by ; or |
    # detect patterns like "OAS2_0111|0.900;OAS2_0112|0.900"
    for token in re.split(r'[;\\n]+', s):
        token = token.strip()
        if not token: continue
        parts.append(token)
    return parts

# Create an actionable dataframe: each row has npz_file, npz_sid, candidates(list of (key,score))
rows = []
for i, row in fdf.iterrows():
    npz_file = row.get('npz_file') if 'npz_file' in row else row.get('file') if 'file' in row else None
    npz_sid = row.get('npz_sid') if 'npz_sid' in row else ''
    cand_cell = row.get('candidates') if 'candidates' in row else row.get('suggestions') if 'suggestions' in row else ''
    cand_list = parse_candidates(cand_cell)
    # transform cand_list items into (key,score) tuples
    parsed = []
    for item in cand_list:
        item = str(item)
        # if JSON-like dict string, try to extract 'key' and 'score'
        if '|' in item and item.count('|')>=1:
            parts = item.split('|')
            key = parts[0].strip()
            try:
                score = float(parts[1])
            except:
                score = None
            parsed.append((key, score))
        else:
            # sometimes item like "OAS2_0001" or "OAS2_0001:0.9"
            m = re.match(r'^(?P<k>[A-Za-z0-9_\-]+)[\|: ]?(?P<s>[0-9\.]+)?', item)
            if m:
                key = m.group('k')
                s = m.group('s')
                score = float(s) if s is not None else None
                parsed.append((key, score))
            else:
                parsed.append((item, None))
    rows.append({"idx": i, "npz_file": npz_file, "npz_sid": str(npz_sid), "candidates": parsed})

acts = pd.DataFrame(rows)

# Load demographics for label lookup
if not os.path.exists(DEM_PATH):
    print("Warning: demographics file not found at expected DEM_PATH:", DEM_PATH)
    dem_df = None
else:
    try:
        dem_df = pd.read_excel(DEM_PATH)
        dem_df.rename(columns={c:c.strip() for c in dem_df.columns}, inplace=True)
        id_col = next((c for c in dem_df.columns if any(tok in c.upper() for tok in ("MRI","MRI_ID","MRIID","SUBJ","SUBJECT","ID"))), dem_df.columns[0])
        lab_col = next((c for c in dem_df.columns if any(tok in c.upper() for tok in ("CDR","CDR_GLOBAL","DEMENTIA","DIAG","SEVERITY"))), dem_df.columns[-1])
        dem_map = {str(v).strip(): dem_df.iloc[i][lab_col] for i,v in enumerate(dem_df[id_col].astype(str))}
        print("Loaded demographics: {}, id_col = {}, label_col = {}, dem entries = {}".format(DEM_PATH, id_col, lab_col, len(dem_map)))
    except Exception as e:
        print("Could not read demographics:", e)
        dem_df = None
        dem_map = {}

# Index / misses CSVs (create if missing)
if not os.path.exists(INDEX_CSV):
    idx_df = pd.DataFrame(columns=["subject_id","source","file","n_nodes","feat_dim","y"])
    idx_df.to_csv(INDEX_CSV, index=False)
else:
    idx_df = pd.read_csv(INDEX_CSV)

if not os.path.exists(MISSES_CSV):
    misses_df = acts[['idx','npz_file','npz_sid']].rename(columns={'idx':'miss_idx'}).copy()
    misses_df.to_csv(MISSES_CSV, index=False)
else:
    misses_df = pd.read_csv(MISSES_CSV)

# utility: pretty print one miss
def show_miss(i):
    """Display miss row i and candidate list with dem label if available."""
    row = acts[acts['idx']==i]
    if row.empty:
        print(f"No miss with idx={i} (rows available: {acts['idx'].tolist()[:20]} ... )")
        return
    row = row.iloc[0]
    npzf = row['npz_file']; sid = row['npz_sid']; cands = row['candidates']
    print(f"MISS idx={i}\n npz_file={npzf}\n npz_sid='{sid}'\nCandidates (best first):")
    for j, (key, score) in enumerate(cands):
        label_val = dem_map.get(key) if dem_map else None
        print(f"  [{j}] {key}  score={score}  dem_label={label_val}")
    # also show a tiny preview of the NPZ's shapes if possible
    if npzf and os.path.exists(npzf):
        try:
            arr = np.load(npzf, allow_pickle=True)
            x = arr.get('x'); pos = arr.get('pos'); ei = arr.get('edge_index') or arr.get('edges') or arr.get('edge_idx')
            print("NPZ preview shapes: x:", getattr(x,'shape',None), "pos:", getattr(pos,'shape',None), "edge_index:", getattr(ei,'shape',None))
        except Exception as e:
            print("Could not load npz preview:", e)
    else:
        print("NPZ file not found on disk.")

# utility: accept candidate j for miss i -> writes labeled NPZ and updates CSVs
def review_and_apply(i, j):
    """Accept candidate j for miss index i. Writes labeled npz and updates CSVs."""
    row = acts[acts['idx']==i]
    if row.empty:
        print("No miss with idx=", i); return
    row = row.iloc[0]
    npzf = row['npz_file']; sid = row['npz_sid']; cands = row['candidates']
    if j<0 or j>=len(cands):
        print("Candidate index out of range."); return
    dem_key, score = cands[j]
    # find label value from demographics
    if dem_map and dem_key in dem_map:
        yval = dem_map[dem_key]
    else:
        # try to match numeric part (like OAS2_0001 -> 0001)
        digits = re.sub(r'\D+','', dem_key)
        yval = dem_map.get(dem_key) or dem_map.get(digits) or None
    if yval is None:
        print("WARNING: could not find dem label for key", dem_key, "— still proceeding but y will be saved as string.")
    # load npz
    if not npzf or not os.path.exists(npzf):
        print("NPZ file missing:", npzf); return
    arr = np.load(npzf, allow_pickle=True)
    x = arr.get('x', np.array([]))
    pos = arr.get('pos', np.array([]))
    ei = arr.get('edge_index', arr.get('edges', arr.get('edge_idx', np.empty((2,0), dtype=np.int64))))
    y = np.array([float(yval)]) if (yval is not None and (isinstance(yval, (int,float)) or (isinstance(yval, np.number)))) else np.array([yval])
    # write labeled npz
    outname = os.path.join(LABELED_DIR, os.path.basename(npzf).replace(".npz","_labeled.npz"))
    np.savez_compressed(outname, x=x, pos=pos, edge_index=ei, y=y)
    print("Wrote labeled npz:", outname, "mapped_dem_key:", dem_key, "score:", score, "y:", y)
    # update index CSV (append)
    # compute n_nodes / feat_dim safely
    try:
        n_nodes = int(np.asarray(x).shape[0])
        feat_dim = int(np.asarray(x).shape[1]) if np.asarray(x).ndim>1 else 1
    except Exception:
        n_nodes = 0; feat_dim = 0
    newrow = {"subject_id": dem_key, "source": "auto_fuzzy", "file": outname, "n_nodes": n_nodes, "feat_dim": feat_dim, "y": float(y) if (isinstance(y, np.ndarray) and np.issubdtype(y.dtype, np.number)) else y}
    # append to CSVs in memory and disk
    global idx_df, misses_df
    idx_df = pd.concat([idx_df, pd.DataFrame([newrow])], ignore_index=True)
    idx_df.to_csv(INDEX_CSV, index=False)
    # remove from misses: delete acts row and rewrite misses CSV
    acts.drop(acts[acts['idx']==i].index, inplace=True)
    # also remove entry from misses_df if exists (match by npz_file)
    misses_df = misses_df[misses_df['npz_file'] != npzf]
    misses_df.to_csv(MISSES_CSV, index=False)
    # update fuzzy CSV file by removing that row
    # read original CSV and drop matching npz_file row
    try:
        orig = pd.read_csv(fuzzy_path)
        orig = orig[orig['npz_file'] != npzf]
        orig.to_csv(fuzzy_path, index=False)
    except Exception:
        pass
    # update index CSV on disk
    print("Index CSV and misses CSV updated.")
    return outname

# convenience: auto-apply highly confident suggestions
def auto_apply_threshold(threshold=0.85, dry_run=False, max_apply=500):
    """Auto-apply candidate[0] for rows where candidate score >= threshold.
       dry_run=True will just print; returns list of applied (idx,dem_key)."""
    applied=[]
    candidates_to_apply = []
    for _, row in acts.iterrows():
        i = row['idx']; cands = row['candidates']
        if not cands: continue
        key, score = cands[0]
        score_val = score if score is not None else 0.0
        if score_val >= threshold:
            candidates_to_apply.append((i, key, score_val))
    # sort by score desc
    candidates_to_apply = sorted(candidates_to_apply, key=lambda x: -x[2])[:max_apply]
    print(f"Auto-applying {len(candidates_to_apply)} matches with score >= {threshold}. dry_run={dry_run}")
    for i,key,score in candidates_to_apply:
        print(" ->", i, key, score)
        if not dry_run:
            # find candidate index in acts
            row = acts[acts['idx']==i].iloc[0]
            # find which candidate entry matches key
            cand_idx = next((k for k,(kk,sc) in enumerate(row['candidates']) if kk==key), 0)
            review_and_apply(i, cand_idx)
            applied.append((i,key,score))
    return applied

# finalizer: build .pt from all labeled npz files in LABELED_DIR and save into oasis2_graphs_labeled_final.pt
def build_final_pt(out_pt=None, overwrite=False):
    """Load all labeled npz in LABELED_DIR and save a single .pt list of torch_geometric.data.Data"""
    try:
        import torch
        from torch_geometric.data import Data
    except Exception as e:
        print("Could not import torch/torch_geometric:", e)
        return
    out_pt = out_pt or os.path.join(BASE, "data", "graphs", "oasis2_graphs_labeled_final.pt")
    if os.path.exists(out_pt) and not overwrite:
        print("Out .pt already exists:", out_pt, "use overwrite=True to replace.")
        return
    npz_files = sorted(glob.glob(os.path.join(LABELED_DIR, "*_labeled.npz")))
    if not npz_files:
        print("No labeled npz files found in", LABELED_DIR)
        return
    data_list = []
    for p in npz_files:
        try:
            arr = np.load(p, allow_pickle=True)
            x = arr.get('x', np.array([]))
            pos = arr.get('pos', np.array([]))
            ei = arr.get('edge_index', np.empty((2,0),dtype=np.int64))
            y = arr.get('y', np.array([]))
            xt = torch.tensor(x, dtype=torch.float) if getattr(x,'size', lambda:0)() else None
            post = torch.tensor(pos, dtype=torch.float) if getattr(pos,'size', lambda:0)() else None
            eit = torch.tensor(ei, dtype=torch.long) if getattr(ei,'size', lambda:0)() else None
            yt = torch.tensor(y, dtype=torch.float) if getattr(y,'size', lambda:0)() else None
            g = Data(x=xt, edge_index=eit, pos=post, y=yt)
            # subject id from filename or stored in y/npz name
            subj = os.path.splitext(os.path.basename(p))[0].replace("_labeled","")
            g.subject_id = subj
            data_list.append(g)
        except Exception as e:
            print("Skipping", p, "due to error:", e)
    if not data_list:
        print("No Data objects constructed.")
        return
    torch.save(data_list, out_pt)
    print("Saved final labeled .pt with", len(data_list), "graphs to:", out_pt)
    return out_pt

# Print small usage instruction
print("\nREADY. Usage examples:")
print("  show_miss(3)             # inspect miss index 3")
print("  review_and_apply(3, 0)   # accept candidate 0 for miss 3 (writes labeled npz + updates CSVs)")
print("  auto_apply_threshold(0.9) # auto accept high-confidence matches (dry_run default False)")
print("  build_final_pt(overwrite=True)  # aggregate labeled npz -> single .pt\n")

# expose a light HTML table for the first few misses
try:
    display(HTML(acts[['idx','npz_file','npz_sid']].head(12).to_html(index=False)))
except Exception:
    pass


Using fuzzy/misses file: /content/drive/MyDrive/oasis_project/data/graphs/label_misses_fuzzy_suggestions_rf.csv
Loaded demographics: /content/drive/MyDrive/oasis_project/data/demographics/oasis2_demographics.xlsx, id_col = Subject ID, label_col = CDR, dem entries = 150

READY. Usage examples:
  show_miss(3)             # inspect miss index 3
  review_and_apply(3, 0)   # accept candidate 0 for miss 3 (writes labeled npz + updates CSVs)
  auto_apply_threshold(0.9) # auto accept high-confidence matches (dry_run default False)
  build_final_pt(overwrite=True)  # aggregate labeled npz -> single .pt



idx,npz_file,npz_sid
0,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0000_subj_0.npz,
1,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0003_subj_3.npz,3.0
2,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0006_subj_6.npz,6.0
3,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0011_subj_11.npz,11.0
4,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0015_subj_15.npz,15.0
5,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0019_subj_19.npz,19.0
6,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0024_subj_24.npz,24.0
7,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0025_subj_25.npz,25.0
8,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0033_subj_33.npz,33.0
9,/content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0038_subj_38.npz,38.0


In [None]:
# ONE-CELL TOOL: Export .pt -> npz, auto-match demographics with fuzzy suggestions,
# interactive helpers: show_miss, review_and_apply, auto_apply_threshold, build_final_pt
# Drop this into one notebook cell and run. It will:
#  - find / recover the .pt dataset if possible
#  - export graphs to NPZ (if .pt available)
#  - load demographics (prompt looks in project dem folder or root drive)
#  - perform strict + fuzzy matching to propose dem keys for each NPZ
#  - provide helper functions to accept suggestions and build a final labeled .pt
#
# NOTE: This code is defensive and prints progress. Edit PATHS near top if needed.

import os, glob, re, shutil, json, numpy as np, pandas as pd, traceback
from collections import defaultdict
from pprint import pprint
from pathlib import Path

# optional imports (faster fuzzy matching). Will fallback to difflib if missing.
try:
    from rapidfuzz import process as rf_process, fuzz as rf_fuzz
    HAVE_RAPIDFUZZ = True
except Exception:
    import difflib
    HAVE_RAPIDFUZZ = False

try:
    import torch
    from torch_geometric.data import Data
except Exception:
    torch = None
    Data = None

# ---------------- CONFIG (change here if you want) ----------------
BASE = "/content/drive/MyDrive/oasis_project"
SRC_PT_CANON = os.path.join(BASE, "oasis2_graph_dataset.pt")   # canonical expected path
OUTPUT_DIR = os.path.join(BASE, "data", "graphs")
NPZ_OUT = os.path.join(OUTPUT_DIR, "npz_graphs_resave")
LABELED_OUT = os.path.join(OUTPUT_DIR, "labeled_npz_auto")
FINAL_PT = os.path.join(OUTPUT_DIR, "oasis2_graphs_labeled_auto_final.pt")
DEM_PREF = os.path.join(BASE, "data", "demographics", "oasis2_demographics.xlsx")

# suggestion / index CSVs
FUZZY_CSV = os.path.join(OUTPUT_DIR, "label_misses_fuzzy_suggestions_rf.csv")
INDEX_CSV = os.path.join(OUTPUT_DIR, "final_graph_label_index_auto.csv")
MISSES_CSV = os.path.join(OUTPUT_DIR, "final_label_misses_auto.csv")

os.makedirs(NPZ_OUT, exist_ok=True)
os.makedirs(LABELED_OUT, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---------------- helpers ----------------
def safe_load_torch(path):
    """Try to load a torch file safely; map storages to cpu if needed."""
    if torch is None:
        raise RuntimeError("torch not available in environment.")
    try:
        return torch.load(path, weights_only=False)
    except RuntimeError as e:
        # try mapping to CPU (common when checkpoint saved on CUDA)
        try:
            return torch.load(path, map_location=torch.device('cpu'), weights_only=False)
        except Exception:
            raise

def find_candidate_pt():
    """Search known locations for a dataset .pt containing list-of-graphs."""
    # candidates: canonical path, outputs folder copies
    cands = []
    if os.path.exists(SRC_PT_CANON):
        cands.append(SRC_PT_CANON)
    out_dir = os.path.join(BASE, "outputs")
    if os.path.isdir(out_dir):
        for p in sorted(glob.glob(os.path.join(out_dir, "*.pt"))):
            cands.append(p)
    # also check top-level for anything named oasis*dataset*.pt
    for p in glob.glob(os.path.join(BASE, "**", "*graph_dataset*.pt"), recursive=True):
        cands.append(p)
    # de-duplicate preserving order
    seen = set(); uniq=[]
    for p in cands:
        if p not in seen and os.path.exists(p):
            uniq.append(p); seen.add(p)
    return uniq

def export_pt_to_npz(pt_path, npz_out=NPZ_OUT, overwrite=False):
    """Load a .pt list-of-Data and export each graph to compressed npz. Returns list of npz paths."""
    print("Exporting .pt -> npz from:", pt_path)
    raw = safe_load_torch(pt_path)
    if not isinstance(raw, (list, tuple)):
        raise RuntimeError("Loaded .pt is not a list/tuple of graphs. File: " + pt_path)
    graphs = list(raw)
    print("Loaded graphs count:", len(graphs))
    out_files = []
    for i,g in enumerate(graphs):
        # robust extraction
        sid = getattr(g, 'subject_id', None) or getattr(g, 'sid', None) or getattr(g, 'filename', None) or f"subj_{i}"
        # get tensors/arrays to numpy
        def to_np(v):
            if v is None: return np.array([])
            try:
                return v.cpu().numpy()
            except Exception:
                return np.asarray(v)
        x = to_np(getattr(g, 'x', None))
        pos = to_np(getattr(g, 'pos', None))
        # try multiple edge attr names
        ei = None
        for nm in ('edge_index','edges','edge_idx'):
            if hasattr(g, nm):
                ei = getattr(g, nm)
                break
        ei = to_np(ei)
        # normalize edge shape
        try:
            ei = np.asarray(ei)
            if ei.ndim == 1 and ei.size % 2 == 0:
                ei = ei.reshape(2,-1)
            elif ei.ndim == 2 and ei.shape[0] != 2 and ei.shape[1] == 2:
                ei = ei.T
            elif ei.ndim != 2:
                ei = np.empty((2,0), dtype=np.int64)
        except Exception:
            ei = np.empty((2,0), dtype=np.int64)
        safe_sid = re.sub(r'[^0-9A-Za-z_\-\.]', '_', str(sid))
        outp = os.path.join(npz_out, f"graph_{i:04d}_{safe_sid}.npz")
        if os.path.exists(outp) and not overwrite:
            out_files.append(outp)
            continue
        np.savez_compressed(outp, x=x, pos=pos, edge_index=ei)
        out_files.append(outp)
    print("Exported npz examples:", out_files[:3])
    return out_files

def find_demographics(pref=DEM_PREF):
    """Return path to demographics file (xlsx/csv) or None."""
    if pref and os.path.exists(pref):
        return pref
    # search typical places: project folder and MyDrive root
    candidates=[]
    for root in (BASE, os.path.join("/content/drive","MyDrive"), "/content"):
        for ext in ("*.xlsx","*.xls","*.csv"):
            for p in glob.glob(os.path.join(root, "**", ext), recursive=False):
                candidates.append(p)
    # deeper scan for oasis+dem names (limited)
    for ext in ("*.xlsx","*.xls","*.csv"):
        for p in glob.glob(os.path.join("/content/drive","MyDrive","**", ext), recursive=True):
            name = os.path.basename(p).lower()
            if 'oasis' in name and ('dem' in name or 'demograph' in name or 'demographics' in name):
                candidates.append(p)
    candidates = [c for c in sorted(set(candidates)) if os.path.exists(c)]
    return candidates[0] if candidates else None

def read_dem(dfpath):
    if dfpath.lower().endswith('.csv'):
        df = pd.read_csv(dfpath)
    else:
        df = pd.read_excel(dfpath)
    df.rename(columns={c:c.strip() for c in df.columns}, inplace=True)
    return df

def detect_id_label_columns(df):
    cols = list(df.columns)
    id_candidates = [c for c in cols if any(tok in c.upper() for tok in ("MRI","MRI_ID","MRIID","SUBJ","SUBJECT","ID","OASIS"))]
    if not id_candidates:
        id_candidates = cols[:1]
    id_col = id_candidates[0]
    label_candidates = [c for c in cols if any(tok in c.upper() for tok in ("CDR","CDR_GLOBAL","DEMENTIA","DIAG","SEVERITY"))]
    if not label_candidates:
        numeric_cols = [c for c in cols if np.issubdtype(df[c].dtype, np.number)]
        label_candidates = numeric_cols if numeric_cols else [cols[-1]]
    label_col = label_candidates[0]
    return id_col, label_col

def normalize_dem_key(x):
    if pd.isna(x): return None
    s = str(x).strip()
    # common dem keys look like OAS2_0123 or numeric 123; keep both forms as alternatives
    return s

# ---------------- core matching logic ----------------
def build_dem_map(dem_path):
    df = read_dem(dem_path)
    id_col, label_col = detect_id_label_columns(df)
    dem_map = {}
    dem_keys = []
    for _, r in df.iterrows():
        key = normalize_dem_key(r.get(id_col))
        if key is None: continue
        dem_keys.append(key)
        val = r.get(label_col)
        dem_map[key] = val
    dem_keys = sorted(list(set(dem_keys)))
    return dem_map, dem_keys, id_col, label_col, df

def extract_npz_sid(fname):
    """Extract the numeric-ish id from filename (commonly last token like subj_12 or 0012)."""
    b = os.path.basename(fname)
    m = re.search(r'([0-9]{2,}|subj[_-]?\d+|sub[_-]?\d+)', b, flags=re.IGNORECASE)
    if not m:
        # fallback: basename without extension
        return os.path.splitext(b)[0]
    s = m.group(0)
    # trim "subj_" -> keep numeric if possible
    s2 = re.sub(r'^(subj[_-]?|sub[_-]?)', '', s, flags=re.IGNORECASE)
    return s2

def fuzzy_candidates_for_sid(sid, dem_keys, topk=5):
    """Return list of (dem_key, score) best matches for sid (score 0-1)."""
    if sid is None or sid == "":
        return []
    sid_str = str(sid)
    if HAVE_RAPIDFUZZ:
        # rapidfuzz returns scores up to 100; prefer ratio on whole string
        matches = rf_process.extract(sid_str, dem_keys, scorer=rf_fuzz.ratio, limit=topk)
        # convert to 0..1 floats and return
        return [(m[0], float(m[1]) / 100.0) for m in matches]
    else:
        # use difflib SequenceMatcher ratio
        scored = []
        for k in dem_keys:
            score = difflib.SequenceMatcher(a=sid_str.lower(), b=str(k).lower()).ratio()
            scored.append((k, score))
        scored.sort(key=lambda x: x[1], reverse=True)
        return scored[:topk]

# ---------------- RUN pipeline (export if needed) ----------------
print("STARTING pipeline. Base:", BASE)
# 1) Ensure we have NPZs: attempt to use existing, else export from a pt if present
npz_files = sorted(glob.glob(os.path.join(NPZ_OUT, "*.npz")))
if len(npz_files) == 0:
    print("No npz files in", NPZ_OUT, "-> searching for .pt dataset to export.")
    pts = find_candidate_pt()
    if not pts:
        print("❌ No .pt candidates found in project. If you have a .pt dataset upload it to", BASE)
    else:
        recovered = False
        for p in pts:
            try:
                print("Trying to load candidate .pt:", p)
                raw = safe_load_torch(p)
                # if it's a list, we can export
                if isinstance(raw, list):
                    # if canonical not existing, copy to canonical for bookkeeping
                    if not os.path.exists(SRC_PT_CANON):
                        try:
                            shutil.copyfile(p, SRC_PT_CANON)
                            print("Copied candidate to canonical path:", SRC_PT_CANON)
                        except Exception:
                            pass
                    export_pt_to_npz(p, npz_out=NPZ_OUT, overwrite=False)
                    recovered = True
                    break
            except Exception as e:
                print("Failed loading candidate:", p, "->", e)
                continue
        if not recovered:
            print("❌ Could not recover a list-of-graphs .pt. Aborting NPZ creation.")
npz_files = sorted(glob.glob(os.path.join(NPZ_OUT, "*.npz")))
print("Found NPZ count:", len(npz_files))
if len(npz_files) == 0:
    raise SystemExit("No graphs (.npz) available. Provide a .pt or pre-exported .npz files in: " + NPZ_OUT)

# 2) Load demographics
dem_path = find_demographics(DEM_PREF)
if dem_path is None:
    raise SystemExit("Demographics file not found. Upload your oasis_longitudinal_demographics* file to Drive or data/demographics/ and re-run.")
print("Using demographics file:", dem_path)
dem_map, dem_keys, id_col, label_col, dem_df = build_dem_map(dem_path)
print(f"Demographics loaded. id_col = {id_col}, label_col = {label_col}, dem entries = {len(dem_keys)}")
# quick sample
print("Dem keys sample:", dem_keys[:12])

# 3) Perform strong direct mapping (exact match) and collect misses for fuzzy
auto_matches = []
miss_rows = []
index_rows = []
for p in npz_files:
    sid_raw = extract_npz_sid(p)
    sid_norm = str(sid_raw).strip() if sid_raw is not None else ""
    label_val = None
    chosen_key = None
    # try exact OAS2_0001 style or numeric token matches
    if sid_norm in dem_map:
        chosen_key = sid_norm
        label_val = dem_map[sid_norm]
    else:
        # try common padded numeric forms (e.g. '1' -> 'OAS2_0001')
        digits = re.sub(r'\D+', '', sid_norm or "")
        if digits:
            # attempt variants:
            variants = [digits, digits.zfill(4), f"OAS2_{digits.zfill(4)}", f"OAS_{digits.zfill(4)}"]
            for v in variants:
                if v in dem_map:
                    chosen_key = v; label_val = dem_map[v]; break
    if chosen_key is not None:
        # save labeled npz immediately
        arr = np.load(p, allow_pickle=True)
        x = arr.get('x', np.array([])); pos = arr.get('pos', np.array([]))
        ei = arr.get('edge_index', np.empty((2,0), dtype=np.int64))
        y = np.array([label_val])
        outname = os.path.join(LABELED_OUT, os.path.basename(p).replace(".npz","_labeled.npz"))
        np.savez_compressed(outname, x=x, pos=pos, edge_index=ei, y=y)
        auto_matches.append({"npz_file":p, "labeled_npz": outname, "npz_sid": sid_norm, "dem_key": chosen_key, "y": label_val})
        index_rows.append({"npz_file":p, "npz_sid":sid_norm, "dem_key": chosen_key, "y": label_val})
    else:
        miss_rows.append({"npz_file": p, "npz_sid": sid_norm})

print("Strict auto-labeled count:", len(auto_matches), "Misses:", len(miss_rows))

# 4) Fuzzy-match suggestions for misses
fuzzy_rows = []
for m in miss_rows:
    p = m['npz_file']; sid = str(m.get('npz_sid') or "")
    cands = fuzzy_candidates_for_sid(sid, dem_keys, topk=6)
    fuzzy_rows.append({"npz_file": p, "npz_sid": sid, "candidates": ";".join([f"{k}|{s:.3f}" for k,s in cands])})
# write suggestions csv
fdf = pd.DataFrame(fuzzy_rows)
fdf.to_csv(FUZZY_CSV, index=False)
print("Saved fuzzy suggestions to:", FUZZY_CSV)

# 5) Provide interactive helpers (in-memory state + CSVs saved)
# Build in-memory miss list with parsed candidates
miss_list = []
for idx, row in fdf.iterrows():
    cand_text = row['candidates'] or ""
    cands = []
    if cand_text:
        for token in cand_text.split(";"):
            if "|" in token:
                k, s = token.split("|")
                try: s = float(s)
                except: s = 0.0
                cands.append((k, s))
    miss_list.append({"idx": int(idx), "npz_file": row['npz_file'], "npz_sid": row['npz_sid'], "candidates": cands, "applied": None})

# keep track of applied decisions (updates CSVs when applied)
applied_index_rows = list(index_rows)  # start with strict matches
applied_misses = []

# helper functions to expose to notebook user
def show_miss(i):
    """Print details for miss index i (use index from CSV / printed table)."""
    if i < 0 or i >= len(miss_list):
        print("Index out of range:", i); return
    m = miss_list[i]
    print("MISS INDEX:", i)
    print(" NPZ file: ", m['npz_file'])
    print(" NPZ extracted sid:", m['npz_sid'])
    print(" Candidates (dem_key | score):")
    for ci, (k,s) in enumerate(m['candidates']):
        # show demographics label value too
        demo_val = dem_map.get(k, None)
        print(f"  [{ci}] {k}  score={s:.3f}  -> label={demo_val}")
    # quick preview of npz structure
    try:
        arr = np.load(m['npz_file'], allow_pickle=True)
        print(" NPZ keys:", list(arr.keys()))
        x = arr.get('x', None)
        if x is not None and getattr(x, 'shape', None) is not None:
            print("  x.shape:", np.asarray(x).shape, " dtype:", np.asarray(x).dtype)
        pos = arr.get('pos', None)
        if pos is not None and getattr(pos, 'shape', None) is not None:
            print("  pos.shape:", np.asarray(pos).shape)
        ei = arr.get('edge_index', None)
        if ei is not None:
            try:
                ei_a = np.asarray(ei)
                print("  edge_index.shape:", ei_a.shape)
            except:
                print("  edge_index: (could not parse shape)")
    except Exception as e:
        print(" Could not open npz:", e)

def review_and_apply(miss_idx, cand_idx):
    """Accept candidate cand_idx for miss miss_idx. Writes labeled npz and updates CSVs."""
    if miss_idx < 0 or miss_idx >= len(miss_list):
        print("miss idx out of range"); return
    m = miss_list[miss_idx]
    if cand_idx < 0 or cand_idx >= len(m['candidates']):
        print("cand idx out of range"); return
    dem_key, score = m['candidates'][cand_idx]
    label_val = dem_map.get(dem_key, None)
    # load npz and save labeled npz
    arr = np.load(m['npz_file'], allow_pickle=True)
    x = arr.get('x', np.array([])); pos = arr.get('pos', np.array([]))
    ei = arr.get('edge_index', np.empty((2,0), dtype=np.int64))
    y = np.array([label_val])
    outname = os.path.join(LABELED_OUT, os.path.basename(m['npz_file']).replace(".npz","_labeled.npz"))
    np.savez_compressed(outname, x=x, pos=pos, edge_index=ei, y=y)
    # record
    applied_index_rows.append({"npz_file": m['npz_file'], "npz_sid": m['npz_sid'], "dem_key": dem_key, "y": label_val})
    applied_misses.append({"npz_file": m['npz_file'], "npz_sid": m['npz_sid'], "mapped_dem_key": dem_key, "y": label_val, "score": score})
    m['applied'] = (dem_key, score)
    # persist small updates to CSVs
    pd.DataFrame(applied_index_rows).to_csv(INDEX_CSV, index=False)
    pd.DataFrame(applied_misses).to_csv(MISSES_CSV, index=False)
    print("Applied mapping:", os.path.basename(m['npz_file']), "->", dem_key, " score=", score)
    print("Wrote labeled npz:", outname)
    print("Updated index CSV:", INDEX_CSV, "and misses CSV:", MISSES_CSV)

def auto_apply_threshold(threshold=0.90, dry_run=True):
    """Auto-accept candidates whose fuzzy score >= threshold. dry_run=True just prints what would be applied."""
    to_apply = []
    for m in miss_list:
        if m['applied'] is not None: continue
        if not m['candidates']: continue
        top = m['candidates'][0]
        if top[1] >= threshold:
            to_apply.append((m, top))
    print(f"Found {len(to_apply)} auto-apply candidates with score >= {threshold} (dry_run={dry_run})")
    if dry_run:
        for m,top in to_apply[:50]:
            print(" WOULD APPLY:", os.path.basename(m['npz_file']), "->", top[0], "score=", top[1])
        return to_apply
    else:
        for m,top in to_apply:
            idx = next((i for i,mm in enumerate(miss_list) if mm['npz_file']==m['npz_file']), None)
            if idx is not None:
                ci = 0  # top candidate
                review_and_apply(idx, ci)
        return len(to_apply)

def build_final_pt(overwrite=True):
    """Aggregate labeled npz files in LABELED_OUT into a single .pt (torch_geometric.Data list)."""
    npz_labeled = sorted(glob.glob(os.path.join(LABELED_OUT, "*_labeled.npz")))
    if not npz_labeled:
        print("No labeled npz files found in", LABELED_OUT)
        return None
    data_list = []
    for p in npz_labeled:
        arr = np.load(p, allow_pickle=True)
        x = arr.get('x', np.array([]))
        pos = arr.get('pos', np.array([]))
        ei = arr.get('edge_index', np.empty((2,0), dtype=np.int64))
        y = arr.get('y', np.array([]))
        # convert to torch geometric Data where possible
        if torch is not None:
            try:
                xt = torch.tensor(x, dtype=torch.float) if getattr(x,'size', None) and np.asarray(x).size else None
                post = torch.tensor(pos, dtype=torch.float) if getattr(pos,'size', None) and np.asarray(pos).size else None
                eit = torch.tensor(ei, dtype=torch.long) if getattr(ei,'size', None) and np.asarray(ei).size else None
                yt = torch.tensor(y, dtype=torch.float) if getattr(y,'size', None) and np.asarray(y).size else None
                g = Data(x=xt, edge_index=eit, pos=post, y=yt)
                # store a subject id attribute for traceability
                g.subject_id = os.path.basename(p).replace("_labeled.npz","")
                data_list.append(g)
            except Exception:
                # fallback: store a dict-like object
                data_list.append({"x": x, "edge_index": ei, "pos": pos, "y": y, "subject_id": os.path.basename(p)})
        else:
            data_list.append({"x": x, "edge_index": ei, "pos": pos, "y": y, "subject_id": os.path.basename(p)})
    # save
    if os.path.exists(FINAL_PT) and not overwrite:
        print("Final .pt exists and overwrite=False:", FINAL_PT)
    else:
        if torch is not None and isinstance(data_list[0], Data):
            torch.save(data_list, FINAL_PT)
            print("Saved final labeled .pt (list-of-Data) to:", FINAL_PT)
        else:
            # save generic pickled list
            import pickle
            with open(FINAL_PT, "wb") as f:
                pickle.dump(data_list, f)
            print("Saved final labeled pickled list to:", FINAL_PT)
    return FINAL_PT

# Save initial index/miss CSVs (pre-application)
pd.DataFrame(index_rows).to_csv(INDEX_CSV, index=False)
pd.DataFrame([{"npz_file":m["npz_file"], "npz_sid":m["npz_sid"], "candidates": ";".join([f"{k}|{s:.3f}" for k,s in m['candidates']])} for m in miss_list]).to_csv(MISSES_CSV, index=False)

print("Pipeline complete.")
print("Counts: strict_labeled =", len(auto_matches), "misses_with_suggestions =", len(miss_list))
print("Fuzzy suggestions saved to:", FUZZY_CSV)
print("Index CSV (initial):", INDEX_CSV)
print("Miss CSV (initial):", MISSES_CSV)
print("")
print("Usage (examples):")
print("  show_miss(0)           # inspect miss 0")
print("  review_and_apply(0, 0) # accept candidate 0 for miss 0")
print("  auto_apply_threshold(0.90, dry_run=True)  # preview")
print("  auto_apply_threshold(0.90, dry_run=False) # apply")
print("  build_final_pt(overwrite=True)            # aggregate labeled npz -> final .pt")


STARTING pipeline. Base: /content/drive/MyDrive/oasis_project
Found NPZ count: 209
Using demographics file: /content/drive/MyDrive/oasis_project/data/demographics/oasis2_demographics.xlsx
Demographics loaded. id_col = Subject ID, label_col = CDR, dem entries = 150
Dem keys sample: ['OAS2_0001', 'OAS2_0002', 'OAS2_0004', 'OAS2_0005', 'OAS2_0007', 'OAS2_0008', 'OAS2_0009', 'OAS2_0010', 'OAS2_0012', 'OAS2_0013', 'OAS2_0014', 'OAS2_0016']
Strict auto-labeled count: 150 Misses: 59
Saved fuzzy suggestions to: /content/drive/MyDrive/oasis_project/data/graphs/label_misses_fuzzy_suggestions_rf.csv
Pipeline complete.
Counts: strict_labeled = 150 misses_with_suggestions = 59
Fuzzy suggestions saved to: /content/drive/MyDrive/oasis_project/data/graphs/label_misses_fuzzy_suggestions_rf.csv
Index CSV (initial): /content/drive/MyDrive/oasis_project/data/graphs/final_graph_label_index_auto.csv
Miss CSV (initial): /content/drive/MyDrive/oasis_project/data/graphs/final_label_misses_auto.csv

Usage (exam

In [None]:
auto_apply_threshold(0.90, dry_run=True)


Found 0 auto-apply candidates with score >= 0.9 (dry_run=True)


[]

In [None]:
auto_apply_threshold(0.90, dry_run=False)


Found 0 auto-apply candidates with score >= 0.9 (dry_run=False)


0

In [None]:
show_miss(3)


MISS INDEX: 3
 NPZ file:  /content/drive/MyDrive/oasis_project/data/graphs/npz_graphs_resave/graph_0011_subj_11.npz
 NPZ extracted sid: 0011
 Candidates (dem_key | score):
  [0] OAS2_0001  score=0.462  -> label=0.0
  [1] OAS2_0010  score=0.462  -> label=0.5
  [2] OAS2_0012  score=0.462  -> label=0.0
  [3] OAS2_0013  score=0.462  -> label=0.0
  [4] OAS2_0014  score=0.462  -> label=1.0
  [5] OAS2_0016  score=0.462  -> label=0.5
 NPZ keys: ['x', 'pos', 'edge_index']
  x.shape: (1575, 50)  dtype: float32
  pos.shape: (0,)
  edge_index.shape: (2, 22062)


In [None]:
review_and_apply(miss_idx=3, cand_idx=0)


Applied mapping: graph_0011_subj_11.npz -> OAS2_0001  score= 0.462
Wrote labeled npz: /content/drive/MyDrive/oasis_project/data/graphs/labeled_npz_auto/graph_0011_subj_11_labeled.npz
Updated index CSV: /content/drive/MyDrive/oasis_project/data/graphs/final_graph_label_index_auto.csv and misses CSV: /content/drive/MyDrive/oasis_project/data/graphs/final_label_misses_auto.csv


In [None]:
build_final_pt(overwrite=True)


Saved final labeled .pt (list-of-Data) to: /content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_auto_final.pt


'/content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_auto_final.pt'

In [None]:
# ONE-CELL: Aggressive / fuzzy + manual-apply relabeler for 04 -> attempt FULL coverage
# Paste & run in 04_label_map_and_checks.ipynb (Colab). Installs rapidfuzz if missing.
import os, glob, re, numpy as np, pandas as pd, shutil, math
from pathlib import Path
from tqdm import tqdm

BASE = "/content/drive/MyDrive/oasis_project"
GRAPHS_DIR = os.path.join(BASE, "data", "graphs")
NPZ_FOLDER = os.path.join(GRAPHS_DIR, "npz_graphs_resave")
LABELED_DIR = os.path.join(GRAPHS_DIR, "labeled_npz_full")
DEM_PATH = os.path.join(BASE, "data", "demographics", "oasis2_demographics.xlsx")
OUT_INDEX = os.path.join(GRAPHS_DIR, "final_graph_label_index_auto.csv")
OUT_MISSES = os.path.join(GRAPHS_DIR, "final_label_misses_auto.csv")
OUT_SUGGEST = os.path.join(GRAPHS_DIR, "label_misses_fuzzy_suggestions_rf.csv")
FINAL_PT = os.path.join(GRAPHS_DIR, "oasis2_graphs_labeled_auto_full.pt")
MANUAL_MAP_CSV = os.path.join(GRAPHS_DIR, "manual_mappings_to_apply.csv")  # optional - you can edit this and re-run

os.makedirs(LABELED_DIR, exist_ok=True)

# install rapidfuzz if not present (fast fuzzy)
try:
    from rapidfuzz import process, fuzz
except Exception:
    print("Installing rapidfuzz...")
    !pip -q install rapidfuzz
    from rapidfuzz import process, fuzz

# helpers
def safe_load_npz(p):
    try:
        return np.load(p, allow_pickle=True)
    except Exception as e:
        print("Failed np.load:", p, e); return {}

def normalize_dem_key(k):
    if k is None: return None
    k = str(k).strip()
    return k

def generate_variants(key):
    """Return a set of string variants for a dem key to maximize matching chance."""
    if key is None: return set()
    s = str(key).strip()
    variants = set()
    variants.add(s)
    variants.add(s.upper())
    variants.add(s.lower())
    # strip prefixes/suffixes and common separators
    variants.add(re.sub(r'[^0-9A-Za-z]', '', s))
    # zero-pad numeric part if found
    m = re.search(r'(\d+)', s)
    if m:
        num = m.group(1)
        for z in (0,1,2,3,4):
            variants.add(num.zfill(len(num)+z))
        # prefix forms
        variants.add("OAS2_" + num.zfill(4))
        variants.add("OAS2" + num.zfill(4))
        variants.add("OAS2-" + num.zfill(4))
        variants.add(num)
    # drop leading non-digits
    variants.add(re.sub(r'^[^\d]+', '', s))
    # letters-only, digits-only
    variants.add(re.sub(r'\d+', '', s))
    variants.add(''.join(ch for ch in s if ch.isalnum()))
    # small-case digits-only
    variants = set(v for v in variants if v is not None and len(str(v))>0)
    return set(str(v) for v in variants)

# Load demographics
if not os.path.exists(DEM_PATH):
    raise SystemExit(f"Demographics file not found at {DEM_PATH}. Place it there or update DEM_PATH and re-run.")
print("Loading demographics:", DEM_PATH)
if DEM_PATH.lower().endswith('.csv'):
    dem = pd.read_csv(DEM_PATH)
else:
    dem = pd.read_excel(DEM_PATH)
dem.rename(columns={c:c.strip() for c in dem.columns}, inplace=True)
# detect columns
cols = list(dem.columns)
id_col = next((c for c in cols if any(tok in c.upper() for tok in ("MRI","MRI_ID","MRIID","SUBJ","SUBJECT","ID"))), cols[0])
label_col = next((c for c in cols if any(tok in c.upper() for tok in ("CDR","CDR_GLOBAL","DEMENTIA","DIAG","SEVERITY"))), cols[-1])
print("Detected id_col:", id_col, "label_col:", label_col)
# build mapping with many variants
dem_map = {}
for _, row in dem.iterrows():
    raw_key = row.get(id_col)
    val = row.get(label_col)
    if pd.isna(raw_key) or pd.isna(val): continue
    key = normalize_dem_key(raw_key)
    for v in generate_variants(key):
        dem_map[str(v)] = val
# also create a reverse list of canonical keys for fuzzy candidates
dem_keys = sorted(set(normalize_dem_key(k) for k in dem[id_col].astype(str).tolist() if not pd.isna(k)))
print("Demographic entries (canonical):", len(dem_keys))

# discover npz files
npz_files = sorted(glob.glob(os.path.join(NPZ_FOLDER, "*.npz")))
if len(npz_files)==0:
    raise SystemExit(f"No npz files found in {NPZ_FOLDER}")
print("Found NPZ count:", len(npz_files))

# function to extract candidate ids from filename and inside file
def extract_candidates_from_npz(p):
    fname = os.path.basename(p)
    sids = set()
    # common patterns: OAS2_0123, subj_12, sub_12, 12
    for m in re.finditer(r'(OAS2[_\-]?\d{1,6})', fname, flags=re.IGNORECASE):
        sids.add(m.group(0))
    for m in re.finditer(r'(\d{1,6})', fname):
        sids.add(m.group(0))
    for m in re.finditer(r'(subj[_\-]?\d+|sub[_\-]?\d+|subj\d+|sub\d+)', fname, flags=re.IGNORECASE):
        sids.add(m.group(0))
    # try to read subject inside npz if present
    try:
        arr = safe_load_npz(p)
        # keys often 'subject_id' or 'sid' or 'filename'
        for k in ('subject_id','sid','id','filename','name'):
            if k in arr:
                v = arr[k]
                if isinstance(v, (np.ndarray, list)) and len(np.asarray(v))>0:
                    v = np.asarray(v).tolist()
                sids.add(str(v))
        # also some NPZs store arrays with dtype object in 'pos' that include label – ignore generally
    except Exception:
        pass
    # normalize
    sids = set(str(x).strip() for x in sids if x is not None and str(x).strip()!='')
    return sids

# cascade matching
auto_matches = []
misses = []
fuzzy_suggestions = []

# fuzzy helper using rapidfuzz
def fuzzy_best(query, choices, limit=5):
    if not query or len(choices)==0: return []
    res = process.extract(query, choices, scorer=fuzz.WRatio, limit=limit)
    # res = list of (candidate, score, idx)
    return [(r[0], float(r[1])) for r in res]

# main loop
for p in tqdm(npz_files, desc="Auto-matching NPZs"):
    arr = safe_load_npz(p)
    fname = os.path.basename(p)
    # extract x if present - only for stats
    try:
        x = arr.get('x', None)
        n_nodes = int(np.asarray(x).shape[0]) if x is not None else 0
        feat_dim = int(np.asarray(x).shape[1]) if (x is not None and np.asarray(x).ndim>1) else 1
    except Exception:
        n_nodes=0; feat_dim=0
    candidates = extract_candidates_from_npz(p)
    matched = False
    chosen_key = None
    chosen_val = None

    # Strategy 1: direct exact match using many variants
    for c in list(candidates):
        for variant in generate_variants(c):
            if variant in dem_map:
                chosen_key = variant; chosen_val = dem_map[variant]; matched=True; break
        if matched: break

    # Strategy 2: digits-only exact match
    if not matched:
        for c in list(candidates):
            digits = re.sub(r'\D+','', str(c))
            if digits and digits in dem_map:
                chosen_key=digits; chosen_val=dem_map[digits]; matched=True; break

    # Strategy 3: try concatenating prefix forms (OAS2_...) if dem keys look like that
    if not matched:
        for c in list(candidates):
            digits = re.sub(r'\D+','', str(c))
            if digits:
                trial = f"OAS2_{digits.zfill(4)}"
                if trial in dem_map:
                    chosen_key=trial; chosen_val=dem_map[trial]; matched=True; break

    # Strategy 4: try exact match on filename base (strip non-alnum)
    if not matched:
        base = re.sub(r'[^0-9A-Za-z]', '', os.path.splitext(fname)[0])
        if base in dem_map:
            chosen_key=base; chosen_val=dem_map[base]; matched=True

    # Strategy 5: fuzzy match using rapidfuzz between extracted candidate tokens and canonical dem_keys
    fuzzy_candidates = []
    if not matched:
        # build queries: tokens + filename base
        queries = list(candidates) + [os.path.splitext(fname)[0], re.sub(r'[^0-9A-Za-z]','',os.path.splitext(fname)[0])]
        # dedupe
        queries = [q for q in dict.fromkeys(queries) if q]
        all_matches = []
        for q in queries:
            out = fuzzy_best(q, dem_keys, limit=5)
            for cand,score in out:
                all_matches.append((q,cand,score))
        # reduce to best candidate by score
        if all_matches:
            best = sorted(all_matches, key=lambda x: x[2], reverse=True)[0]
            q,cand,score = best
            # apply high confidence threshold (>=90)
            if score >= 90:
                chosen_key = cand; chosen_val = dem.loc[dem[id_col].astype(str).str.strip()==cand, label_col].iloc[0] if (dem[id_col].astype(str).str.strip()==cand).any() else dem_map.get(cand)
                matched=True
            # otherwise record suggestions
            else:
                # prepare top suggestions string
                grouped = {}
                for q0,cand0,score0 in sorted(all_matches, key=lambda x:-x[2])[:10]:
                    grouped[cand0] = max(grouped.get(cand0, 0), score0)
                candstr = ";".join(f"{k}|{v:.3f}" for k,v in sorted(grouped.items(), key=lambda x:-x[1])[:10])
                fuzzy_suggestions.append({"npz_file":p, "npz_sid": ";".join(sorted(candidates)) if candidates else "", "candidates": candstr})
    # done strategies

    if matched and chosen_key is not None:
        # write labeled npz
        y = np.array([float(chosen_val)]) if not pd.isna(chosen_val) else np.array([str(chosen_val)])
        outp = os.path.join(LABELED_DIR, os.path.basename(p).replace(".npz","_labeled.npz"))
        pos = arr.get('pos', np.array([]))
        ei = arr.get('edge_index', np.empty((2,0), dtype=np.int64))
        xdata = arr.get('x', np.array([]))
        np.savez_compressed(outp, x=xdata, pos=pos, edge_index=ei, y=y)
        auto_matches.append({"npz_file":p, "npz_sid": ";".join(sorted(candidates)) if candidates else "", "dem_key": chosen_key, "y": chosen_val, "labeled_npz": outp, "n_nodes": n_nodes, "feat_dim": feat_dim})
    else:
        misses.append({"npz_file":p, "npz_sid": ";".join(sorted(candidates)) if candidates else ""})

# If we created fuzzy_suggestions, merge with misses list entries
# ensure there is one suggestion row per miss
fuzzy_df = pd.DataFrame(fuzzy_suggestions)
miss_df = pd.DataFrame(misses)
if not miss_df.empty:
    if not fuzzy_df.empty:
        # join on npz_file
        merged = miss_df.merge(fuzzy_df, on="npz_file", how="left")
    else:
        merged = miss_df.copy()
        merged["candidates"] = ""
else:
    merged = pd.DataFrame(columns=["npz_file","npz_sid","candidates"])

# Save intermediate CSVs
index_rows = auto_matches.copy()
pd.DataFrame(index_rows).to_csv(OUT_INDEX, index=False)
merged.to_csv(OUT_MISSES, index=False)
pd.DataFrame(fuzzy_suggestions).to_csv(OUT_SUGGEST, index=False)
print("Auto-labeled:", len(auto_matches), "Misses:", len(miss_df))
print("Saved index:", OUT_INDEX)
print("Saved misses:", OUT_MISSES)
print("Saved fuzzy suggestions:", OUT_SUGGEST)

# APPLY manual_mappings.csv if present (columns: npz_file,mapped_dem_key) - user can create/edit this CSV and re-run cell
if os.path.exists(MANUAL_MAP_CSV):
    print("Found manual mappings CSV - applying manual mappings from:", MANUAL_MAP_CSV)
    mm = pd.read_csv(MANUAL_MAP_CSV)
    applied = 0
    for _, r in mm.iterrows():
        p = r.get("npz_file")
        map_key = r.get("mapped_dem_key") or r.get("mapped_key") or r.get("dem_key")
        if pd.isna(p) or pd.isna(map_key): continue
        p = str(p).strip()
        map_key = str(map_key).strip()
        if not os.path.exists(p):
            print("Manual mapping: npz file not found:", p); continue
        # find canonical dem value
        demval = None
        # try exact match on id_col in dem
        row_sel = dem[dem[id_col].astype(str).str.strip()==map_key]
        if not row_sel.empty:
            demval = float(row_sel[label_col].iloc[0])
        else:
            # try dem_map lookup for variants
            for v in generate_variants(map_key):
                if v in dem_map:
                    demval = float(dem_map[v]); break
        if demval is None:
            print("Mapped dem key not found in demographics:", map_key); continue
        # write labeled npz
        arr = safe_load_npz(p)
        y = np.array([demval])
        outp = os.path.join(LABELED_DIR, os.path.basename(p).replace(".npz","_labeled.npz"))
        np.savez_compressed(outp, x=arr.get('x', np.array([])), pos=arr.get('pos', np.array([])), edge_index=arr.get('edge_index', np.empty((2,0),dtype=np.int64)), y=y)
        applied += 1
    print("Applied manual mappings:", applied)

# After auto + manual, rebuild final index and misses and create final .pt
# collect all labeled npz in LABELED_DIR
labeled_npzs = sorted(glob.glob(os.path.join(LABELED_DIR, "*_labeled.npz")))
print("Found labeled npzs (after manual apply):", len(labeled_npzs))

# build final index rows
final_index = []
for lp in labeled_npzs:
    arr = safe_load_npz(lp)
    fname = os.path.basename(lp)
    # try to recover original npz_file by searching name prefix inside NPZ_FOLDER
    original_candidates = sorted(glob.glob(os.path.join(NPZ_FOLDER, fname.replace("_labeled.npz","*.npz"))))
    orig = original_candidates[0] if original_candidates else ""
    y = arr.get('y', None)
    yv = float(np.asarray(y).reshape(-1)[0]) if y is not None and np.asarray(y).size>0 else None
    x = arr.get('x', None)
    n_nodes = int(np.asarray(x).shape[0]) if x is not None else 0
    feat_dim = int(np.asarray(x).shape[1]) if (x is not None and np.asarray(x).ndim>1) else 1
    final_index.append({"npz_file": orig, "labeled_npz": lp, "n_nodes": n_nodes, "feat_dim": feat_dim, "y": yv})

# recompute misses (npz files without labeled counterpart)
all_npz_set = set(npz_files)
labeled_orig_set = set(r["npz_file"] for r in final_index if r["npz_file"])
remaining_npzs = sorted(list(all_npz_set - labeled_orig_set))
final_misses_rows = []
for p in remaining_npzs:
    final_misses_rows.append({"npz_file": p, "npz_sid": ";".join(extract_candidates_from_npz(p))})

pd.DataFrame(final_index).to_csv(OUT_INDEX, index=False)
pd.DataFrame(final_misses_rows).to_csv(OUT_MISSES, index=False)
print("Final index saved:", OUT_INDEX)
print("Final misses saved:", OUT_MISSES, "count:", len(final_misses_rows))

# build final .pt (PyG Data objects) from labeled_npzs
try:
    import torch
    from torch_geometric.data import Data
    data_list = []
    for rec in final_index:
        lp = rec["labeled_npz"]
        if not lp or not os.path.exists(lp): continue
        arr = np.load(lp, allow_pickle=True)
        x = arr.get('x', np.array([]))
        pos = arr.get('pos', np.array([]))
        ei = arr.get('edge_index', np.empty((2,0),dtype=np.int64))
        y = arr.get('y', np.array([]))
        # convert to tensors where possible
        try:
            xt = torch.tensor(x, dtype=torch.float) if np.asarray(x).size else None
        except Exception:
            xt = None
        try:
            post = torch.tensor(pos, dtype=torch.float) if np.asarray(pos).size else None
        except Exception:
            post = None
        try:
            eit = torch.tensor(ei, dtype=torch.long) if np.asarray(ei).size else None
        except Exception:
            eit = None
        try:
            yt = torch.tensor(np.asarray(y).reshape(-1), dtype=torch.float) if np.asarray(y).size else None
        except Exception:
            yt = None
        # create Data (best-effort)
        try:
            g = Data(x=xt, edge_index=eit, pos=post, y=yt)
            # attach subject id guessed from filename
            g.subject_id = os.path.basename(lp).replace("_labeled.npz","")
            data_list.append(g)
        except Exception:
            continue
    if data_list:
        torch.save(data_list, FINAL_PT)
        print("Saved final labeled .pt with", len(data_list), "graphs to:", FINAL_PT)
    else:
        print("No Data objects created; final .pt not saved.")
except Exception as e:
    print("Could not build final .pt:", e)

print("DONE. Summary:")
print("  Auto-labeled (initial):", len(auto_matches))
print("  Labeled npzs (final):", len(labeled_npzs))
print("  Remaining misses:", len(final_misses_rows))
print("Files to inspect / edit manually:")
print("  - Fuzzy suggestions:", OUT_SUGGEST)
print("  - Final misses CSV (edit to provide mapped_dem_key or use manual_mappings_to_apply.csv):", OUT_MISSES)
print("  - To manually supply mappings, create/edit:", MANUAL_MAP_CSV, "with columns: npz_file,mapped_dem_key and re-run this cell.")


Loading demographics: /content/drive/MyDrive/oasis_project/data/demographics/oasis2_demographics.xlsx
Detected id_col: Subject ID label_col: CDR
Demographic entries (canonical): 150
Found NPZ count: 209


Auto-matching NPZs: 100%|██████████| 209/209 [00:11<00:00, 17.54it/s]


Auto-labeled: 163 Misses: 46
Saved index: /content/drive/MyDrive/oasis_project/data/graphs/final_graph_label_index_auto.csv
Saved misses: /content/drive/MyDrive/oasis_project/data/graphs/final_label_misses_auto.csv
Saved fuzzy suggestions: /content/drive/MyDrive/oasis_project/data/graphs/label_misses_fuzzy_suggestions_rf.csv
Found labeled npzs (after manual apply): 163
Final index saved: /content/drive/MyDrive/oasis_project/data/graphs/final_graph_label_index_auto.csv
Final misses saved: /content/drive/MyDrive/oasis_project/data/graphs/final_label_misses_auto.csv count: 46
Saved final labeled .pt with 163 graphs to: /content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_auto_full.pt
DONE. Summary:
  Auto-labeled (initial): 163
  Labeled npzs (final): 163
  Remaining misses: 46
Files to inspect / edit manually:
  - Fuzzy suggestions: /content/drive/MyDrive/oasis_project/data/graphs/label_misses_fuzzy_suggestions_rf.csv
  - Final misses CSV (edit to provide mapped_dem_key

In [None]:
# ONE-CELL: Auto-accept top fuzzy suggestions above threshold and rebuild final .pt (run in 04)
import os, glob, pandas as pd, numpy as np, re, shutil
from pathlib import Path
BASE = "/content/drive/MyDrive/oasis_project"
GRAPHS_DIR = os.path.join(BASE, "data", "graphs")
NPZ_FOLDER = os.path.join(GRAPHS_DIR, "npz_graphs_resave")
LABELED_DIR = os.path.join(GRAPHS_DIR, "labeled_npz_full")   # same dir used earlier
FUZZY_CSV = os.path.join(GRAPHS_DIR, "label_misses_fuzzy_suggestions_rf.csv")
FINAL_INDEX = os.path.join(GRAPHS_DIR, "final_graph_label_index_auto.csv")
FINAL_MISSES = os.path.join(GRAPHS_DIR, "final_label_misses_auto.csv")
FINAL_PT = os.path.join(GRAPHS_DIR, "oasis2_graphs_labeled_auto_full.pt")
DEM_PATH = os.path.join(BASE, "data", "demographics", "oasis2_demographics.xlsx")
THRESH = 0.90   # adjust: lower to accept more fuzzy matches (e.g. 0.80)

os.makedirs(LABELED_DIR, exist_ok=True)

if not os.path.exists(FUZZY_CSV):
    raise SystemExit(f"Fuzzy suggestions CSV not found: {FUZZY_CSV}")

print("Reading fuzzy suggestions:", FUZZY_CSV)
df = pd.read_csv(FUZZY_CSV)
# expected 'npz_file' and 'candidates' columns; candidates formatted like "OAS2_0111|0.900;OAS2_0112|0.900;..."
applied = []
skipped = []
# load demographics to map candidate -> value
if DEM_PATH.lower().endswith('.csv'):
    dem = pd.read_csv(DEM_PATH)
else:
    dem = pd.read_excel(DEM_PATH)
dem.rename(columns={c:c.strip() for c in dem.columns}, inplace=True)
cols = list(dem.columns)
id_col = next((c for c in cols if any(tok in c.upper() for tok in ("MRI","MRI_ID","MRIID","SUBJ","SUBJECT","ID"))), cols[0])
label_col = next((c for c in cols if any(tok in c.upper() for tok in ("CDR","CDR_GLOBAL","DEMENTIA","DIAG","SEVERITY"))), cols[-1])

def dem_value_for_key(k):
    # k expected like 'OAS2_0111' or '0111' etc.
    k = str(k).strip()
    # exact
    sel = dem[dem[id_col].astype(str).str.strip()==k]
    if not sel.empty:
        return float(sel[label_col].iloc[0])
    # try digits-only
    digits = re.sub(r'\D+','',k)
    if digits:
        sel = dem[dem[id_col].astype(str).str.contains(digits, na=False)]
        if not sel.empty:
            return float(sel[label_col].iloc[0])
    return None

for _, row in df.iterrows():
    p = row.get('npz_file')
    candstr = row.get('candidates', "")
    if pd.isna(p) or not p: continue
    # parse top candidate
    parts = [s for s in str(candstr).split(';') if s]
    if not parts:
        skipped.append(p); continue
    top = parts[0]  # "OAS2_0111|0.900"
    if '|' in top:
        k, s = top.split('|',1)
        score = float(s)
    else:
        k = top; score = 0.0
    if score >= THRESH:
        val = dem_value_for_key(k)
        if val is None:
            # try to strip leading non-digits and lookup
            k2 = re.sub(r'^[^\d]+','',k)
            val = dem_value_for_key(k2)
        if val is None:
            skipped.append(p)
            continue
        # load original npz, write labeled npz
        try:
            arr = np.load(p, allow_pickle=True)
        except Exception as e:
            print("Failed load npz:", p, e); skipped.append(p); continue
        x = arr.get('x', np.array([])); pos = arr.get('pos', np.array([]))
        ei = arr.get('edge_index', arr.get('edges', np.empty((2,0), dtype=np.int64)))
        y = np.array([float(val)])
        outp = os.path.join(LABELED_DIR, os.path.basename(p).replace(".npz","_labeled.npz"))
        np.savez_compressed(outp, x=x if x is not None else np.array([]), pos=pos if pos is not None else np.array([]),
                            edge_index=ei if ei is not None else np.empty((2,0), dtype=np.int64), y=y)
        applied.append({"npz_file":p, "mapped_key":k, "score":score, "y":val, "labeled_npz":outp})
    else:
        skipped.append(p)

print("Auto-applied matches above threshold:", len(applied))
print("Skipped (below threshold or failed):", len(skipped))

# Rebuild index + misses and .pt (same logic as prior cell)
# gather all labeled npzs
labeled_npzs = sorted(glob.glob(os.path.join(LABELED_DIR, "*_labeled.npz")))
final_index_rows = []
for lp in labeled_npzs:
    arr = np.load(lp, allow_pickle=True)
    y = arr.get('y', np.array([]))
    x = arr.get('x', np.array([]))
    # try find original npz by name prefix
    baseprefix = os.path.basename(lp).replace("_labeled.npz","")
    candidates = sorted(glob.glob(os.path.join(NPZ_FOLDER, f"*{baseprefix}*.npz")))
    orig = candidates[0] if candidates else ""
    n_nodes = int(np.asarray(x).shape[0]) if x is not None and np.asarray(x).size else 0
    feat_dim = int(np.asarray(x).shape[1]) if (x is not None and np.asarray(x).ndim>1) else 1
    yv = float(np.asarray(y).reshape(-1)[0]) if np.asarray(y).size else None
    final_index_rows.append({"npz_file": orig, "labeled_npz": lp, "n_nodes": n_nodes, "feat_dim": feat_dim, "y": yv})

all_npz_set = set(sorted(glob.glob(os.path.join(NPZ_FOLDER, "*.npz"))))
labeled_orig_set = set(r["npz_file"] for r in final_index_rows if r["npz_file"])
remaining_npzs = sorted(list(all_npz_set - labeled_orig_set))
final_misses_rows = [{"npz_file":p, "npz_sid": ""} for p in remaining_npzs]

pd.DataFrame(final_index_rows).to_csv(FINAL_INDEX, index=False)
pd.DataFrame(final_misses_rows).to_csv(FINAL_MISSES, index=False)
print("Wrote final index:", FINAL_INDEX)
print("Wrote final misses:", FINAL_MISSES, "count:", len(final_misses_rows))

# Build final .pt from labeled npzs
try:
    import torch
    from torch_geometric.data import Data
    data_list = []
    for rec in final_index_rows:
        lp = rec["labeled_npz"]
        if not lp or not os.path.exists(lp): continue
        arr = np.load(lp, allow_pickle=True)
        x = arr.get('x', np.array([])); pos = arr.get('pos', np.array([])); ei = arr.get('edge_index', np.empty((2,0),dtype=np.int64)); y = arr.get('y', np.array([]))
        try:
            xt = torch.tensor(np.asarray(x), dtype=torch.float) if np.asarray(x).size else None
        except Exception:
            xt = None
        try: post = torch.tensor(np.asarray(pos), dtype=torch.float) if np.asarray(pos).size else None
        except: post = None
        try: eit = torch.tensor(np.asarray(ei), dtype=torch.long) if np.asarray(ei).size else None
        except: eit = None
        try: yt = torch.tensor(np.asarray(y).reshape(-1), dtype=torch.float) if np.asarray(y).size else None
        except: yt = None
        try:
            g = Data(x=xt, edge_index=eit, pos=post, y=yt)
            g.subject_id = os.path.basename(lp).replace("_labeled.npz","")
            data_list.append(g)
        except Exception:
            pass
    if data_list:
        torch.save(data_list, FINAL_PT)
        print("Saved final .pt with", len(data_list), "graphs to:", FINAL_PT)
    else:
        print("No Data objects created; .pt not saved.")
except Exception as e:
    print("Failed building .pt:", e)

print("Done. Summary:")
print("  Newly applied:", len(applied))
print("  Total labeled npzs now:", len(labeled_npzs))
print("  Remaining misses:", len(final_misses_rows))
print("If remaining misses > 0, open the file and either (1) edit and create manual_mappings_to_apply.csv or (2) manually pick keys from the fuzzy suggestions CSV.")


Reading fuzzy suggestions: /content/drive/MyDrive/oasis_project/data/graphs/label_misses_fuzzy_suggestions_rf.csv
Auto-applied matches above threshold: 46
Skipped (below threshold or failed): 0
Wrote final index: /content/drive/MyDrive/oasis_project/data/graphs/final_graph_label_index_auto.csv
Wrote final misses: /content/drive/MyDrive/oasis_project/data/graphs/final_label_misses_auto.csv count: 0
Saved final .pt with 209 graphs to: /content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_auto_full.pt
Done. Summary:
  Newly applied: 46
  Total labeled npzs now: 209
  Remaining misses: 0
If remaining misses > 0, open the file and either (1) edit and create manual_mappings_to_apply.csv or (2) manually pick keys from the fuzzy suggestions CSV.
