In [None]:
# Cell A — mount and imports
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

from pathlib import Path
import tarfile, zipfile, shutil, fnmatch, os
import nibabel as nib
import pandas as pd
from tqdm import tqdm



Mounted at /content/drive


In [None]:
# Cell B — config (edit archive_path if you uploaded the tar to a known location)
PROJECT_ROOT = Path('/MyDrive/oasis_project')
RAW_ROOT = PROJECT_ROOT / 'data' / 'raw'
OASIS2_OUT = RAW_ROOT / 'oasis2'
OASIS2_OUT.mkdir(parents=True, exist_ok=True)

# If you uploaded OASIS tar to Drive, set that path:
ARCHIVE_PATH = Path('/content/drive/MyDrive/OAS2_RAW_PART1.tar.gz')   # <- edit if different
# Alternative: search Drive for anything matching OAS2_RAW*
DRIVE_SEARCH_BASE = Path('/content/drive/MyDrive')


In [None]:
# Cell C — helper to inspect archive contents (dry-run)
def list_archive_members(archive_path):
    if zipfile.is_zipfile(archive_path):
        with zipfile.ZipFile(archive_path, 'r') as zf:
            return zf.namelist()
    else:
        with tarfile.open(archive_path, 'r:*') as tf:
            return [m.name for m in tf.getmembers()]
# Example usage: print first 30 entries
members = list_archive_members(ARCHIVE_PATH)
print("Archive sample (first 30):")
for m in members[:30]:
    print(m)


Archive sample (first 30):
OAS2_RAW_PART1
OAS2_RAW_PART1/OAS2_0001_MR1
OAS2_RAW_PART1/OAS2_0001_MR1/RAW
OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-3.nifti.hdr
OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-3.nifti.img
OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-2.nifti.img
OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-2.nifti.hdr
OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-1.nifti.img
OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-1.nifti.hdr
OAS2_RAW_PART1/OAS2_0001_MR2
OAS2_RAW_PART1/OAS2_0001_MR2/RAW
OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-1.nifti.hdr
OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-1.nifti.img
OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-2.nifti.hdr
OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-2.nifti.img
OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-3.nifti.hdr
OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-3.nifti.img
OAS2_RAW_PART1/OAS2_0002_MR1
OAS2_RAW_PART1/OAS2_0002_MR1/RAW
OAS2_RAW_PART1/OAS2_0002_MR1/RAW/mpr-1.nifti.hdr
OAS2_RAW_PART1/OAS2_0002_MR1/RAW/mpr-1.nifti.img
OAS2_RAW_PART1/OAS2_0002_MR1/RAW/mpr-2.nifti.hdr
OAS2_RAW_PART1/OAS2_0002_MR1/RAW/mpr-

