In [None]:
from google.colab import drive
import os
import shutil

# Unmount drive if already mounted and clear the mountpoint
if os.path.exists('/content/drive'):
    try:
        drive.flush_and_unmount()
        # It might take a moment for the unmount to complete,
        # but often clearing the directory is sufficient.
    except ValueError:
        # Drive was not mounted, no need to unmount
        pass

    # Clear the mountpoint directory
    if os.path.exists('/content/drive') and os.path.isdir('/content/drive'):
        for item in os.listdir('/content/drive'):
            item_path = os.path.join('/content/drive', item)
            try:
                if os.path.isfile(item_path) or os.path.islink(item_path):
                    os.unlink(item_path)
                elif os.path.isdir(item_path):
                    shutil.rmtree(item_path)
            except Exception as e:
                print(f"Error removing {item_path}: {e}")


drive.mount('/content/drive', force_remount=True)

print("✅ Drive mounted.")
print("Drive contents at root:")
for f in os.listdir("/content/drive"):
    print(" -", f)

# verify main project folder exists
proj = "/content/drive/MyDrive/oasis_project"
print("\nProject folder exists?", os.path.exists(proj))
if os.path.exists(proj):
    print("Contents of oasis_project:")
    for f in os.listdir(proj):
        print("   ", f)

Mounted at /content/drive
✅ Drive mounted.
Drive contents at root:
 - .shortcut-targets-by-id
 - MyDrive
 - .Trash-0
 - .Encrypted

Project folder exists? True
Contents of oasis_project:
    notebooks
    data
    outputs
    logs
    oasis2_graph_dataset.pt


In [None]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m70.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [None]:
from pathlib import Path
import os, tarfile, pandas as pd, numpy as np, nibabel as nib, torch
from skimage.transform import resize
from tqdm import tqdm
from sklearn.neighbors import kneighbors_graph
from torch_geometric.data import Data

BASE = Path("/content/drive/MyDrive/oasis_project")
RAW_DIR = BASE / "data" / "raw" / "oasis1"
PREPROC_DIR = BASE / "data" / "preproc" / "oasis1"
GRAPH_DIR = BASE / "data" / "graphs"
for p in [RAW_DIR, PREPROC_DIR, GRAPH_DIR]:
    p.mkdir(parents=True, exist_ok=True)

CLIN_PATH = Path("/content/drive/MyDrive/oasis_project/data/demographics/oasis2_demographics.xlsx")
TAR_PATH  = Path("/content/drive/MyDrive/OAS2_RAW_PART2.tar.gz")

print("Clinical file exists:", CLIN_PATH.exists())
print("MRI archive exists:", TAR_PATH.exists())


Clinical file exists: True
MRI archive exists: True


In [None]:
# Extract only once
extracted_flag = RAW_DIR / ".extracted"
if not extracted_flag.exists():
    print("Extracting .tar.gz ... this may take a few minutes")
    with tarfile.open(TAR_PATH, "r:gz") as tar:
        tar.extractall(RAW_DIR)
    extracted_flag.write_text("done")
else:
    print("Already extracted:", RAW_DIR)

# list nii or img/hdr files
nii_files = sorted(list(RAW_DIR.rglob("*.nii*")) + list(RAW_DIR.rglob("*.img")))
print("Total MRI files:", len(nii_files))
print("Sample:", [f.name for f in nii_files[:5]])


Already extracted: /content/drive/MyDrive/oasis_project/data/raw/oasis1
Total MRI files: 304
Sample: ['OAS1_0001_MR1_mpr_n4_anon_111_t88_masked_gfc_fseg.img', 'OAS1_0001_MR1_mpr_n4_anon_sbj_111.img', 'OAS1_0001_MR1_mpr_n4_anon_111_t88_gfc.img', 'OAS1_0001_MR1_mpr_n4_anon_111_t88_masked_gfc.img', 'OAS1_0001_MR1_mpr-1_anon.img']


In [None]:
# Fixed metadata parsing cell (paste & run)
import pandas as pd
import numpy as np
from pathlib import Path
CLIN_PATH = Path("/content/drive/MyDrive/oasis_cross-sectional-5708aa0a98d82080 (1).xlsx")

meta = pd.read_excel(CLIN_PATH)
print("Original columns:", meta.columns.tolist())
display(meta.head(5))

# Build a mapping original_col -> canonical name
rename_map = {}
col_lower = [c.lower() for c in meta.columns]

def find_original(keys):
    for k in keys:
        for orig in meta.columns:
            if k in orig.lower():
                return orig
    return None

orig_subject = find_original(["subject", "id", "subj"])
orig_cdr     = find_original(["cdr", "cdR", "diagnosis", "label"])
orig_age     = find_original(["age"])
orig_time    = find_original(["delay", "visit", "examdate", "time", "session"])

if orig_subject: rename_map[orig_subject] = "subject_id"
if orig_cdr:     rename_map[orig_cdr]     = "cdr"
if orig_age:     rename_map[orig_age]     = "age"
if orig_time:    rename_map[orig_time]    = "timepoint"

print("Detected rename mapping (orig -> new):", rename_map)
# apply rename
meta = meta.rename(columns=rename_map)

# ensure required columns exist or create safe defaults
if "subject_id" not in meta.columns:
    raise KeyError(f"Could not detect subject id column automatically. Available columns: {list(meta.columns)}")

if "cdr" not in meta.columns:
    print("Warning: CDR column not detected — filling with NaN")
    meta["cdr"] = np.nan

if "timepoint" not in meta.columns:
    # create a default timepoint column filled with NaN (we will convert to 0.0 later)
    meta["timepoint"] = np.nan

# sanitize subject_id to a compact form that matches filenames:
# e.g., convert "OAS1_0001_MR1" -> "OAS1_0001"
def canonical_subject(s):
    s = str(s)
    # common pattern: keep first two underscore parts if present
    parts = s.split('_')
    if len(parts) >= 2:
        return "_".join(parts[:2])
    return s

meta["subject_id"] = meta["subject_id"].astype(str).map(canonical_subject)

# coerce cdr to float where possible
meta["cdr"] = pd.to_numeric(meta["cdr"], errors="coerce")

# fill missing timepoint with NaN (or 0.0 if you prefer)
meta["timepoint"] = pd.to_numeric(meta["timepoint"], errors="coerce")
# If you want a numeric default, you can uncomment next line:
# meta["timepoint"] = meta["timepoint"].fillna(0.0)

print("After normalization — columns now:", meta.columns.tolist())
print("Unique subjects:", meta["subject_id"].nunique())
display(meta[["subject_id","cdr","timepoint"]].head(10))


Original columns: ['ID', 'M/F', 'Hand', 'Age', 'Educ', 'SES', 'MMSE', 'CDR', 'eTIV', 'nWBV', 'ASF', 'Delay']


