<a href="https://colab.research.google.com/github/sharmaishaa/Brain_tumor_survivaldays_prediction/blob/main/braintumor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Cell 1 - setup
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os, random, math, glob, h5py, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, losses, callbacks
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, mean_absolute_error

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

print("TF version:", tf.__version__)
print("GPU available:", tf.config.list_physical_devices('GPU'))

# Paths - change if needed
DRIVE_DATA_PATH = "/content/drive/MyDrive/data"  # where your .h5 files are
CSV_PATH = "/content"  # where your name_mapping.csv, meta_data.csv, survival_info.csv are uploaded

Mounted at /content/drive
TF version: 2.19.0
GPU available: []


In [2]:
# CELL 2 – Load CSVs from Drive

import pandas as pd
import numpy as np
import os
import re # Import re for regex operations

CSV_PATH = "/content/drive/MyDrive/data"

map_path = os.path.join(CSV_PATH, "name_mapping.csv")
meta_path = os.path.join(CSV_PATH, "meta_data.csv")
surv_path = os.path.join(CSV_PATH, "survival_info.csv")

print("Looking for CSVs at:")
print(map_path)
print(meta_path)
print(surv_path)

# load CSVs
df_map = pd.read_csv(map_path) if os.path.exists(map_path) else pd.DataFrame()
df_meta = pd.read_csv(meta_path) if os.path.exists(meta_path) else pd.DataFrame()
df_surv = pd.read_csv(surv_path) if os.path.exists(surv_path) else pd.DataFrame()

print("Loaded CSV shapes:")
print("name_mapping.csv ->", df_map.shape)
print("meta_data.csv ->", df_meta.shape)
print("survival_info.csv ->", df_surv.shape)

if not df_map.empty:
    display(df_map.head())
if not df_surv.empty:
    display(df_surv.head())

# -----------------------------
#  BUILD PATIENT → SURVIVAL DAYS MAP
# -----------------------------

patient_to_surv = {}

if not df_surv.empty and 'Brats20ID' in df_surv.columns:

    temp = df_surv.copy()

    # Extract numeric patient id (last digits from Brats20ID), pad 3 digits (001, 045, etc.)
    # Example: 'BraTS20_Training_001' -> '001'
    temp["patient_id"] = (
        temp["Brats20ID"]
        .astype(str)
        .str.extract(r"(\d+)$") # Extract digits at the end
        [0] # Get the first (and only) capturing group
        .fillna('') # Replace NaN with empty string to prevent zfill from failing
        .apply(lambda x: x.zfill(3) if x else None) # Pad to 3 digits, or None if empty
    )

    # clean survival days
    if "Survival_days" in temp.columns:
        temp["survival_days"] = (
            temp["Survival_days"]
            .astype(str)
            .str.extract(r"(\d+)")[0]
        )
        temp["survival_days"] = pd.to_numeric(temp["survival_days"], errors="coerce")
    else:
        temp["survival_days"] = np.nan

    # -----------------------------
    #  OUTLIER REMOVAL FOR SURVIVAL DAYS (re-added section)
    # -----------------------------
    print("\n--- Outlier Removal for Survival Days ---")
    initial_rows_surv = temp.shape[0]

    # Filter out NaNs for outlier calculation
    numeric_survival_days = temp['survival_days'].dropna()

    if not numeric_survival_days.empty:
        Q1 = numeric_survival_days.quantile(0.25)
        Q3 = numeric_survival_days.quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR

        # Filter out outliers from temp DataFrame
        # Keep NaNs as they are not outliers in the numeric sense and will be handled by mapping
        temp = temp[(temp['survival_days'] >= lower_bound) | temp['survival_days'].isna()]
        temp = temp[(temp['survival_days'] <= upper_bound) | temp['survival_days'].isna()]

        print(f"Original entries with survival info: {initial_rows_surv}")
        print(f"Q1: {Q1:.2f}, Q3: {Q3:.2f}, IQR: {IQR:.2f}")
        print(f"Lower bound: {lower_bound:.2f}, Upper bound: {upper_bound:.2f}")
        print(f"Entries after outlier removal: {temp.shape[0]}")
    else:
        print("No numeric survival days found to perform outlier removal.")

    # -----------------------------
    #  END OUTLIER REMOVAL SECTION
    # -----------------------------

    # Debugging prints for patient_id and survival_days before mapping
    print("\n--- Debugging patient_id and survival_days in temp before mapping ---")
    print("temp['patient_id'] head:\n", temp['patient_id'].head())
    print("temp['patient_id'] null count:", temp['patient_id'].isnull().sum())
    print("temp['survival_days'] head:\n", temp['survival_days'].head())
    print("temp['survival_days'] null count:", temp['survival_days'].isnull().sum())
    print("Total rows in temp:", temp.shape[0])

    # finalize mapping
    for _, row in temp.iterrows():
        pid = row["patient_id"]
        # Ensure pid is a valid numeric string before using as key
        if pd.notnull(pid):
            patient_to_surv[pid] = (
                row["survival_days"] if not pd.isna(row["survival_days"]) else None
            )

