In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# --- Kaggle API setup ---
!mkdir -p ~/.kaggle
!mv /content/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Download LGG MRI dataset
!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation
!unzip -q lgg-mri-segmentation.zip -d /content/dataset


# --- Install dependencies ---
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install segmentation-models-pytorch albumentations

In [None]:
import os
import pandas as pd
import numpy as np
from glob import glob
from PIL import Image
import torch
from sklearn.model_selection import train_test_split
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from segmentation_models_pytorch.metrics import get_stats, iou_score, f1_score

In [None]:
# --- Build DataFrame ---
files_dir = '/content/dataset/lgg-mri-segmentation/kaggle_3m'
file_paths = glob(f'{files_dir}/*/*[0-9].tif')

def get_file_row(path):
    path_no_ext, ext = os.path.splitext(path)
    filename = os.path.basename(path)
    patient_id = '_'.join(filename.split('_')[:3])  # Patient ID = first 3 segments
    return [patient_id, path, f'{path_no_ext}_mask{ext}']

train_df = pd.DataFrame(
    (get_file_row(filename) for filename in file_paths),
    columns=['Patient', 'image_filename', 'mask_filename']
)

# Add race and gender
patient_info = pd.read_csv(os.path.join(files_dir, "data.csv"))
train_df = train_df.merge(
    patient_info[['Patient', 'race', 'gender']],
    on='Patient',
    how='left'
)

In [None]:
import os
import cv2
import numpy as np
import nibabel as nib
import pandas as pd
from PIL import Image
from sklearn.metrics import f1_score, jaccard_score
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

# --- Load SAM once ---
sam_checkpoint = "/content/drive/MyDrive/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,            # controls mask density
    pred_iou_thresh=0.88,
    stability_score_thresh=0.95,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100       # ignore tiny masks
)

# --- Paths ---
drive_root = "/content/drive/MyDrive/SAM_results_auto"
os.makedirs(drive_root, exist_ok=True)


def run_all_auto(train_df):
    results = []

    for patient_id, group in train_df.groupby("Patient"):
        group = group.sort_values("image_filename")
        race = group.iloc[0].race
        gender = group.iloc[0].gender

        # Prepare volume
        first_slice = np.array(Image.open(group.iloc[0].image_filename).convert("L"))
        H, W = first_slice.shape
        Z = len(group)
        output_pred = np.zeros((H, W, Z), dtype=np.uint8)

        ious, f1s = [], []

        for z, row in enumerate(group.itertuples(index=False)):
            mri_file = row.image_filename
            mask_file = row.mask_filename

            # Load MRI + mask slice
            mri_slice = np.array(Image.open(mri_file).convert("L"))
            mask_slice = np.array(Image.open(mask_file).convert("L"))
            gt = (mask_slice > 0).astype(np.uint8)

            # Normalize MRI for SAM (grayscale → RGB)
            mri_slice_norm = cv2.normalize(mri_slice, None, 0, 255, cv2.NORM_MINMAX)
            mri_slice_rgb = cv2.cvtColor(mri_slice_norm.astype(np.uint8), cv2.COLOR_GRAY2RGB)

            # --- AUTO MODE ---
            masks = mask_generator.generate(mri_slice_rgb)

            if len(masks) > 0:
                # Take largest mask by area
                largest = max(masks, key=lambda x: x['area'])
                pred = largest['segmentation'].astype(np.uint8)
            else:
                pred = np.zeros_like(gt)

            output_pred[:, :, z] = pred

            # Metrics
            if gt.sum() > 0 or pred.sum() > 0:
                ious.append(jaccard_score(gt.flatten(), pred.flatten()))
                f1s.append(f1_score(gt.flatten(), pred.flatten()))

        # Save NIfTI prediction
        affine = np.eye(4)
        out_path = os.path.join(drive_root, f"{patient_id}_sam_auto_predictions.nii.gz")
        nib.save(nib.Nifti1Image(output_pred, affine), out_path)

        # Store metrics
        results.append({
            "Patient": patient_id,
            "Race": race,
            "Gender": gender,
            "IoU": np.mean(ious) if ious else 0,
            "F1": np.mean(f1s) if f1s else 0,
            "Prediction_Path": out_path
        })

    # Collect results
    results_df = pd.DataFrame(results)
    csv_path = os.path.join(drive_root, "sam_auto_patient_metrics.csv")
    results_df.to_csv(csv_path, index=False)

    # Summaries
    race_stats = results_df.groupby("Race")[["IoU", "F1"]].mean().reset_index()
    gender_stats = results_df.groupby("Gender")[["IoU", "F1"]].mean().reset_index()

    print(f"✅ Finished auto segmentation for all patients.\nCSV → {csv_path}")
    print("\nRace-level metrics:\n", race_stats)
    print("\nGender-level metrics:\n", gender_stats)

    return results_df, race_stats, gender_stats


# --- Run automatically ---
if __name__ == "__main__":
    results_df, race_stats, gender_stats = run_all_auto(train_df)