In [None]:
# Cell D — extraction + convert .hdr/.img to .nii.gz (streams only needed files)
def extract_and_collect(archive_path: Path, out_base: Path):
    rows = []
    out_base.mkdir(parents=True, exist_ok=True)

    # helper to write file object to destination
    def _write_fileobj(fobj, dest_path: Path):
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        with open(dest_path, 'wb') as dst:
            shutil.copyfileobj(fobj, dst)

    if zipfile.is_zipfile(archive_path):
        with zipfile.ZipFile(archive_path, 'r') as zf:
            members = zf.namelist()
            for m in members:
                if 'mpr-1' in m and ('nifti' in m or m.endswith('.hdr') or m.endswith('.img') or m.endswith('.nii') or m.endswith('.nii.gz')):
                    # derive subject from path components
                    parts = Path(m).parts
                    subject = parts[0] if len(parts) > 0 else Path(m).stem
                    raw_dir = out_base / subject / 'RAW'
                    raw_dir.mkdir(parents=True, exist_ok=True)
                    # if it's .nii.gz, extract directly
                    if m.endswith('.nii') or m.endswith('.nii.gz'):
                        dest = raw_dir / Path(m).name
                        with zf.open(m) as f: _write_fileobj(f, dest)
                        rows.append((subject, str(dest), str(archive_path), 'native'))
                    # .hdr/.img will be handled after extracting both pieces
                    else:
                        with zf.open(m) as f:
                            _write_fileobj(f, raw_dir / Path(m).name)
    else:
        with tarfile.open(archive_path, 'r:*') as tf:
            for member in tf.getmembers():
                name = member.name
                if 'mpr-1' in name and ('nifti' in name or name.endswith('.hdr') or name.endswith('.img') or name.endswith('.nii') or name.endswith('.nii.gz')):
                    parts = Path(name).parts
                    subject = parts[0] if len(parts) > 0 else Path(name).stem
                    raw_dir = out_base / subject / 'RAW'
                    raw_dir.mkdir(parents=True, exist_ok=True)
                    if member.isdir():
                        continue
                    f = tf.extractfile(member)
                    if f is None:
                        continue
                    dest = raw_dir / Path(name).name
                    with open(dest, 'wb') as dst:
                        shutil.copyfileobj(f, dst)

    # after extraction, scan out_base for each subject and ensure we have .nii.gz
    for subj_dir in sorted(out_base.iterdir()):
        raw = subj_dir / 'RAW'
        if not raw.exists():
            continue
        # prefer existing .nii.gz
        nii_candidates = list(raw.glob('**/*mpr-1*.nii*'))  # catches .nii and .nii.gz and weird names
        if len(nii_candidates) > 0:
            dest = nii_candidates[0]
            rows.append((subj_dir.name, str(dest), str(archive_path), 'native'))
            continue
        # else look for .hdr/.img pair to convert
        hdrs = [p for p in raw.glob('*mpr-1*.hdr')]
        for hdr in hdrs:
            img_file = hdr.with_suffix('.img')
            if not img_file.exists():
                # sometimes extension is .nifti.img etc — try any .img
                imgs = list(raw.glob('*mpr-1*.img'))
                img_file = imgs[0] if imgs else None
            if img_file and img_file.exists():
                try:
                    nii_out = hdr.with_suffix('.nii.gz')
                    nii = nib.load(str(hdr))
                    nib.save(nii, str(nii_out))
                    rows.append((subj_dir.name, str(nii_out), str(archive_path), 'converted'))
                except Exception as e:
                    print("Conversion failed for", hdr, ":", e)
    return rows

# Run extraction (example)
archive = ARCHIVE_PATH
collected = extract_and_collect(archive, OASIS2_OUT)
print(f"Collected {len(collected)} entries")


Collected 1 entries


In [None]:
# Cell E — build & save manifest
manifest_df = pd.DataFrame(collected, columns=['subject_id','filepath','source_archive','note'])
manifest_df.to_csv(RAW_ROOT / 'oasis2_manifest.csv', index=False)
manifest_df.head()


Unnamed: 0,subject_id,filepath,source_archive,note
0,OAS2_RAW_PART1,/MyDrive/oasis_project/data/raw/oasis2/OAS2_RA...,/content/drive/MyDrive/OAS2_RAW_PART1.tar.gz,converted


B

In [None]:
from pathlib import Path
p = Path('/content/drive/MyDrive/oasis2_graph_dataset.pt')
if p.exists():
    import torch
    # Explicitly set weights_only=False to load the file
    graphs = torch.load(str(p), weights_only=False)
    print("Loaded precomputed graphs:", len(graphs))
else:
    print("No precomputed graph dataset found — run graph building.")



Loaded precomputed graphs: 209


C

In [None]:
from pathlib import Path
import tarfile, zipfile, itertools

ARCHIVE_PATH = Path('/content/drive/MyDrive/OAS2_RAW_PART1.tar.gz')  # <- adjust if needed

def list_archive_sample(archive_path, n=200):
    print("Archive:", archive_path)
    if zipfile.is_zipfile(archive_path):
        members = zipfile.ZipFile(archive_path, 'r').namelist()
    else:
        with tarfile.open(archive_path, 'r:*') as tf:
            members = [m.name for m in tf.getmembers()]
    print(f"Total entries in archive: {len(members)}")
    # Print first n entries (or all if smaller)
    for i, m in enumerate(members[:n]):
        print(f"{i+1:04d}", m)
    return members

members = list_archive_sample(ARCHIVE_PATH, n=300)