print("Survival Mapping Loaded:", len(patient_to_surv))

# -----------------------------
#  BUILD PATIENT → GRADE MAP  (Tumor Present = Yes/No)
# -----------------------------

patient_to_grade = {}

if not df_map.empty:

    # find ID column
    id_col = None
    for opt in ["BraTS_2020_subject_ID", "Subject_ID", "ID", "brats_id"]:
        if opt in df_map.columns:
            id_col = opt
            break
    if id_col is None:
        id_col = df_map.columns[0]   # fallback

    # find grade / tumor-type column
    grade_col = None
    for opt in ["Grade", "grade", "TumorType", "tumor_type"]:
        if opt in df_map.columns:
            grade_col = opt
            break

    # build mapping
    if id_col is not None and grade_col is not None:
        for _, rr in df_map.iterrows():
            full_pid = str(rr[id_col])
            # Extract 3-digit numeric part from full_pid for consistency
            match = re.search(r"(\d+)$", full_pid)
            if match:
                pid = match.group(1).zfill(3) # Ensure 3-digit like '001'
            else:
                pid = None # Or handle cases where ID is not found / not numeric
            grade = rr[grade_col]

            if pid is not None and pd.notnull(grade):
                patient_to_grade[pid] = str(grade).strip()

print("Tumor Grade Mapping Loaded:", len(patient_to_grade))

Looking for CSVs at:
/content/drive/MyDrive/data/name_mapping.csv
/content/drive/MyDrive/data/meta_data.csv
/content/drive/MyDrive/data/survival_info.csv
Loaded CSV shapes:
name_mapping.csv -> (369, 6)
meta_data.csv -> (57195, 4)
survival_info.csv -> (236, 4)


Unnamed: 0,Grade,BraTS_2017_subject_ID,BraTS_2018_subject_ID,TCGA_TCIA_subject_ID,BraTS_2019_subject_ID,BraTS_2020_subject_ID
0,HGG,Brats17_CBICA_AAB_1,Brats18_CBICA_AAB_1,,BraTS19_CBICA_AAB_1,BraTS20_Training_001
1,HGG,Brats17_CBICA_AAG_1,Brats18_CBICA_AAG_1,,BraTS19_CBICA_AAG_1,BraTS20_Training_002
2,HGG,Brats17_CBICA_AAL_1,Brats18_CBICA_AAL_1,,BraTS19_CBICA_AAL_1,BraTS20_Training_003
3,HGG,Brats17_CBICA_AAP_1,Brats18_CBICA_AAP_1,,BraTS19_CBICA_AAP_1,BraTS20_Training_004
4,HGG,Brats17_CBICA_ABB_1,Brats18_CBICA_ABB_1,,BraTS19_CBICA_ABB_1,BraTS20_Training_005


Unnamed: 0,Brats20ID,Age,Survival_days,Extent_of_Resection
0,BraTS20_Training_001,60.463,289,GTR
1,BraTS20_Training_002,52.263,616,GTR
2,BraTS20_Training_003,54.301,464,GTR
3,BraTS20_Training_004,39.068,788,GTR
4,BraTS20_Training_005,68.493,465,GTR



--- Outlier Removal for Survival Days ---
Original entries with survival info: 236
Q1: 190.00, Q3: 579.25, IQR: 389.25
Lower bound: -393.88, Upper bound: 1163.12
Entries after outlier removal: 221