Unnamed: 0,ID,M/F,Hand,Age,Educ,SES,MMSE,CDR,eTIV,nWBV,ASF,Delay
0,OAS1_0001_MR1,F,R,74,2.0,3.0,29.0,0.0,1344,0.743,1.306,
1,OAS1_0002_MR1,F,R,55,4.0,1.0,29.0,0.0,1147,0.81,1.531,
2,OAS1_0003_MR1,F,R,73,4.0,3.0,27.0,0.5,1454,0.708,1.207,
3,OAS1_0004_MR1,M,R,28,,,,,1588,0.803,1.105,
4,OAS1_0005_MR1,M,R,18,,,,,1737,0.848,1.01,


Detected rename mapping (orig -> new): {'ID': 'subject_id', 'CDR': 'cdr', 'Age': 'age', 'Delay': 'timepoint'}
After normalization — columns now: ['subject_id', 'M/F', 'Hand', 'age', 'Educ', 'SES', 'MMSE', 'cdr', 'eTIV', 'nWBV', 'ASF', 'timepoint']
Unique subjects: 416


Unnamed: 0,subject_id,cdr,timepoint
0,OAS1_0001,0.0,
1,OAS1_0002,0.0,
2,OAS1_0003,0.5,
3,OAS1_0004,,
4,OAS1_0005,,
5,OAS1_0006,,
6,OAS1_0007,,
7,OAS1_0009,,
8,OAS1_0010,0.0,
9,OAS1_0011,0.0,


In [None]:
def preprocess_mri(path, out_dir, target_shape=(96,96,96)):
    out_dir.mkdir(parents=True, exist_ok=True)
    img = nib.load(str(path))
    data = img.get_fdata().astype(np.float32)
    # simple normalization
    mask = data > np.percentile(data, 5)
    mean, std = data[mask].mean(), data[mask].std()
    data = (data - mean) / (std + 1e-6)
    data = resize(data, target_shape, order=1, preserve_range=True, anti_aliasing=True)
    out_path = out_dir / f"{path.stem}_preproc.nii.gz"
    nib.save(nib.Nifti1Image(data, affine=np.eye(4)), str(out_path))
    return out_path

preproc_files = []
for f in tqdm(nii_files, desc="Preprocessing MRIs"):
    try:
        out = preprocess_mri(f, PREPROC_DIR)
        preproc_files.append(out)
    except Exception as e:
        print("Error preprocessing", f, e)
print("Preprocessed files:", len(preproc_files))


Preprocessing MRIs: 100%|██████████| 304/304 [05:02<00:00,  1.01it/s]

Preprocessed files: 304





In [None]:
def mri_to_graph(nifti_path, subject_id, label, timepoint, patch_size=3, n_neighbors=6, sample_n=400):
    data = nib.load(str(nifti_path)).get_fdata().astype(np.float32)
    data = (data - data.mean()) / (data.std() + 1e-6)

    coords_all = np.stack(np.meshgrid(np.arange(data.shape[0]),
                                      np.arange(data.shape[1]),
                                      np.arange(data.shape[2]),
                                      indexing='ij'), axis=-1).reshape(-1,3)
    mask = data.reshape(-1) > np.percentile(data.reshape(-1), 10)
    coords = coords_all[mask]
    if coords.shape[0] > sample_n:
        idx = np.linspace(0, coords.shape[0]-1, sample_n).astype(int)
        coords = coords[idx]

    feats = []
    for (x,y,z) in coords:
        x0,x1 = max(0,x-patch_size), min(data.shape[0],x+patch_size+1)
        y0,y1 = max(0,y-patch_size), min(data.shape[1],y+patch_size+1)
        z0,z1 = max(0,z-patch_size), min(data.shape[2],z+patch_size+1)
        patch = data[x0:x1,y0:y1,z0:z1]
        feats.append([patch.mean(), patch.std(), patch.min(), patch.max()])
    X = np.array(feats, dtype=np.float32)
    if len(coords) > 1:
        A = kneighbors_graph(coords, n_neighbors=min(n_neighbors,len(coords)-1), mode='connectivity', include_self=False)
        edge_index = np.vstack(A.nonzero()).astype(np.int64)
    else:
        edge_index = np.zeros((2,0), dtype=np.int64)
    g = Data(x=torch.tensor(X), edge_index=torch.tensor(edge_index, dtype=torch.long))
    g.y = torch.tensor([float(label)], dtype=torch.float32)
    g.subject_id = str(subject_id)
    g.timepoint = float(timepoint) if not pd.isna(timepoint) else 0.0
    return g


In [None]:
graphs = []
skipped = []
for f in tqdm(preproc_files, desc="Graph building"):
    name = f.stem.split("_")[0]
    row = meta[meta["subject_id"].astype(str).str.contains(name, case=False, na=False)]
    if len(row)==0:
        skipped.append(name)
        continue
    sid = row.iloc[0]["subject_id"]
    cdr = row.iloc[0]["cdr"]
    tp  = row.iloc[0].get("timepoint", 0.0)
    try:
        g = mri_to_graph(f, sid, cdr, tp)
        graphs.append(g)
    except Exception as e:
        skipped.append((name,str(e)))
        print("Failed", name, e)

print("Built graphs:", len(graphs), "Skipped:", len(skipped))


Graph building: 100%|██████████| 304/304 [00:34<00:00,  8.72it/s]

Built graphs: 304 Skipped: 0





In [None]:
OUT_PT = GRAPH_DIR / "oasis1_1_graphs_labeled_auto_full.pt"
torch.save(graphs, OUT_PT)
print("✅ Saved:", OUT_PT)
if graphs:
    g0 = graphs[0]
    print("Example graph:", g0)
    print("x shape:", g0.x.shape, "edges:", g0.edge_index.shape[1], "label:", g0.y.item())


✅ Saved: /content/drive/MyDrive/oasis_project/data/graphs/oasis1_graphs_labeled_auto_full.pt
Example graph: Data(x=[400, 4], edge_index=[2, 2400], y=[1], subject_id='OAS1_0001', timepoint=0.0)
x shape: torch.Size([400, 4]) edges: 2400 label: 0.0


In [None]:
# Paste & run this in Colab
from pathlib import Path
import torch
from collections import Counter, defaultdict
import numpy as np

MASTER_PT = Path("/content/drive/MyDrive/oasis_project/data/graphs/oasis1_graphs_labeled_auto_full.pt")
assert MASTER_PT.exists(), f"Master file not found: {MASTER_PT}"

# Safe load (allowlist PyG storage global) — recommended
try:
    import torch_geometric
    with torch.serialization.safe_globals([torch_geometric.data.storage.GlobalStorage]):
        graphs = torch.load(MASTER_PT, map_location="cpu")
except Exception as e:
    # fallback (only if you trust file)
    print("safe_globals load failed, trying weights_only=False fallback:", e)
    graphs = torch.load(MASTER_PT, map_location="cpu", weights_only=False)

print("Loaded graphs:", len(graphs))

# Gather basic info
subject_ids = []
timepoints = []
labels = []
for g in graphs:
    sid = getattr(g, "subject_id", None)
    # canonicalize short id if it contains MR suffix
    if sid is not None:
        sid = str(sid)
        parts = sid.split("_")
        if len(parts) >= 2:
            sid = "_".join(parts[:2])
    subject_ids.append(sid)
    tp = getattr(g, "timepoint", None)
    timepoints.append(tp if tp is not None else np.nan)
    try:
        yv = float(g.y.view(-1).cpu().numpy()[0])
    except Exception:
        try:
            yv = float(g.y)
        except Exception:
            yv = np.nan
    labels.append(yv)