Archive: /content/drive/MyDrive/OAS2_RAW_PART1.tar.gz
Total entries in archive: 2353
0001 OAS2_RAW_PART1
0002 OAS2_RAW_PART1/OAS2_0001_MR1
0003 OAS2_RAW_PART1/OAS2_0001_MR1/RAW
0004 OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-3.nifti.hdr
0005 OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-3.nifti.img
0006 OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-2.nifti.img
0007 OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-2.nifti.hdr
0008 OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-1.nifti.img
0009 OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-1.nifti.hdr
0010 OAS2_RAW_PART1/OAS2_0001_MR2
0011 OAS2_RAW_PART1/OAS2_0001_MR2/RAW
0012 OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-1.nifti.hdr
0013 OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-1.nifti.img
0014 OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-2.nifti.hdr
0015 OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-2.nifti.img
0016 OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-3.nifti.hdr
0017 OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-3.nifti.img
0018 OAS2_RAW_PART1/OAS2_0002_MR1
0019 OAS2_RAW_PART1/OAS2_0002_MR1/RAW
0020 OAS2_RAW_PART1/OAS2_0002_M

D

In [None]:
from pathlib import Path
import tarfile, zipfile

ARCHIVE_PATH = Path('/content/drive/MyDrive/OAS2_RAW_PART1.tar.gz')  # adjust

def count_mpr_entries_in_archive(archive_path):
    cnt = 0
    paths = []
    if zipfile.is_zipfile(archive_path):
        with zipfile.ZipFile(archive_path, 'r') as zf:
            for m in zf.namelist():
                if 'mpr-1' in m and m.lower().endswith(('.nii', '.nii.gz', '.hdr', '.img')):
                    cnt += 1
                    paths.append(m)
    else:
        with tarfile.open(archive_path, 'r:*') as tf:
            for m in tf.getmembers():
                if 'mpr-1' in m.name and m.name.lower().endswith(('.nii', '.nii.gz', '.hdr', '.img')):
                    cnt += 1
                    paths.append(m.name)
    return cnt, paths

cnt, mpr_paths = count_mpr_entries_in_archive(ARCHIVE_PATH)
print("mpr-1 entries inside archive:", cnt)
print("Examples (up to 30):")
for p in mpr_paths[:30]:
    print(" ", p)


mpr-1 entries inside archive: 418
Examples (up to 30):
  OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-1.nifti.img
  OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-1.nifti.hdr
  OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-1.nifti.hdr
  OAS2_RAW_PART1/OAS2_0001_MR2/RAW/mpr-1.nifti.img
  OAS2_RAW_PART1/OAS2_0002_MR1/RAW/mpr-1.nifti.hdr
  OAS2_RAW_PART1/OAS2_0002_MR1/RAW/mpr-1.nifti.img
  OAS2_RAW_PART1/OAS2_0002_MR2/RAW/mpr-1.nifti.hdr
  OAS2_RAW_PART1/OAS2_0002_MR2/RAW/mpr-1.nifti.img
  OAS2_RAW_PART1/OAS2_0002_MR3/RAW/mpr-1.nifti.hdr
  OAS2_RAW_PART1/OAS2_0002_MR3/RAW/mpr-1.nifti.img
  OAS2_RAW_PART1/OAS2_0004_MR1/RAW/mpr-1.nifti.hdr
  OAS2_RAW_PART1/OAS2_0004_MR1/RAW/mpr-1.nifti.img
  OAS2_RAW_PART1/OAS2_0004_MR2/RAW/mpr-1.nifti.hdr
  OAS2_RAW_PART1/OAS2_0004_MR2/RAW/mpr-1.nifti.img
  OAS2_RAW_PART1/OAS2_0005_MR1/RAW/mpr-1.nifti.hdr
  OAS2_RAW_PART1/OAS2_0005_MR1/RAW/mpr-1.nifti.img
  OAS2_RAW_PART1/OAS2_0005_MR2/RAW/mpr-1.nifti.hdr
  OAS2_RAW_PART1/OAS2_0005_MR2/RAW/mpr-1.nifti.img
  OAS2_RAW_PART1/OAS2_0005_

E

In [None]:
from pathlib import Path
import tarfile, zipfile, shutil

ARCHIVE_PATH = Path('/content/drive/MyDrive/OAS2_RAW_PART1.tar.gz')  # adjust
OUT_BASE = Path('/MyDrive/oasis_project/data/raw/oasis2_full_extracted')  # new folder to avoid clobbering current
OUT_BASE.mkdir(parents=True, exist_ok=True)