--- Debugging patient_id and survival_days in temp before mapping ---
temp['patient_id'] head:
 0    001
1    002
2    003
3    004
4    005
Name: patient_id, dtype: object
temp['patient_id'] null count: 0
temp['survival_days'] head:
 0    289
1    616
2    464
3    788
4    465
Name: survival_days, dtype: int64
temp['survival_days'] null count: 0
Total rows in temp: 221
Survival Mapping Loaded: 221
Tumor Grade Mapping Loaded: 369


In [3]:
# Cell 3 - helper functions

def get_volume_from_h5(path):
    """
    Attempts to read an h5 file and return (volume, mask).
    volume -> numpy array shape (S, H, W)
    mask -> same shape or None if not available
    """
    with h5py.File(path, 'r') as f:
        # Heuristics: common keys in BRaTS .h5: 'image', 'vol', 'data'; mask: 'mask', 'segmentation'
        vol = None
        msk = None
        for key in f.keys():
            k_low = key.lower()
            if k_low in ('image', 'images', 'vol', 'volume', 'data', 't1', 't1ce', 'flair'):
                try:
                    arr = np.array(f[key])
                    if arr.ndim == 3:
                        vol = arr
                        break
                except:
                    pass
        # If not found yet, pick first 3D dataset
        if vol is None:
            for key in f.keys():
                try:
                    arr = np.array(f[key])
                    if arr.ndim == 3:
                        vol = arr
                        break
                except:
                    pass

        # find mask-like dataset
        for key in f.keys():
            k_low = key.lower()
            if k_low in ('mask', 'seg', 'segmentation', 'label'):
                try:
                    arr = np.array(f[key])
                    if arr.ndim == 3:
                        msk = arr
                        break
                except:
                    pass

    # Standardize orientation: assume vol is (S,H,W) or (H,W,S)
    if vol is None:
        raise ValueError(f"No 3D dataset found in {path}")
    vol = vol.astype(np.float32)
    # If the last dim is smallest, could be (H,W,S)
    if vol.shape[2] < vol.shape[0] and vol.shape[2] < vol.shape[1]:
        vol = np.transpose(vol, (2,0,1))
    if msk is not None:
        msk = msk.astype(np.uint8)
        if msk.shape != vol.shape:
            # try transpose
            if msk.shape[2] < msk.shape[0] and msk.shape[2] < msk.shape[1]:
                msk = np.transpose(msk, (2,0,1))
    return vol, msk

def normalize_slice(slice2d):
    mn = slice2d.min()
    mx = slice2d.max()
    if mx - mn < 1e-8:
        return np.zeros_like(slice2d, dtype=np.float32)
    out = (slice2d - mn) / (mx - mn)
    return out.astype(np.float32)

def file_patient_id_from_name(fname):
    """Try to extract a 3-digit patient id from filename or Brats20ID portion."""
    bn = os.path.basename(fname)
    # Try to match 'volume_XXX_slice_YYY' format
    match = re.search(r"volume_(\d+)_", bn)
    if match:
        return match.group(1).zfill(3) # Extract number and pad to 3 digits
    # If not matched, try to match 'Brats20_Training_XXX' or similar if present in filename
    match = re.search(r"(\d+)$", os.path.splitext(bn)[0])
    if match:
        return match.group(1).zfill(3) # Extract number and pad to 3 digits
    # Fallback if no numeric ID found
    return os.path.splitext(bn)[0]

In [4]:

# Cell 4 - discover .h5 files and prepare sampling blocks
all_h5 = sorted([os.path.join(DRIVE_DATA_PATH, f) for f in os.listdir(DRIVE_DATA_PATH) if f.endswith('.h5')])
print("Total .h5 files in data folder:", len(all_h5))
if len(all_h5) == 0:
    raise SystemExit("No .h5 files found in DRIVE_DATA_PATH. Check the path or upload files.")