n_graphs = len(graphs)
unique_subjects = set([s for s in subject_ids if s is not None])
print("Total graphs:", n_graphs)
print("Unique subject IDs:", len(unique_subjects))

# label distribution
lab_counts = Counter([l for l in labels if not np.isnan(l)])
print("Label counts (non-NaN):")
for k,v in sorted(lab_counts.items()):
    print("  ", k, ":", v)
print("Labels with NaN (missing):", sum(1 for l in labels if np.isnan(l)))

# duplicates per (subject,timepoint)
pair_counts = Counter()
for sid, tp in zip(subject_ids, timepoints):
    key = (sid, float(tp) if not (tp is None or (isinstance(tp,float) and np.isnan(tp))) else "__na__")
    pair_counts[key] += 1

dupes = [(k,c) for k,c in pair_counts.items() if c>1]
print("Subjects/timepoint combinations with >1 graph (sample up to 20):", len(dupes))
for k,c in dupes[:20]:
    print(" ", k, "->", c)

# show 5 example graphs (subject, shape, edges)
print("\nExample graphs (first 5):")
for i,g in enumerate(graphs[:5]):
    sid = subject_ids[i]
    tp  = timepoints[i]
    xshape = getattr(g, "x").shape if hasattr(g, "x") else None
    ecount = g.edge_index.shape[1] if hasattr(g, "edge_index") else None
    yval = labels[i]
    print(f" {i:02d}) subject={sid} timepoint={tp} x.shape={xshape} edges={ecount} y={yval}")

# quick sanity checks:
#  - any graphs with zero nodes?
zero_nodes = [i for i,g in enumerate(graphs) if (not hasattr(g,"x") or getattr(g,"x") is None or getattr(g,"x").shape[0]==0)]
print("Graphs with zero nodes:", zero_nodes[:20])