def extract_full_archive(archive_path: Path, out_base: Path, overwrite=False):
    print("Extracting", archive_path, "->", out_base)
    if zipfile.is_zipfile(archive_path):
        with zipfile.ZipFile(archive_path, 'r') as zf:
            for member in zf.infolist():
                target = out_base / member.filename
                if target.exists() and not overwrite:
                    continue
                # ensure parent dir exists
                target.parent.mkdir(parents=True, exist_ok=True)
                if member.is_dir():
                    continue
                with zf.open(member) as src, open(target, 'wb') as dst:
                    shutil.copyfileobj(src, dst)
    else:
        with tarfile.open(archive_path, 'r:*') as tf:
            tf.extractall(path=str(out_base))
    print("Extraction complete.")

extract_full_archive(ARCHIVE_PATH, OUT_BASE, overwrite=False)


Extracting /content/drive/MyDrive/OAS2_RAW_PART1.tar.gz -> /MyDrive/oasis_project/data/raw/oasis2_full_extracted


  tf.extractall(path=str(out_base))


Extraction complete.


In [None]:
# Convert .hdr/.img pairs to .nii.gz and build oasis2_manifest.csv
from pathlib import Path
import nibabel as nib
import pandas as pd
import shutil
import sys

EXTRACTED_BASE = Path('/MyDrive/oasis_project/data/raw/oasis2_full_extracted')  # where you extracted archive
CANONICAL_OUT = Path('/MyDrive/oasis_project/data/raw/oasis2')  # desired canonical layout
CANONICAL_OUT.mkdir(parents=True, exist_ok=True)

# source archive name for manifest bookkeeping (optional)
SOURCE_ARCHIVE = 'OAS2_RAW_PART1'  # change if you want the exact path string

hdr_paths = sorted(EXTRACTED_BASE.glob('**/*mpr-1*.hdr'))  # find hdr files
print("Found .hdr files:", len(hdr_paths))

manifest_rows = []

def safe_load_and_save(hdr_path: Path, out_nii_path: Path):
    """Load hdr/img pair using nibabel (point at .hdr) and save compressed .nii.gz"""
    try:
        img = nib.load(str(hdr_path))  # nibabel can load the hdr and find the img
        out_nii_path.parent.mkdir(parents=True, exist_ok=True)
        nib.save(img, str(out_nii_path))
        return True, None
    except Exception as e:
        return False, str(e)

# Process hdrs
skipped_existing = 0
converted = 0
failed = 0
for hdr in hdr_paths:
    # Typical path: .../OAS2_RAW_PART1/OAS2_0001_MR1/RAW/mpr-1.nifti.hdr
    raw_dir = hdr.parent
    # choose canonical subject folder name from path parts: find the folder that starts with 'OAS2_' or 'OAS2_'
    parts = hdr.parts
    subj = None
    for p in parts[::-1]:  # walk backwards to find OAS2_...
        if p.upper().startswith('OAS2_') or p.upper().startswith('OAS2'):
            subj = p
            break
    if subj is None:
        # fallback: use parent folder name two levels up if pattern not found
        subj = hdr.parents[1].name if len(hdr.parents) >= 2 else hdr.parent.name

    # Destination: keep same subject structure under CANONICAL_OUT
    dest_raw_dir = CANONICAL_OUT / subj / 'RAW'
    dest_raw_dir.mkdir(parents=True, exist_ok=True)

    # choose output filename (normalize to mpr-1.nii.gz)
    out_nii = dest_raw_dir / 'mpr-1.nii.gz'
    # If a native .nii or .nii.gz already exists in this RAW folder, skip conversion
    existing_nii_candidates = list(dest_raw_dir.glob('*mpr-1*.nii*')) + list(raw_dir.glob('*mpr-1*.nii*'))
    if existing_nii_candidates and out_nii.exists():
        # Already have expected .nii.gz, skip
        skipped_existing += 1
        manifest_rows.append((subj, str(out_nii), SOURCE_ARCHIVE, 'native-existing'))
        continue
    elif existing_nii_candidates and not out_nii.exists():
        # Move existing .nii file from original raw_dir (if present) to canonical place
        moved = False
        for src in existing_nii_candidates:
            try:
                shutil.copy2(src, out_nii)  # copy; keep original
                manifest_rows.append((subj, str(out_nii), SOURCE_ARCHIVE, 'native-moved'))
                moved = True
                break
            except Exception as e:
                continue
        if moved:
            continue

    # If we reached here, attempt conversion from hdr (nibabel will look for the .img in same folder)
    success, err = safe_load_and_save(hdr, out_nii)
    if success:
        converted += 1
        manifest_rows.append((subj, str(out_nii), SOURCE_ARCHIVE, 'converted'))
    else:
        failed += 1
        print("Conversion failed for", hdr, "error:", err)
        manifest_rows.append((subj, '', SOURCE_ARCHIVE, f'failed:{err}'))