# We'll partition the file list into blocks of 50 (1-50, 51-100...) and from each block sample up to N_VOLUMES_PER_BLOCK volumes
BLOCK_SIZE = 50
N_VOLUMES_PER_BLOCK = 10   # from each block of 50 volumes, pick 10 volumes randomly
MAX_BLOCKS = math.ceil(len(all_h5) / BLOCK_SIZE)
print(f"Blocks: {MAX_BLOCKS}, BLOCK_SIZE={BLOCK_SIZE}, pick {N_VOLUMES_PER_BLOCK} per block")

selected_volume_paths = []
for b in range(MAX_BLOCKS):
    start = b * BLOCK_SIZE
    end = min((b + 1) * BLOCK_SIZE, len(all_h5))
    block_files = all_h5[start:end]
    pick = min(N_VOLUMES_PER_BLOCK, len(block_files))
    sampled = random.sample(block_files, pick)
    selected_volume_paths.extend(sampled)

print("Selected volumes count:", len(selected_volume_paths))
# show example
selected_volume_paths[:5]


Total .h5 files in data folder: 32422
Blocks: 649, BLOCK_SIZE=50, pick 10 per block
Selected volumes count: 6490


['/content/drive/MyDrive/data/volume_100_slice_134.h5',
 '/content/drive/MyDrive/data/volume_100_slice_104.h5',
 '/content/drive/MyDrive/data/volume_100_slice_1.h5',
 '/content/drive/MyDrive/data/volume_100_slice_113.h5',
 '/content/drive/MyDrive/data/volume_100_slice_111.h5']

In [None]:
# Cell 5 - create dataset (small)
# For each selected volume, pick up to N_SLICES_PER_VOL random slices to represent the volume
N_SLICES_PER_VOL = 10            # number of slices to sample per selected volume
TARGET_HW = (128, 128)           # resize slices to this spatial size
CHANNELS = 3                     # stack to RGB-like for CNN

from tensorflow.keras.preprocessing.image import img_to_array, array_to_img
import cv2

def prepare_volume_sample(h5_path, n_slices=N_SLICES_PER_VOL, target_hw=TARGET_HW):
    vol, mask = get_volume_from_h5(h5_path)    # vol shape (S,H,W)
    S = vol.shape[0]
    # choose slice indices: if S < n_slices then sample with replacement or pad later
    if S <= n_slices:
        idxs = list(range(S))
    else:
        idxs = random.sample(list(range(S)), n_slices)

    slices = []
    mask_presence = 0
    for i in idxs:
        sl = vol[i]
        sln = normalize_slice(sl)
        # resize to target_hw
        slr = cv2.resize(sln, (target_hw[1], target_hw[0]), interpolation=cv2.INTER_AREA)
        # convert single channel to CHANNELS (RGB style)
        if CHANNELS == 3:
            sl_rgb = np.stack([slr, slr, slr], axis=-1)
        else:
            sl_rgb = slr[...,None]
        slices.append(sl_rgb.astype(np.float32))
        # check mask presence for that slice
        if mask is not None:
            if mask.shape[0] == vol.shape[0]:
                if np.sum(mask[i]) > 50:
                    mask_presence += 1
            else:
                # fallback: if shapes mismatched, try to check any nonzero in whole mask
                if np.sum(mask) > 0:
                    mask_presence = 1

    slices = np.stack(slices, axis=0)  # (n_slices, H, W, C)
    # decide binary label: tumor present if any selected slice has mask pixels OR if patient mapping says HGG
    binary_label = 1 if mask_presence > 0 else 0

    # try also to check mapping from file name
    pid = file_patient_id_from_name(h5_path)
    if pid in patient_to_grade:
        g = patient_to_grade[pid]
        # treat HGG as tumor present
        if g.upper().find('HGG') >= 0:
            binary_label = 1
        elif g.upper().find('LGG') >= 0:
            binary_label = binary_label or 0

    # survival days from patient_to_surv mapping if available (else NaN)
    surv = patient_to_surv.get(pid, np.nan)

    return slices, int(binary_label), surv, pid

# Build arrays (volume-level)
vol_embeddings_slices = []   # list of (n_slices,H,W,C) arrays per volume
labels_binary = []
labels_surv = []
patient_ids = []
paths_used = []