safe_globals load failed, trying weights_only=False fallback: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL torch_geometric.data.data.DataEdgeAttr was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torch_geometric.data.data.DataEdgeAttr])` or the `torch.serialization.safe_globals([torch_geometric.data.data.DataEdgeAttr])` context manager to allowlist this global if you tru

In [None]:
from pathlib import Path
import tarfile

BASE = Path("/content/drive/MyDrive/oasis_project")
RAW_ROOT = BASE / "data" / "raw" / "oasis2"
PREPROC_DIR = BASE / "data" / "preproc" / "oasis2"
GRAPH_DIR = BASE / "data" / "graphs"
for p in [RAW_ROOT, PREPROC_DIR, GRAPH_DIR]:
    p.mkdir(parents=True, exist_ok=True)

CLIN_XLSX = Path("/content/drive/MyDrive/oasis_project/data/demographics/oasis2_demographics.xlsx")
ARCHIVE   = Path("/content/drive/MyDrive/OAS2_RAW_PART2.tar.gz")

print("Clinical exists:", CLIN_XLSX.exists())
print("Archive exists:", ARCHIVE.exists())

# extract only once
out_subdir = RAW_ROOT / ARCHIVE.stem
if not out_subdir.exists():
    out_subdir.mkdir(parents=True, exist_ok=True)
extracted_flag = out_subdir / ".extracted"
if not extracted_flag.exists():
    print("Extracting", ARCHIVE, "->", out_subdir)
    with tarfile.open(ARCHIVE, "r:*") as tar:
        tar.extractall(out_subdir)
    extracted_flag.write_text("done")
else:
    print("Already extracted at", out_subdir)

# list nifti files
nii_files = sorted(list(out_subdir.rglob("*.nii*")) + list(out_subdir.rglob("*.img")))
print("Found NIfTI files:", len(nii_files))


Clinical exists: True
Archive exists: True
Extracting /content/drive/MyDrive/OAS2_RAW_PART2.tar.gz -> /content/drive/MyDrive/oasis_project/data/raw/oasis2/OAS2_RAW_PART2.tar


  tar.extractall(out_subdir)


Found NIfTI files: 596


In [None]:
import pandas as pd
import numpy as np

meta = pd.read_excel(CLIN_XLSX)
print("Clinical columns:", meta.columns.tolist())

# heuristics to find ID, label, time columns
def find_col(keylist):
    for k in keylist:
        for c in meta.columns:
            if k.lower() in c.lower():
                return c
    return None

id_col = find_col(["subject","id","subj","ptid"]) or meta.columns[0]
label_col = find_col(["cdr","label","dx","diagnosis"])
time_col  = find_col(["visit","time","date","delay","session"])

print("Using id_col:", id_col, "label_col:", label_col, "time_col:", time_col)

# canonicalize subject id to match filenames (e.g. OAS2_0001_MR1 -> OAS2_0001)
def canonical_subject(s):
    s = str(s)
    parts = s.split("_")
    return "_".join(parts[:2]) if len(parts) >= 2 else s

meta = meta.copy()
meta["subject_id"] = meta[id_col].astype(str).map(canonical_subject)
if label_col:
    meta["label"] = pd.to_numeric(meta[label_col], errors="coerce")
else:
    meta["label"] = np.nan
if time_col:
    meta["timepoint"] = pd.to_numeric(meta[time_col], errors="coerce")
else:
    meta["timepoint"] = np.nan

print("Unique subjects in clinical:", meta["subject_id"].nunique())
meta[["subject_id","label","timepoint"]].head(6)


Clinical columns: ['Subject ID', 'MRI ID', 'Group', 'Visit', 'MR Delay', 'M/F', 'Hand', 'Age', 'EDUC', 'SES', 'MMSE', 'CDR', 'eTIV', 'nWBV', 'ASF']
Using id_col: Subject ID label_col: CDR time_col: Visit
Unique subjects in clinical: 150


Unnamed: 0,subject_id,label,timepoint
0,OAS2_0001,0.0,1
1,OAS2_0001,0.0,2
2,OAS2_0002,0.5,1
3,OAS2_0002,0.5,2
4,OAS2_0002,0.5,3
5,OAS2_0004,0.0,1


In [None]:
import nibabel as nib
import numpy as np
from skimage.transform import resize
from sklearn.neighbors import kneighbors_graph
import torch
from torch_geometric.data import Data

def preprocess_mri(path, out_dir, target_shape=(96,96,96)):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    img = nib.load(str(path))
    data = img.get_fdata().astype(np.float32)
    mask = data > np.percentile(data, 5)
    if mask.sum() == 0:
        mask = data != 0
    mu = data[mask].mean() if mask.sum()>0 else data.mean()
    sd = data[mask].std() if mask.sum()>0 else data.std()
    data = (data - mu) / (sd + 1e-6)
    data = resize(data, target_shape, order=1, preserve_range=True, anti_aliasing=True)
    out_path = Path(out_dir) / f"{Path(path).stem}_preproc.nii.gz"
    nib.save(nib.Nifti1Image(data.astype(np.float32), affine=np.eye(4)), str(out_path))
    return out_path

def mri_to_graph(nifti_path, subject_id, label, timepoint, patch_size=3, n_neighbors=6, sample_n=400):
    data = nib.load(str(nifti_path)).get_fdata().astype(np.float32)
    data = (data - data.mean()) / (data.std() + 1e-6)
    coords_all = np.stack(np.meshgrid(np.arange(data.shape[0]),
                                      np.arange(data.shape[1]),
                                      np.arange(data.shape[2]),
                                      indexing='ij'), axis=-1).reshape(-1,3)
    mask = data.reshape(-1) > np.percentile(data.reshape(-1), 10)
    coords = coords_all[mask]
    if coords.shape[0] > sample_n:
        idx = np.linspace(0, coords.shape[0]-1, sample_n).astype(int)
        coords = coords[idx]
    feats = []
    valid_coords = []
    for (x,y,z) in coords:
        x,y,z = int(x), int(y), int(z)
        x0,x1 = max(0,x-patch_size), min(data.shape[0], x+patch_size+1)
        y0,y1 = max(0,y-patch_size), min(data.shape[1], y+patch_size+1)
        z0,z1 = max(0,z-patch_size), min(data.shape[2], z+patch_size+1)
        patch = data[x0:x1, y0:y1, z0:z1]
        feats.append([patch.mean(), patch.std(), patch.min(), patch.max()])
        valid_coords.append([x,y,z])
    X = np.array(feats, dtype=np.float32)
    if len(valid_coords) > 1:
        A = kneighbors_graph(np.array(valid_coords), n_neighbors=min(n_neighbors, len(valid_coords)-1),
                             mode='connectivity', include_self=False)
        rows, cols = A.nonzero()
        edge_index = np.vstack([rows, cols]).astype(np.int64)
    else:
        edge_index = np.zeros((2,0), dtype=np.int64)
    g = Data(x=torch.tensor(X), edge_index=torch.tensor(edge_index, dtype=torch.long))
    g.y = torch.tensor([float(label) if not np.isnan(label) else np.nan], dtype=torch.float32)
    g.subject_id = str(subject_id)
    g.timepoint = float(timepoint) if (timepoint is not None and not np.isnan(timepoint)) else 0.0
    # track origin
    g.source_archive = str(Path(nifti_path).parts[1] if len(Path(nifti_path).parts)>1 else Path(nifti_path).parent.name)
    return g


In [None]:
from tqdm import tqdm
import re

# build quick lookup by subject id (first occurrence)
meta_index = meta.set_index("subject_id")

graphs_part = []
skipped = []

# try to match file -> subject by substring (subject id should appear in filename)
for nii in tqdm(nii_files, desc="Process files"):
    stem = Path(nii).name
    # find candidate subject id(s)
    matched_sid = None
    for sid in meta["subject_id"].unique():
        if sid in stem:
            matched_sid = sid
            break
    if matched_sid is None:
        # fallback: try regex of digits e.g. '0001' and match
        m = re.search(r"\d{3,4}", stem)
        if m:
            token = m.group(0)
            cand = meta[meta["subject_id"].str.contains(token, na=False)]
            if len(cand)>0:
                matched_sid = cand.iloc[0]["subject_id"]
    if matched_sid is None:
        skipped.append((nii, "no_subject_match"))
        continue

    # take first matching clinical row for label/time
    row = meta_index.loc[matched_sid]
    label = row["label"] if "label" in row.index else np.nan
    tp = row["timepoint"] if "timepoint" in row.index else 0.0

    try:
        preproc = preprocess_mri(nii, PREPROC_DIR, target_shape=(96,96,96))
        g = mri_to_graph(preproc, matched_sid, label, tp, patch_size=3, n_neighbors=6, sample_n=400)
        graphs_part.append(g)
    except Exception as e:
        skipped.append((nii, str(e)))

print("Built graphs in this part:", len(graphs_part), "Skipped:", len(skipped))


Process files: 100%|██████████| 596/596 [00:00<00:00, 13180.94it/s]

Built graphs in this part: 0 Skipped: 596





In [None]:
import torch
from pathlib import Path
import tempfile
import numpy as np

PART_PT = GRAPH_DIR / "oasis2_graphs_labeled_part2.pt"
MASTER_PT = GRAPH_DIR / "oasis2_graphs_labeled_auto_full.pt"
GLOBAL_PT = GRAPH_DIR / "all_graphs_master.pt"

# save the part file
torch.save(graphs_part, PART_PT)
print("Saved part graphs:", PART_PT, "count:", len(graphs_part))

# load existing master if exists (use safe_globals)
existing = []
try:
    import torch_geometric
    with torch.serialization.safe_globals([torch_geometric.data.storage.GlobalStorage]):
        if MASTER_PT.exists():
            existing = torch.load(MASTER_PT, map_location="cpu")
except Exception:
    if MASTER_PT.exists():
        existing = torch.load(MASTER_PT, map_location="cpu", weights_only=False)

print("Existing master count:", len(existing))

# merge with dedupe by (subject_id, timepoint, source_archive) preferring new graphs if label present
merged = {}
def key_for(g, idx):
    sid = getattr(g, "subject_id", None) or f"__idx__{idx}"
    tp = getattr(g, "timepoint", None) or 0.0
    src = getattr(g, "source_archive", "")
    return (str(sid), float(tp), str(src))

# add existing
for i,g in enumerate(existing):
    merged[key_for(g,i)] = g

# add new, overriding when appropriate (prefer labeled or prefer new)
for i,g in enumerate(graphs_part):
    k = key_for(g,i)
    if k not in merged:
        merged[k] = g
    else:
        prev = merged[k]
        try:
            prev_y = float(prev.y.view(-1).cpu().numpy()[0])
        except Exception:
            prev_y = np.nan
        try:
            new_y = float(g.y.view(-1).cpu().numpy()[0])
        except Exception:
            new_y = np.nan
        if (not np.isnan(new_y) and np.isnan(prev_y)) or (not np.isnan(new_y) and not np.isnan(prev_y)):
            # prefer new if it has label or equally labeled (choose new)
            merged[k] = g

merged_list = list(merged.values())
print("Merged master count (after adding this part):", len(merged_list))

# atomic save
tmp = tempfile.mktemp(suffix=".pt")
torch.save(merged_list, tmp)
Path(tmp).replace(MASTER_PT)
print("Saved merged master to:", MASTER_PT)

# also append to global master (all_graphs_master.pt)
global_existing = []
if GLOBAL_PT.exists():
    try:
        with torch.serialization.safe_globals([torch_geometric.data.storage.GlobalStorage]):
            global_existing = torch.load(GLOBAL_PT, map_location="cpu")
    except Exception:
        global_existing = torch.load(GLOBAL_PT, map_location="cpu", weights_only=False)
# append all new graphs (no dedupe across datasets unless you want it)
global_combined = list(global_existing) + graphs_part
torch.save(global_combined, GLOBAL_PT)
print("Appended to global master (count now):", len(global_combined))


Saved part graphs: /content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_part2.pt count: 0
Existing master count: 209
Merged master count (after adding this part): 209


OSError: [Errno 18] Invalid cross-device link: '/tmp/tmpjrbncr1w.pt' -> '/content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_auto_full.pt'

In [None]:
# --- Fixer: match by parent folder name, build graphs for PART2, and save safely ---
from pathlib import Path
import re, collections, math, shutil, tempfile
import nibabel as nib, numpy as np, torch
from tqdm import tqdm
from sklearn.neighbors import kneighbors_graph
from skimage.transform import resize
from torch_geometric.data import Data

# --- CONFIG: adjust if yours differ ---
GRAPH_DIR = Path("/content/drive/MyDrive/oasis_project/data/graphs")
RAW_ROOT  = Path("/content/drive/MyDrive/oasis_project/data/raw/oasis2")
PART_NAME = "OAS2_RAW_PART2"   # folder name inside RAW_ROOT created when archive extracted
PART_OUT  = GRAPH_DIR / "oasis2_graphs_labeled_part2.pt"
CLIN_XLSX = Path("/content/drive/MyDrive/oasis_project/data/demographics/oasis2_demographics.xlsx")

# locate extracted folder for part2
candidates = list(RAW_ROOT.rglob(f"*{PART_NAME}*"))
if len(candidates)==0:
    raise FileNotFoundError(f"Could not find extracted folder for {PART_NAME} under {RAW_ROOT}. Check extraction path.")
extracted_dir = candidates[0]
print("Using extracted dir:", extracted_dir)

# load clinical meta and canonicalize subject ids (keep first two underscore parts)
import pandas as pd
meta = pd.read_excel(CLIN_XLSX)
# detect id column
id_col = next((c for c in meta.columns if any(tok in c.lower() for tok in ("id","subject","subj","ptid"))), meta.columns[0])
label_col = next((c for c in meta.columns if "cdr" in c.lower() or "label" in c.lower() or "dx" in c.lower()), None)
time_col = next((c for c in meta.columns if "visit" in c.lower() or "delay" in c.lower() or "date" in c.lower()), None)

def canonical_subject(s):
    s = str(s)
    parts = s.split('_')
    return "_".join(parts[:2]) if len(parts)>=2 else s

meta["subject_id"] = meta[id_col].astype(str).map(canonical_subject)
if label_col:
    meta["label"] = pd.to_numeric(meta[label_col], errors="coerce")
else:
    meta["label"] = np.nan
if time_col:
    meta["timepoint"] = pd.to_numeric(meta[time_col], errors="coerce")
else:
    meta["timepoint"] = np.nan

meta_index = meta.set_index("subject_id", drop=False)
print("Clinical subjects available:", meta["subject_id"].nunique())

# helper: preprocess + graphify (same logic as before)
def preprocess_mri(path, out_dir, target_shape=(96,96,96)):
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    img = nib.load(str(path))
    data = img.get_fdata().astype(np.float32)
    mask = data > np.percentile(data, 5)
    if mask.sum()==0:
        mask = data!=0
    mu = data[mask].mean() if mask.sum()>0 else data.mean()
    sd = data[mask].std() if mask.sum()>0 else data.std()
    data = (data - mu) / (sd + 1e-6)
    data = resize(data, target_shape, order=1, preserve_range=True, anti_aliasing=True)
    out_path = out_dir / f"{Path(path).stem}_preproc.nii.gz"
    nib.save(nib.Nifti1Image(data.astype(np.float32), affine=np.eye(4)), str(out_path))
    return out_path

def mri_to_graph(nifti_path, subject_id, label, timepoint, patch_size=3, n_neighbors=6, sample_n=400):
    data = nib.load(str(nifti_path)).get_fdata().astype(np.float32)
    data = (data - data.mean()) / (data.std() + 1e-6)
    coords_all = np.stack(np.meshgrid(np.arange(data.shape[0]),
                                      np.arange(data.shape[1]),
                                      np.arange(data.shape[2]),
                                      indexing='ij'), axis=-1).reshape(-1,3)
    mask = data.reshape(-1) > np.percentile(data.reshape(-1), 10)
    coords = coords_all[mask]
    if coords.shape[0] > sample_n:
        idx = np.linspace(0, coords.shape[0]-1, sample_n).astype(int)
        coords = coords[idx]
    feats = []; valid_coords=[]
    for (x,y,z) in coords:
        x,y,z = int(x), int(y), int(z)
        x0,x1 = max(0,x-patch_size), min(data.shape[0], x+patch_size+1)
        y0,y1 = max(0,y-patch_size), min(data.shape[1], y+patch_size+1)
        z0,z1 = max(0,z-patch_size), min(data.shape[2], z+patch_size+1)
        patch = data[x0:x1, y0:y1, z0:z1]
        feats.append([patch.mean(), patch.std(), patch.min(), patch.max()])
        valid_coords.append([x,y,z])
    X = np.array(feats, dtype=np.float32)
    if len(valid_coords) > 1:
        A = kneighbors_graph(np.array(valid_coords), n_neighbors=min(n_neighbors, len(valid_coords)-1),
                             mode='connectivity', include_self=False)
        rows, cols = A.nonzero()
        edge_index = np.vstack([rows, cols]).astype(np.int64)
    else:
        edge_index = np.zeros((2,0), dtype=np.int64)
    g = Data(x=torch.tensor(X), edge_index=torch.tensor(edge_index, dtype=torch.long))
    g.y = torch.tensor([float(label) if not np.isnan(label) else np.nan], dtype=torch.float32)
    g.subject_id = str(subject_id)
    g.timepoint = float(timepoint) if (not pd.isna(timepoint)) else 0.0
    g.source_archive = PART_NAME
    return g

# scan all image files under extracted_dir
all_imgs = sorted(list(extracted_dir.rglob("*.nii*")) + list(extracted_dir.rglob("*.img")) + list(extracted_dir.rglob("*.hdr")))
print("Total image-like files found:", len(all_imgs))

# match by searching parent directory names for pattern like 'OAS2_0100' or 'OAS2_0100_MR1'
pattern = re.compile(r"(OAS[12]_[0-9]{3,4})", flags=re.IGNORECASE)

matched_subject_files = {}   # sid -> first image path
skipped = []

for p in all_imgs:
    # search among parents for subject folder token
    sid_candidate = None
    for part in reversed(p.parts):  # iterate path parts from leaf to root
        m = pattern.search(part)
        if m:
            sid_candidate = m.group(1)
            break
    if sid_candidate is None:
        # also try filename
        m = pattern.search(p.name)
        if m:
            sid_candidate = m.group(1)
    if sid_candidate is None:
        skipped.append((p, "no_subject_token_in_path"))
        continue
    # canonicalize (keep first two underscore parts)
    parts = sid_candidate.split("_")
    sid_can = "_".join(parts[:2]) if len(parts)>=2 else sid_candidate
    # try match to clinical IDs (meta["subject_id"]) - exact or case-insensitive
    if sid_can in meta["subject_id"].values:
        if sid_can not in matched_subject_files:
            matched_subject_files[sid_can] = p
    else:
        # try case-insensitive / partial numeric match (e.g., '0100' -> 'OAS2_0100')
        digits = re.search(r"\d{3,4}", sid_candidate)
        if digits:
            token = digits.group(0)
            cand = meta[meta["subject_id"].str.contains(token, na=False)]
            if len(cand) > 0:
                sid_match = cand.iloc[0]["subject_id"]
                if sid_match not in matched_subject_files:
                    matched_subject_files[sid_match] = p
            else:
                skipped.append((p, f"no_clinical_match_for_{sid_candidate}"))
        else:
            skipped.append((p, f"candidate_{sid_candidate}_no_digits"))

print("Unique subjects matched to an image:", len(matched_subject_files))
print("Sample matches (up to 10):")
for k,v in list(matched_subject_files.items())[:10]:
    print(" ", k, "->", v.name)

# build graphs for matched subjects (one image per subject)
graphs_part = []
skipped_build = []
for sid, img_path in tqdm(matched_subject_files.items(), desc="Building graphs"):
    try:
        row = meta_index.loc[sid]
        label = row.get("label", np.nan)
        tp = row.get("timepoint", 0.0)
        preproc = preprocess_mri(img_path, PREPROC_DIR, target_shape=(96,96,96))
        g = mri_to_graph(preproc, sid, label, tp, patch_size=3, n_neighbors=6, sample_n=400)
        graphs_part.append(g)
    except Exception as e:
        skipped_build.append((sid, str(e)))

print("Built graphs:", len(graphs_part), "skipped builds:", len(skipped_build))
print("Top skip reasons (first 20):", skipped[:20] + skipped_build[:20])

# Save the PART file atomically into same GRAPH_DIR (avoid cross-device replace)
if len(graphs_part) > 0:
    tmp_fd, tmp_path = tempfile.mkstemp(suffix=".pt", dir=str(GRAPH_DIR))
    os.close(tmp_fd)
    torch.save(graphs_part, tmp_path)
    Path(tmp_path).replace(PART_OUT)
    print("Saved part graphs to:", PART_OUT)
else:
    print("No graphs built - part file not saved (graphs_part empty). Inspect skip reasons above.")


Using extracted dir: /content/drive/MyDrive/oasis_project/data/raw/oasis2/OAS2_RAW_PART2.tar
Clinical subjects available: 150
Total image-like files found: 1192
Unique subjects matched to an image: 68
Sample matches (up to 10):
  OAS2_0100 -> mpr-1.nifti.hdr
  OAS2_0101 -> mpr-1.nifti.hdr
  OAS2_0102 -> mpr-1.nifti.hdr
  OAS2_0103 -> mpr-1.nifti.hdr
  OAS2_0104 -> mpr-1.nifti.hdr
  OAS2_0105 -> mpr-1.nifti.hdr
  OAS2_0106 -> mpr-1.nifti.hdr
  OAS2_0108 -> mpr-1.nifti.hdr
  OAS2_0109 -> mpr-1.nifti.hdr
  OAS2_0111 -> mpr-1.nifti.hdr


Building graphs: 100%|██████████| 68/68 [01:12<00:00,  1.07s/it]

Built graphs: 0 skipped builds: 68
Top skip reasons (first 20): [('OAS2_0100', 'The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().'), ('OAS2_0101', 'The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().'), ('OAS2_0102', 'The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().'), ('OAS2_0103', 'The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().'), ('OAS2_0104', 'The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().'), ('OAS2_0105', 'The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().'), ('OAS2_0106', 'The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().'), ('OAS2_0108', 'The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().'), ('OAS2_0109', 'The truth value 




In [None]:
# Rebuild graphs_part safely for PART2; fixes ambiguous Series error by selecting first clinical row
from pathlib import Path
import tempfile, os, traceback
import nibabel as nib
import numpy as np
import torch
from tqdm import tqdm
from skimage.transform import resize
from sklearn.neighbors import kneighbors_graph
from torch_geometric.data import Data

GRAPH_DIR = Path("/content/drive/MyDrive/oasis_project/data/graphs")
PART_OUT  = GRAPH_DIR / "oasis2_graphs_labeled_part2.pt"
PREPROC_DIR = Path("/content/drive/MyDrive/oasis_project/data/preproc/oasis2")
PREPROC_DIR.mkdir(parents=True, exist_ok=True)

# load clinical meta if not in memory
import pandas as pd
CLIN_XLSX = Path("/content/drive/MyDrive/oasis_project/data/demographics/oasis2_demographics.xlsx")
meta = pd.read_excel(CLIN_XLSX) if 'meta' not in globals() else meta.copy()
# ensure canonical subject_id exists
if "subject_id" not in meta.columns:
    id_col = next((c for c in meta.columns if any(tok in c.lower() for tok in ("id","subject","subj","ptid"))), meta.columns[0])
    def canonical_subject(s):
        s = str(s)
        parts = s.split('_')
        return "_".join(parts[:2]) if len(parts)>=2 else s
    meta["subject_id"] = meta[id_col].astype(str).map(canonical_subject)
# optional: coerce label/timepoint names to 'label'/'timepoint' if present
if 'label' not in meta.columns and any("cdr" in c.lower() for c in meta.columns):
    cdr_col = next(c for c in meta.columns if "cdr" in c.lower())
    meta['label'] = pd.to_numeric(meta[cdr_col], errors='coerce')
if 'timepoint' not in meta.columns:
    # try 'delay' or date-like columns
    tp_col = next((c for c in meta.columns if any(tok in c.lower() for tok in ("delay","visit","time","date","session"))), None)
    if tp_col is not None:
        meta['timepoint'] = pd.to_numeric(meta[tp_col], errors='coerce')
    else:
        meta['timepoint'] = np.nan

meta_index = meta.set_index("subject_id", drop=False)

# matched_subject_files dictionary should exist from prior run.
# If not present, attempt to recompute by scanning extracted dir for the PART folder pattern:
if 'matched_subject_files' not in globals() or not matched_subject_files:
    print("matched_subject_files not found — attempting to compute it again from extracted PART folder.")
    RAW_ROOT = Path("/content/drive/MyDrive/oasis_project/data/raw/oasis2")
    PART_NAME = "OAS2_RAW_PART2"
    candidates = list(RAW_ROOT.rglob(f"*{PART_NAME}*"))
    if len(candidates)==0:
        raise FileNotFoundError(f"Could not locate extracted folder for {PART_NAME} under {RAW_ROOT}")
    extracted_dir = candidates[0]
    all_imgs = sorted(list(extracted_dir.rglob("*.nii*")) + list(extracted_dir.rglob("*.img")) + list(extracted_dir.rglob("*.hdr")))
    import re
    pattern = re.compile(r"(OAS2?_[0-9]{3,4})", flags=re.IGNORECASE)
    matched_subject_files = {}
    for p in all_imgs:
        sid_candidate = None
        for part in reversed(p.parts):
            m = pattern.search(part)
            if m:
                sid_candidate = m.group(1)
                break
        if sid_candidate:
            sid_can = "_".join(sid_candidate.split("_")[:2])
            # match to clinical index by exact or by numeric token
            if sid_can in meta["subject_id"].values:
                matched_subject_files.setdefault(sid_can, p)
            else:
                digits = re.search(r"\d{3,4}", sid_candidate)
                if digits:
                    token = digits.group(0)
                    cand = meta[meta["subject_id"].str.contains(token, na=False)]
                    if len(cand)>0:
                        matched_subject_files.setdefault(cand.iloc[0]["subject_id"], p)
    print("Recomputed matched_subject_files size:", len(matched_subject_files))

print(f"Processing {len(matched_subject_files)} matched subjects into graphs...")

# helper functions (same as before)
def preprocess_mri(path, out_dir, target_shape=(96,96,96)):
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    img = nib.load(str(path))
    data = img.get_fdata().astype(np.float32)
    mask = data > np.percentile(data, 5)
    if mask.sum()==0:
        mask = data != 0
    mu = data[mask].mean() if mask.sum()>0 else data.mean()
    sd = data[mask].std() if mask.sum()>0 else data.std()
    data = (data - mu) / (sd + 1e-6)
    data = resize(data, target_shape, order=1, preserve_range=True, anti_aliasing=True)
    out_path = out_dir / f"{Path(path).stem}_preproc.nii.gz"
    nib.save(nib.Nifti1Image(data.astype(np.float32), affine=np.eye(4)), str(out_path))
    return out_path

def mri_to_graph(nifti_path, subject_id, label, timepoint, patch_size=3, n_neighbors=6, sample_n=400):
    data = nib.load(str(nifti_path)).get_fdata().astype(np.float32)
    data = (data - data.mean()) / (data.std() + 1e-6)
    coords_all = np.stack(np.meshgrid(np.arange(data.shape[0]),
                                      np.arange(data.shape[1]),
                                      np.arange(data.shape[2]),
                                      indexing='ij'), axis=-1).reshape(-1,3)
    mask = data.reshape(-1) > np.percentile(data.reshape(-1), 10)
    coords = coords_all[mask]
    if coords.shape[0] > sample_n:
        idx = np.linspace(0, coords.shape[0]-1, sample_n).astype(int)
        coords = coords[idx]
    feats = []; valid_coords=[]
    for (x,y,z) in coords:
        x,y,z = int(x), int(y), int(z)
        x0,x1 = max(0,x-patch_size), min(data.shape[0], x+patch_size+1)
        y0,y1 = max(0,y-patch_size), min(data.shape[1], y+patch_size+1)
        z0,z1 = max(0,z-patch_size), min(data.shape[2], z+patch_size+1)
        patch = data[x0:x1, y0:y1, z0:z1]
        feats.append([patch.mean(), patch.std(), patch.min(), patch.max()])
        valid_coords.append([x,y,z])
    X = np.array(feats, dtype=np.float32)
    if len(valid_coords) > 1:
        A = kneighbors_graph(np.array(valid_coords), n_neighbors=min(n_neighbors, len(valid_coords)-1),
                             mode='connectivity', include_self=False)
        rows, cols = A.nonzero()
        edge_index = np.vstack([rows, cols]).astype(np.int64)
    else:
        edge_index = np.zeros((2,0), dtype=np.int64)
    g = Data(x=torch.tensor(X), edge_index=torch.tensor(edge_index, dtype=torch.long))
    g.y = torch.tensor([float(label) if not np.isnan(label) else np.nan], dtype=torch.float32)
    g.subject_id = str(subject_id)
    g.timepoint = float(timepoint) if (not pd.isna(timepoint)) else 0.0
    g.source_archive = "OAS2_RAW_PART2"
    return g

# Build graphs_part with robust row selection
graphs_part = []
skipped_build = []
for sid, img_path in tqdm(matched_subject_files.items(), desc="Building graphs (fixed)"):
    try:
        # select first matching clinical row (avoid ambiguous Series issue)
        rows = meta[meta["subject_id"] == sid]
        if len(rows) == 0:
            skipped_build.append((sid, "no_clinical_row"))
            continue
        row = rows.iloc[0]   # pick first row if there are duplicates
        label = row.get("label", np.nan)
        tp = row.get("timepoint", np.nan)
        preproc = preprocess_mri(img_path, PREPROC_DIR, target_shape=(96,96,96))
        g = mri_to_graph(preproc, sid, label, tp, patch_size=3, n_neighbors=6, sample_n=400)
        graphs_part.append(g)
    except Exception as e:
        skipped_build.append((sid, str(e)))
        traceback.print_exc()

print("Built graphs:", len(graphs_part), "skipped builds:", len(skipped_build))
if skipped_build:
    print("Sample skipped build reasons (up to 20):")
    for s in skipped_build[:20]:
        print(" ", s)

# Save PART file atomically inside GRAPH_DIR to avoid cross-device replace
if len(graphs_part) > 0:
    fd, tmp_path = tempfile.mkstemp(suffix=".pt", dir=str(GRAPH_DIR))
    os.close(fd)
    torch.save(graphs_part, tmp_path)
    Path(tmp_path).replace(PART_OUT)
    print("Saved PART file to:", PART_OUT)
else:
    print("No graphs built; PART file not saved. Inspect skipped_build above.")


Processing 68 matched subjects into graphs...


Building graphs (fixed): 100%|██████████| 68/68 [01:13<00:00,  1.08s/it]

Built graphs: 68 skipped builds: 0
Saved PART file to: /content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_part2.pt





In [None]:
# Merge PART into MASTER safely, dedupe by (subject_id, timepoint), prefer labeled/new graphs.
from pathlib import Path
import os, tempfile, shutil
import torch, numpy as np
from collections import Counter

GRAPH_DIR = Path("/content/drive/MyDrive/oasis_project/data/graphs")
PART_PT = GRAPH_DIR / "oasis2_graphs_labeled_part2.pt"
MASTER_PT = GRAPH_DIR / "oasis2_graphs_labeled_auto_full.pt"

assert PART_PT.exists(), f"Part file missing: {PART_PT}"
assert MASTER_PT.exists(), f"Master file missing: {MASTER_PT}"

# Load with PyG safe context if possible
def safe_load(p):
    try:
        import torch_geometric
        with torch.serialization.safe_globals([torch_geometric.data.storage.GlobalStorage]):
            return torch.load(p, map_location="cpu")
    except Exception:
        return torch.load(p, map_location="cpu", weights_only=False)

master = safe_load(MASTER_PT)
part   = safe_load(PART_PT)
print("Loaded master count:", len(master))
print("Loaded part count:", len(part))

# Build keyed dict: key = (subject_id, timepoint)
def get_key(g, idx):
    sid = getattr(g, "subject_id", None)
    if sid is None:
        sid = f"__idx__{idx}"
    tp = getattr(g, "timepoint", None)
    try:
        tp_f = float(tp) if tp is not None else 0.0
    except Exception:
        tp_f = 0.0
    return (str(sid), tp_f)

merged = {}
source_of = {}  # track which source (master or part) provided the entry

# add existing master first (lower priority)
for i,g in enumerate(master):
    merged[get_key(g,i)] = g
    source_of[get_key(g,i)] = "master"

# merge part: override according to policy (prefer labeled; else prefer part)
added=0; replaced=0; kept=0
for i,g in enumerate(part):
    k = get_key(g, i)
    try:
        new_y = float(g.y.view(-1).cpu().numpy()[0])
    except Exception:
        try:
            new_y = float(g.y)
        except Exception:
            new_y = np.nan
    if k not in merged:
        merged[k] = g
        source_of[k] = "part"
        added += 1
    else:
        prev = merged[k]
        try:
            prev_y = float(prev.y.view(-1).cpu().numpy()[0])
        except Exception:
            try:
                prev_y = float(prev.y)
            except Exception:
                prev_y = np.nan
        # decide: prefer graph with a label; if both labeled or both unlabeled, prefer the part (new)
        prev_has = not np.isnan(prev_y)
        new_has = not np.isnan(new_y)
        if new_has and not prev_has:
            merged[k] = g; source_of[k] = "part"; replaced += 1
        elif new_has == prev_has:
            # prefer newer (part)
            merged[k] = g; source_of[k] = "part"; replaced += 1
        else:
            kept += 1

merged_list = list(merged.values())
print(f"Merged master count (after): {len(merged_list)} (added {added}, replaced {replaced}, kept {kept})")

# Atomic save: write temp inside GRAPH_DIR to avoid cross-device link error
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".pt", dir=str(GRAPH_DIR))
os.close(tmp_fd)
torch.save(merged_list, tmp_path)
os.replace(tmp_path, str(MASTER_PT))
print("Master saved to:", MASTER_PT)

# Summary: show a few changes
from collections import defaultdict
count_by_source = Counter(source_of.values())
print("Entries by source after merge (sample):", dict(count_by_source))

# Show subject IDs that were added from the part (up to 20)
added_sids = [k[0] for k,v in source_of.items() if v=="part"]
print("Sample subject IDs now coming from part (up to 20):", added_sids[:20])

# Quick sanity: label distribution in merged master
labels=[]
for g in merged_list:
    try:
        labels.append(float(g.y.view(-1).cpu().numpy()[0]))
    except Exception:
        try:
            labels.append(float(g.y))
        except Exception:
            labels.append(np.nan)
lbl_counts = Counter([l for l in labels if not np.isnan(l)])
print("Label counts (non-NaN) in merged master (top items):")
for k,v in sorted(lbl_counts.items()):
    print(" ", k, ":", v)


Loaded master count: 209
Loaded part count: 68
Merged master count (after): 277 (added 68, replaced 0, kept 0)
Master saved to: /content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_auto_full.pt
Entries by source after merge (sample): {'master': 209, 'part': 68}
Sample subject IDs now coming from part (up to 20): ['OAS2_0100', 'OAS2_0101', 'OAS2_0102', 'OAS2_0103', 'OAS2_0104', 'OAS2_0105', 'OAS2_0106', 'OAS2_0108', 'OAS2_0109', 'OAS2_0111', 'OAS2_0112', 'OAS2_0113', 'OAS2_0114', 'OAS2_0116', 'OAS2_0117', 'OAS2_0118', 'OAS2_0119', 'OAS2_0120', 'OAS2_0121', 'OAS2_0122']
Label counts (non-NaN) in merged master (top items):
  0.0 : 144
  0.5 : 102
  1.0 : 28
  2.0 : 3


In [None]:
# Merge helper: call this after you save oasis2_graphs_labeled_partX.pt for a new part
from pathlib import Path
import os, tempfile, shutil
import torch, numpy as np
from collections import Counter

GRAPH_DIR = Path("/content/drive/MyDrive/oasis_project/data/graphs")
MASTER_PT = GRAPH_DIR / "oasis2_graphs_labeled_auto_full.pt"
GLOBAL_PT = GRAPH_DIR / "all_graphs_master.pt"

def safe_load(p):
    p = Path(p)
    if not p.exists():
        raise FileNotFoundError(f"File not found: {p}")
    try:
        import torch_geometric
        with torch.serialization.safe_globals([torch_geometric.data.storage.GlobalStorage]):
            return torch.load(p, map_location="cpu")
    except Exception:
        return torch.load(p, map_location="cpu", weights_only=False)

def merge_part_to_master(part_pt_path):
    part_pt = Path(part_pt_path)
    assert part_pt.exists(), f"Part file not found: {part_pt}"
    print("Loading files...")
    master = safe_load(MASTER_PT) if MASTER_PT.exists() else []
    part   = safe_load(part_pt)
    print(f"Master count: {len(master)}, Part count: {len(part)}")

    def key_for(g, idx):
        sid = getattr(g, "subject_id", None) or f"__idx__{idx}"
        tp = getattr(g, "timepoint", None)
        try:
            tpf = float(tp) if tp is not None else 0.0
        except Exception:
            tpf = 0.0
        return (str(sid), tpf)

    merged = {}
    source = {}

    # add existing
    for i,g in enumerate(master):
        merged[key_for(g,i)] = g
        source[key_for(g,i)] = "master"

    added = replaced = kept = 0
    # merge part entries
    for i,g in enumerate(part):
        k = key_for(g,i)
        try:
            new_y = float(g.y.view(-1).cpu().numpy()[0])
        except Exception:
            try: new_y = float(g.y)
            except Exception: new_y = np.nan
        if k not in merged:
            merged[k] = g; source[k] = "part"; added += 1
        else:
            prev = merged[k]
            try:
                prev_y = float(prev.y.view(-1).cpu().numpy()[0])
            except Exception:
                try: prev_y = float(prev.y)
                except Exception: prev_y = np.nan
            prev_has = not np.isnan(prev_y)
            new_has = not np.isnan(new_y)
            # prefer labeled; if same preference, prefer the part (new)
            if (new_has and not prev_has) or (new_has == prev_has):
                merged[k] = g; source[k] = "part"; replaced += 1
            else:
                kept += 1

    merged_list = list(merged.values())
    print(f"Merged count: {len(merged_list)} (added {added}, replaced {replaced}, kept {kept})")

    # atomic save inside GRAPH_DIR to avoid cross-device error
    tmp_fd, tmp_path = tempfile.mkstemp(suffix=".pt", dir=str(GRAPH_DIR))
    os.close(tmp_fd)
    torch.save(merged_list, tmp_path)
    os.replace(tmp_path, str(MASTER_PT))
    print("Master saved to:", MASTER_PT)

    # append to global master (no dedupe) and save
    global_existing = safe_load(GLOBAL_PT) if GLOBAL_PT.exists() else []
    global_combined = list(global_existing) + list(part)
    tmp_fd2, tmp_path2 = tempfile.mkstemp(suffix=".pt", dir=str(GRAPH_DIR))
    os.close(tmp_fd2)
    torch.save(global_combined, tmp_path2)
    os.replace(tmp_path2, str(GLOBAL_PT))
    print("Global master updated:", GLOBAL_PT)

    # short summary of labels
    labels = []
    for g in merged_list:
        try: labels.append(float(g.y.view(-1).cpu().numpy()[0]))
        except Exception:
            try: labels.append(float(g.y))
            except Exception: labels.append(np.nan)
    from collections import Counter
    print("Label counts (non-NaN) in merged master:", dict(Counter([l for l in labels if not np.isnan(l)])) )

# Example usage:
# merge_part_to_master("/content/drive/MyDrive/oasis_project/data/graphs/oasis2_graphs_labeled_part3.pt")