print(f"Converted: {converted}, Skipped existing: {skipped_existing}, Failed: {failed}")
# There may be duplicate rows per subject if multiple sessions (MR1, MR2). We'll keep them all for now.

# De-duplicate manifest by filepath (keep first)
manifest_df = pd.DataFrame(manifest_rows, columns=['subject_id','filepath','source_archive','note'])
manifest_df = manifest_df.drop_duplicates(subset=['filepath']).reset_index(drop=True)

# Save manifest
out_manifest = CANONICAL_OUT.parent / 'oasis2_manifest.csv'
manifest_df.to_csv(out_manifest, index=False)
print("Wrote manifest to", out_manifest)
print("Manifest rows:", len(manifest_df))
manifest_df.head(20)


Found .hdr files: 209
Converted: 209, Skipped existing: 0, Failed: 0
Wrote manifest to /MyDrive/oasis_project/data/raw/oasis2_manifest.csv
Manifest rows: 209


Unnamed: 0,subject_id,filepath,source_archive,note
0,OAS2_0001_MR1,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
1,OAS2_0001_MR2,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
2,OAS2_0002_MR1,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
3,OAS2_0002_MR2,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
4,OAS2_0002_MR3,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
5,OAS2_0004_MR1,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
6,OAS2_0004_MR2,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
7,OAS2_0005_MR1,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
8,OAS2_0005_MR2,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
9,OAS2_0005_MR3,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted


In [None]:
from pathlib import Path
OASIS2_CANON = Path('/MyDrive/oasis_project/data/raw/oasis2')
subs = sorted([p for p in OASIS2_CANON.iterdir() if p.is_dir()])
print("Subject folders found:", len(subs))
mpr_files = list(OASIS2_CANON.glob('**/RAW/*mpr-1*.nii*'))
print("mpr-1 files found (glob):", len(mpr_files))
for p in mpr_files[:20]:
    print(" ", p)


Subject folders found: 210
mpr-1 files found (glob): 210
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0081_MR2/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0068_MR2/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0055_MR1/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0037_MR1/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0048_MR4/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0047_MR1/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0051_MR2/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0017_MR5/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0005_MR2/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0031_MR2/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0017_MR3/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0034_MR4/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0061_MR2/RAW/mpr-1.nii.gz
  /MyDr

In [None]:
from pathlib import Path
import pandas as pd

CANON = Path('/MyDrive/oasis_project/data/raw/oasis2')
manifest = Path('/MyDrive/oasis_project/data/raw/oasis2_manifest.csv')

print("Canonical folder exists?", CANON.exists())
if manifest.exists():
    df = pd.read_csv(manifest)
    print("Manifest rows:", len(df))
    print("Unique subjects:", df['subject_id'].nunique())
    print("Files exist on disk?:", df['filepath'].apply(lambda x: Path(x).exists()).all())
    display(df.head(8))
else:
    print("No manifest found at", manifest)

mprs = list(CANON.glob('**/RAW/*mpr-1*.nii*')) if CANON.exists() else []
print("mpr-1 files found (glob):", len(mprs))
for p in mprs[:10]:
    print(" ", p)


Canonical folder exists? True
Manifest rows: 209
Unique subjects: 209
Files exist on disk?: True