print("Preparing samples from selected volumes ...")
for pth in tqdm(selected_volume_paths):
    try:
        sls, lbl, surv, pid = prepare_volume_sample(pth)
        vol_embeddings_slices.append(sls)   # keep slices - we will pass through CNN later
        labels_binary.append(lbl)
        labels_surv.append(surv if not pd.isna(surv) else np.nan)
        patient_ids.append(pid)
        paths_used.append(pth)
    except Exception as e:
        print("Error reading", pth, e)

print("Prepared volumes:", len(vol_embeddings_slices))


Preparing samples from selected volumes ...


 31%|███       | 2008/6490 [10:22<01:56, 38.44it/s]

In [None]:
# Cell 6 - slice-level CNN encoder -> produces embedding per slice; then average across slices for volume embedding
from tensorflow.keras import Input, Model

EMBED_DIM = 256

def build_slice_encoder(input_shape=(TARGET_HW[0], TARGET_HW[1], CHANNELS), embedding_dim=EMBED_DIM):
    inp = Input(shape=input_shape)
    x = layers.Conv2D(32, 3, activation='relu', padding='same')(inp)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(64, 3, activation='relu', padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(128, 3, activation='relu', padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(embedding_dim, activation='relu')(x)
    model = Model(inp, x, name='slice_encoder')
    return model

slice_encoder = build_slice_encoder()
slice_encoder.summary()

# Compute embeddings for each slice and average per volume
volume_embeddings = []  # shape (n_volumes, EMBED_DIM)
valid_indices = []      # keep indices where embedding computed successfully
print("Computing per-volume embeddings (average of slice embeddings)...")
for i, sls in enumerate(tqdm(vol_embeddings_slices)):
    try:
        # sls shape (n_slices, H, W, C)
        emb_slices = slice_encoder.predict(sls, verbose=0)   # (n_slices, EMBED_DIM)
        # average pooling across slices
        emb_vol = np.mean(emb_slices, axis=0)
        volume_embeddings.append(emb_vol)
        valid_indices.append(i)
    except Exception as e:
        print("Failed embedding for volume", i, "error:", e)

volume_embeddings = np.stack(volume_embeddings, axis=0).astype(np.float32)
labels_binary = np.array(labels_binary)[valid_indices].astype(int)
labels_surv = np.array(labels_surv)[valid_indices].astype(np.float32)
patient_ids = np.array(patient_ids)[valid_indices]
paths_used = np.array(paths_used)[valid_indices]

print("Final dataset volumes (after embedding):", volume_embeddings.shape)
print("Labels (binary) shape:", labels_binary.shape)
print("Survival labels shape:", labels_surv.shape)


In [None]:
# Cell 7 - train/test split at volume-level
# We'll keep a regression target for survival if available; for missing survival values, we'll exclude from regression training
X = volume_embeddings
y_class = labels_binary
y_surv_raw = labels_surv  # may contain NaN

# Train/test split (stratify by class to keep balance)
X_tr, X_te, ytr_class, yte_class, ytr_surv, yte_surv, tr_ids, te_ids, tr_paths, te_paths = train_test_split(
    X, y_class, y_surv_raw, patient_ids, paths_used, test_size=0.2, random_state=SEED, stratify=y_class
)

print("Train volumes:", X_tr.shape[0], "Test volumes:", X_te.shape[0])


In [None]:
# Cell 8 - multi-head model on top of embeddings
IN_DIM = X_tr.shape[1]

inp = layers.Input(shape=(IN_DIM,), name='embedding_input')
x = layers.Dense(128, activation='relu')(inp)
x = layers.Dropout(0.3)(x)
x = layers.Dense(64, activation='relu')(x)

# classification head
c = layers.Dense(32, activation='relu')(x)
c = layers.Dropout(0.2)(c)
out_class = layers.Dense(1, activation='sigmoid', name='tumor_present')(c)

# regression head (predict log1p of days)
r = layers.Dense(32, activation='relu')(x)
r = layers.Dropout(0.2)(r)
out_reg = layers.Dense(1, activation='linear', name='survival_log')(r)

multi_head = models.Model(inputs=inp, outputs=[out_class, out_reg], name='multi_head_model')
multi_head.compile(
    optimizer=optimizers.Adam(1e-3),
    loss={
        'tumor_present': losses.BinaryCrossentropy(),
        'survival_log': losses.MeanSquaredError()
    },
    loss_weights={'tumor_present': 1.0, 'survival_log': 0.5},  # downweight regression if noisy
    metrics={'tumor_present': ['accuracy'], 'survival_log': ['mae']}
)
multi_head.summary()


In [None]:
# ============================================================
# CELL 9 — TRAINING (with Checkpoints + EarlyStopping)
# ============================================================

# 1. Create a folder in Drive to store checkpoints
CHECKPOINT_DIR = "/content/drive/MyDrive/model_checkpoints_brats"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

checkpoint_path = os.path.join(CHECKPOINT_DIR, "best_model.keras")

# 2. Callbacks
early_stopping_callback = callbacks.EarlyStopping(
    monitor='val_tumor_present_accuracy', # Corrected output name
    patience=5,
    mode='max',
    restore_best_weights=True
)

model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    monitor='val_tumor_present_accuracy', # Corrected output name
    mode='max',
    save_best_only=True,
    save_weights_only=False,
    verbose=1
)

# 3. Prepare targets for training
# survival regression has NaN values → mask them out
train_mask = ~np.isnan(ytr_surv)
test_mask  = ~np.isnan(yte_surv)

ytr_surv_clean = np.nan_to_num(ytr_surv, nan=0)
yte_surv_clean = np.nan_to_num(yte_surv, nan=0)

# 4. Compile the model
multi_head.compile(
    optimizer='adam',
    loss={
        'tumor_present': 'binary_crossentropy', # Corrected output name
        'survival_log': 'mse' # Corrected output name
    },
    loss_weights={
        'tumor_present': 1.0, # Corrected output name
        'survival_log': 0.3 # Corrected output name
    },
    metrics={
        'tumor_present': ['accuracy'], # Corrected output name
        'survival_log': ['mae'] # Corrected output name
    }
)

# 5. Train the model
history = multi_head.fit(
    X_tr,
    {
        'tumor_present': ytr_class, # Corrected output name
        'survival_log': ytr_surv_clean # Corrected output name
    },
    validation_data=(
        X_te,
        {
            'tumor_present': yte_class, # Corrected output name
            'survival_log': yte_surv_clean # Corrected output name
        }
    ),
    epochs=50,
    batch_size=16,
    callbacks=[early_stopping_callback, model_checkpoint_callback],
    verbose=1
)

print("Training Completed!")
print(f"Best model saved at: {checkpoint_path}")

In [None]:
#cell new1
from tensorflow.keras.layers import Input, Dense, Dropout, Lambda
from tensorflow.keras.models import Model

# Input embedding (256-D vector from slice encoder)
inp = Input(shape=(256,))

# Shared dense layers
x = Dense(128, activation='relu')(inp)
x = Dropout(0.2)(x)
x = Dense(64, activation='relu')(x)

# Head 1 – Tumor Classification
class_output = Dense(1, activation='sigmoid', name='tumor_output')(x)

# Head 2 – Survival Regression
reg_output = Dense(1, activation='linear', name='survival_output')(x)

# ----------------------------------------------------------
# NEW HEAD 3 – Consistency loss: tumor_prob * survival_days
# Forces: high tumor prob → low survival days
# ----------------------------------------------------------
consistency_output = Lambda(
    lambda z: z[0] * z[1],
    name='consistency_output'
)([class_output, reg_output])

# Final Model
multi_head = Model(
    inputs=inp,
    outputs=[class_output, reg_output, consistency_output]
)

multi_head.summary()

In [None]:
#new cell 2
multi_head.compile(
    optimizer='adam',
    loss={
        'tumor_output': 'binary_crossentropy',
        'survival_output': 'mse',
        'consistency_output': 'mse'
    },
    loss_weights={
        'tumor_output': 1.0,
        'survival_output': 1.0,
        'consistency_output': 0.3   # small weight but enforces correlation
    },
    metrics={
        'tumor_output': ['accuracy'],
        'survival_output': ['mae']
    }
)


In [None]:
# Cell 9 - prepare targets for training
# For regression training we will only include volumes with known survival days (non-NaN)
# create masks
def safe_log1p(x):
    return np.log1p(x)

# create y_surv_train_log only for available values; we'll feed full arrays but set sample_weight=0 for missing rows in regression
ytr_surv_arr = np.array([v if not np.isnan(v) else 0.0 for v in ytr_surv], dtype=np.float32)
yte_surv_arr = np.array([v if not np.isnan(v) else 0.0 for v in yte_surv], dtype=np.float32)
ytr_surv_log = safe_log1p(ytr_surv_arr)
yte_surv_log = safe_log1p(yte_surv_arr)

# sample weights for regression: 1 if survival known else 0 (so loss doesn't include missing)
wtr_reg = np.array([0.0 if np.isnan(v) else 1.0 for v in ytr_surv], dtype=np.float32).reshape(-1, 1)
wte_reg = np.array([0.0 if np.isnan(v) else 1.0 for v in yte_surv], dtype=np.float32).reshape(-1, 1)

# For classification head weights we keep ones
wtr_class = np.ones_like(wtr_reg, dtype=np.float32)
wte_class = np.ones_like(wte_reg, dtype=np.float32)

# For consistency head weights we keep ones (always apply this regularization)
wtr_consistency = np.ones_like(wtr_class, dtype=np.float32)
wte_consistency = np.ones_like(wte_class, dtype=np.float32)

# Combine sample weights into a dict for fit
sample_weights_train = {
    'tumor_output': wtr_class,
    'survival_output': wtr_reg,
    'consistency_output': wtr_consistency
}
sample_weights_val = {
    'tumor_output': wte_class,
    'survival_output': wte_reg,
    'consistency_output': wte_consistency
}

# Prepare y dicts for fit
y_train_dict = {
    'tumor_output': ytr_class.reshape(-1,1),
    'survival_output': ytr_surv_log.reshape(-1,1),
    'consistency_output': np.zeros_like(ytr_class.reshape(-1,1), dtype=np.float32) # Target for consistency: typically 0 to minimize product
}
y_val_dict   = {
    'tumor_output': yte_class.reshape(-1,1),
    'survival_output': yte_surv_log.reshape(-1,1),
    'consistency_output': np.zeros_like(yte_class.reshape(-1,1), dtype=np.float32)
}

print("Training samples (reg known):", int(wtr_reg.sum()), "Validation samples (reg known):", int(wte_reg.sum()))

In [None]:
EPOCHS = 50
BATCH = 8

# --- Checkpoint directory ---
checkpoint_dir = os.path.join(DRIVE_DATA_PATH, 'model_checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_filepath = os.path.join(checkpoint_dir, 'multi_head_best.h5')

# --- Callback ---
model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    verbose=1
)

# --- TRAINING: EVERYTHING MUST BE A DICT ---
history = multi_head.fit(
    X_tr,
    {
        'tumor_output': y_train_tumor,
        'survival_output': y_train_survival,
        'consistency_output': np.zeros(len(y_train_tumor))
    },
    validation_data=(
        X_te,
        {
            'tumor_output': y_val_tumor,
            'survival_output': y_val_survival,
            'consistency_output': np.zeros(len(y_val_tumor))
        }
    ),
    sample_weight={
        'tumor_output': np.ones(len(y_train_tumor)),
        'survival_output': np.ones(len(y_train_survival)),
        'consistency_output': np.ones(len(y_train_tumor))
    },
    epochs=EPOCHS,
    batch_size=BATCH,
    verbose=2,
    callbacks=[model_checkpoint_callback]
)

print("Checkpoint saved at:", checkpoint_filepath)


In [None]:
#new cell3  Add third output → consistency target = zeros
y_train_dict = {
    'tumor_output': y_train_tumor,
    'survival_output': y_train_survival,
    'consistency_output': np.zeros(len(y_train_tumor))
}

y_val_dict = {
    'tumor_output': y_val_tumor,
    'survival_output': y_val_survival,
    'consistency_output': np.zeros(len(y_val_tumor))
}

# TRAINING
history = multi_head.fit(
    X_train,
    y_train_dict,
    validation_data=(X_val, y_val_dict),
    epochs=40,
    batch_size=8
)


In [None]:
# Cell 11 - evaluation
preds = multi_head.predict(X_te, verbose=0)
pred_class_probs = preds[0].reshape(-1)
pred_reg_log = preds[1].reshape(-1)

pred_class_labels = (pred_class_probs >= 0.5).astype(int)
true_class_labels = yte_class.astype(int)

# Classification metrics
acc = accuracy_score(true_class_labels, pred_class_labels)
print("Test classification accuracy:", acc)
print("\nClassification report:\n", classification_report(true_class_labels, pred_class_labels))

cm = confusion_matrix(true_class_labels, pred_class_labels)
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt="d", cmap='Blues')
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix - Tumor Presence")
plt.show()

# Regression metrics: only on entries where ground truth survival exists (wte_reg == 1)
mask_reg = (wte_reg == 1).flatten() # Flatten the mask to be 1D
if mask_reg.sum() > 0:
    true_reg_log = yte_surv_log[mask_reg]
    pred_reg_log_masked = pred_reg_log[mask_reg]
    # convert back to days
    true_days = np.expm1(true_reg_log)
    pred_days = np.expm1(pred_reg_log_masked)
    mae_days = mean_absolute_error(true_days, pred_days)
    print(f"Regression MAE (days) on {mask_reg.sum()} samples: {mae_days:.2f}")
    # print some examples
    print("\nExamples: true_days vs pred_days")
    for t, p, pid in list(zip(true_days, pred_days, te_ids[mask_reg]))[:10]:
        print(f"pid {pid} \u2192 true {t:.1f}, pred {p:.1f}")
else:
    print("No survival ground-truth available in test to evaluate regression.")

In [None]:
# Cell 12 - optional save (uncomment to save)
# save the slice encoder and multi-head model (SavedModel format) if desired
# slice_encoder.save("/content/slice_encoder_saved")
# multi_head.save("/content/multi_head_saved")
print("Models can be saved by uncommenting the save lines in this cell.")


In [None]:
# Cell 13 - prediction function for input .h5 file (user input)
def predict_from_h5(h5_path, n_slices=N_SLICES_PER_VOL, target_hw=TARGET_HW):
    if not os.path.exists(h5_path):
        print("File not found:", h5_path); return None
    vol, mask = get_volume_from_h5(h5_path)
    S = vol.shape[0]
    # choose slices evenly if many slices, else random
    if S <= n_slices:
        idxs = list(range(S))
    else:
        # sample evenly spaced indices for reproducibility
        idxs = np.linspace(0, S-1, n_slices, dtype=int).tolist()

    slices = []
    for i in idxs:
        sl = normalize_slice(vol[i])
        slr = cv2.resize(sl, (target_hw[1], target_hw[0]), interpolation=cv2.INTER_AREA)
        if CHANNELS == 3:
            sl_rgb = np.stack([slr, slr, slr], axis=-1)
        else:
            sl_rgb = slr[...,None]
        slices.append(sl_rgb.astype(np.float32))
    slices = np.stack(slices, axis=0)

    # compute slice embeddings
    emb_slices = slice_encoder.predict(slices, verbose=0)
    emb_vol = np.mean(emb_slices, axis=0).reshape(1, -1)

    # multi-head prediction
    pred_class_prob, pred_reg_log = multi_head.predict(emb_vol, verbose=0)
    tumor_prob = float(pred_class_prob[0][0])
    tumor_label = int(tumor_prob >= 0.5)
    pred_days = float(np.expm1(pred_reg_log[0][0]))

    print(f"Predicted tumor presence (prob): {tumor_prob:.3f} → label {tumor_label}")
    print(f"Predicted survival days (expm1 of log): {pred_days:.1f}")
    return {'tumor_prob': tumor_prob, 'tumor_label': tumor_label, 'survival_days': pred_days, 'patient_id': file_patient_id_from_name(h5_path)}

# Example usage (uncomment and change path)
# res = predict_from_h5("/content/drive/MyDrive/data/BraTS20_Training_001.h5")
# print(res)


In [None]:
# Cell 14 - interactive input for .h5 path
h5_input = input("Enter path to a single .h5 file (e.g. /content/drive/MyDrive/data/BraTS20_Training_001.h5): ").strip()
if h5_input:
    predict_from_h5(h5_input)
else:
    print("No path entered.")