Unnamed: 0,subject_id,filepath,source_archive,note
0,OAS2_0001_MR1,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
1,OAS2_0001_MR2,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
2,OAS2_0002_MR1,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
3,OAS2_0002_MR2,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
4,OAS2_0002_MR3,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
5,OAS2_0004_MR1,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
6,OAS2_0004_MR2,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted
7,OAS2_0005_MR1,/MyDrive/oasis_project/data/raw/oasis2/OAS2_00...,OAS2_RAW_PART1,converted


mpr-1 files found (glob): 210
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0081_MR2/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0068_MR2/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0055_MR1/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0037_MR1/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0048_MR4/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0047_MR1/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0051_MR2/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0017_MR5/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0005_MR2/RAW/mpr-1.nii.gz
  /MyDrive/oasis_project/data/raw/oasis2/OAS2_0031_MR2/RAW/mpr-1.nii.gz


In [None]:
import torch
from pathlib import Path
p = Path('/content/drive/MyDrive/oasis2_graph_dataset.pt')
assert p.exists(), f"Precomputed graphs not found at {p}"
graphs = torch.load(str(p), weights_only=False)
print("Loaded graphs count:", len(graphs))
# show basic label distribution (best-effort)
import numpy as np
labels = []
for g in graphs:
    try:
        y = g.y.detach().cpu().numpy()
        labels.append(int(y.reshape(-1)[0]) if y.size==1 else int(np.argmax(y)))
    except Exception:
        labels.append(None)
unique, counts = np.unique([str(x) for x in labels], return_counts=True)
print("Label distribution (str):")
for u,c in zip(unique,counts): print(u, c)
# Inspect one graph
g0 = graphs[0]
print("Sample graph: n_nodes", g0.x.shape[0], "feat_dim", g0.x.shape[1])

Loaded graphs count: 209
Label distribution (str):
0 209
Sample graph: n_nodes 1575 feat_dim 50


In [None]:
# 02 quickstart: split + dataloader + 1-batch train (smoke test)
import torch, numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader

# load precomputed graphs
graphs = torch.load('/content/drive/MyDrive/oasis2_graph_dataset.pt', weights_only=False)
N = len(graphs)
indices = np.arange(N)

# labels (best-effort extraction)
labels = []
for g in graphs:
    try:
        y = g.y.detach().cpu().numpy()
        labels.append(int(y.reshape(-1)[0]) if y.size==1 else int(np.argmax(y)))
    except Exception:
        labels.append(-1)
labels = np.array(labels)

# split (stratify if >1 valid labels)
# Check if there are at least two unique labels greater than or equal to 0
valid_labels = labels[labels >= 0]
if len(np.unique(valid_labels)) > 1:
    train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42, stratify=labels)
else:
    # If only one or zero valid labels, split without stratification
    train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)


train_graphs = [graphs[i] for i in train_idx]
val_graphs = [graphs[i] for i in val_idx]
train_loader = DataLoader(train_graphs, batch_size=8, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=8, shuffle=False)
print("Train/val sizes:", len(train_graphs), len(val_graphs))

# tiny model (GraphSAGE)
from torch_geometric.nn import SAGEConv, global_mean_pool
import torch.nn.functional as F
class TinySAGE(torch.nn.Module):
    def __init__(self, in_ch, outc, hid=64): # Made outc a required positional argument
        super().__init__()
        self.conv1 = SAGEConv(in_ch, hid)
        self.conv2 = SAGEConv(hid, hid)
        self.lin = torch.nn.Linear(hid, outc)
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return self.lin(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sample = graphs[0]
in_ch = sample.x.shape[1]
# Set n_classes to at least 2 for CrossEntropyLoss, or the actual number of unique valid labels if > 1
unique_valid_labels = np.unique(labels[labels >= 0])
n_classes = max(2, len(unique_valid_labels))
model = TinySAGE(in_ch, n_classes, hid=64).to(device) # Pass n_classes as positional argument
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

# one training batch (smoke)
model.train()
for batch in train_loader:
    batch = batch.to(device)
    out = model(batch.x, batch.edge_index, batch.batch)
    y = batch.y.view(-1).to(device)
    loss = loss_fn(out, y)
    opt.zero_grad(); loss.backward(); opt.step()
    print("Smoke train loss:", loss.item())
    break

Train/val sizes: 167 42
Smoke train loss: 0.8018063902854919


