## Phase-3: Baseline

Step-0a: Data Ingestion and Initial Preprocessing

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Go to your project folder
%cd /content/drive/MyDrive/multimodal_mammography


Mounted at /content/drive
/content/drive/MyDrive/multimodal_mammography


In [None]:
import importlib.util

def load_module_from_path(name, path):
    spec = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


In [None]:
# Load environment setup
env = load_module_from_path("env", "setup/environment.py")
install = load_module_from_path("install", "setup/install_colab.py")
_ = load_module_from_path("imports", "setup/imports.py")  # No functions to call

# Run setup
install.install_dependencies()
env.suppress_warnings()
env.set_seed(42)
device = env.get_device()


🔄 Detected Google Colab environment.
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Google Drive mounted.
📦 Installing required packages...
✅ Dependencies installed.
🔁 Seed set to 42
 Using device: cuda


Step-0b: Loading Required csvs' and extracting/exploring images


In [None]:
# ✅ Load the dynamic module
data_loader = load_module_from_path("data_loader", "data/load_data.py")

# ✅ Correct CSV paths
metadata_path    = "/content/drive/MyDrive/multimodal_mammography/dataset/csv/metadata.csv"
breast_anno_path = "/content/drive/MyDrive/multimodal_mammography/dataset/csv/breast-level_annotations.csv"
finding_anno_path = "/content/drive/MyDrive/multimodal_mammography/dataset/csv/finding_annotations.csv"

# ✅ Load and view data
metadata_df, breast_df, finding_df = data_loader.load_mammo_data(
    metadata_path,
    breast_anno_path,
    finding_anno_path,
    verbose=False
)


In [None]:
import pandas as pd

image_df=pd.read_csv("/content/drive/MyDrive/multimodal_mammography/dataset/csv/image_df_upsampled_studywise.csv")

In [None]:
print(image_df.columns)

Index(['image_id', 'study_id', 'filename', 'birads', 'birads_dir', 'density',
       'laterality', 'view_position', 'split', 'finding_categories',
       'finding_birads_clean', 'xmin', 'ymin', 'xmax', 'ymax', 'has_bbox',
       'age', 'birads_binary', 'birads_cleaned', 'birads_study_level',
       'finding_mass', 'finding_suspicious_calcification',
       'finding_focal_asymmetry', 'finding_asymmetry',
       'finding_global_asymmetry', 'finding_architectural_distortion',
       'finding_skin_thickening', 'finding_skin_retraction',
       'finding_nipple_retraction', 'finding_suspicious_lymph_node',
       'finding_no_finding', 'image_path', 'case_category', 'upsampled'],
      dtype='object')


In [None]:
import zipfile
import os

# Path to your zip file
zip_path = "/content/drive/MyDrive/multimodal_mammography/dataset/zipped_folder/birads_preprocessed_dataset.zip"

# Destination folder to extract files
extract_dir = "/content/birads_preprocessed_dataset"

# Make sure the directory exists
os.makedirs(extract_dir, exist_ok=True)

# Unzip the dataset
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

print("Extraction complete.")
print("Extracted to:", extract_dir)


Extraction complete.
Extracted to: /content/birads_preprocessed_dataset


In [None]:
import os

# List a few extracted files/folders
for root, dirs, files in os.walk(extract_dir):
    print("Root:", root)
    print("Subdirs:", dirs[:5])   # show first 5 dirs
    print("Files:", files[:5])   # show first 5 files
    break


Root: /content/birads_preprocessed_dataset
Subdirs: ['training', 'test']
Files: ['image_df_upsampled_preprocessed.csv']


Step-0c: Validating existing dataset


In [None]:
import os
from collections import defaultdict

base_dir = "/content/birads_preprocessed_dataset"

def find_small_studies(base_dir, min_images=4):
    small_studies = defaultdict(list)

    for split in ["training", "test"]:
        for case in ["normal", "abnormal"]:
            case_path = os.path.join(base_dir, split, case)
            if not os.path.exists(case_path):
                continue

            for study in os.listdir(case_path):
                study_path = os.path.join(case_path, study)
                if not os.path.isdir(study_path):
                    continue

                imgs = [f for f in os.listdir(study_path) if f.endswith(".png")]
                if len(imgs) < min_images:
                    small_studies[(split, case, study)] = imgs

    return small_studies

small_studies = find_small_studies(base_dir)

if small_studies:
    print("⚠️ Studies with fewer than 4 images:")
    for (split, case, study), imgs in small_studies.items():
        print(f"- {split}/{case}/{study} -> {len(imgs)} images: {imgs}")
else:
    print("✅ All studies have at least 4 images.")


✅ All studies have at least 4 images.


In [None]:
import os
from collections import Counter

base_dir = "/content/birads_preprocessed_dataset"
splits = ["training", "test"]
classes = ["normal", "abnormal"]

# Dictionary to store study -> image count
study_image_counts = {}

for split in splits:
    split_path = os.path.join(base_dir, split)
    for cls in classes:
        cls_path = os.path.join(split_path, cls)
        if not os.path.exists(cls_path):
            continue
        for study in os.listdir(cls_path):
            study_path = os.path.join(cls_path, study)
            if os.path.isdir(study_path):
                images = [f for f in os.listdir(study_path) if f.endswith(".png")]
                study_image_counts[study] = len(images)

# Summarize the distribution of images per study
count_distribution = Counter(study_image_counts.values())
print("Image count per study distribution:")
for n_images, n_studies in sorted(count_distribution.items()):
    print(f"{n_images} images: {n_studies} studies")

# Optional: total studies
print(f"\nTotal studies counted: {len(study_image_counts)}")


Image count per study distribution:
4 images: 7999 studies

Total studies counted: 7999


In [None]:
import os
import pandas as pd
from tqdm import tqdm  # import tqdm

# Paths
base_dir = "/content/birads_preprocessed_dataset"
original_csv = os.path.join(base_dir, "image_df_upsampled_preprocessed.csv")
fixed_csv = os.path.join(base_dir, "image_df_preprocessed_fixed.csv")

# Load original CSV
df_orig = pd.read_csv(original_csv)

# Ensure string types for safe matching
df_orig["study_id"] = df_orig["study_id"].astype(str)
df_orig["filename"] = df_orig["filename"].astype(str)

# Prepare list for final rows
rows = []

splits = ["training", "test"]
classes = ["normal", "abnormal"]

for split in splits:
    split_path = os.path.join(base_dir, split)
    if not os.path.exists(split_path):
        continue

    for cls in classes:
        cls_path = os.path.join(split_path, cls)
        if not os.path.exists(cls_path):
            continue

        study_list = [s for s in os.listdir(cls_path) if os.path.isdir(os.path.join(cls_path, s))]
        for study in tqdm(study_list, desc=f"{split}/{cls} studies"):
            study_path = os.path.join(cls_path, study)

            images = sorted([f for f in os.listdir(study_path) if f.endswith(".png")])
            if len(images) != 4:
                continue  # only keep studies with exactly 4 images

            for img in images:
                # Try to get metadata from original CSV
                match = df_orig[(df_orig["study_id"] == study) &
                                (df_orig["filename"] == img)]
                if not match.empty:
                    row = match.iloc[0].copy()
                    row["image_path"] = os.path.join(study_path, img)  # update path
                else:
                    # If missing in original CSV, create minimal row with placeholders
                    row = {col: -1 for col in df_orig.columns}  # -1 as placeholder
                    row["study_id"] = study
                    row["filename"] = img
                    row["image_path"] = os.path.join(study_path, img)
                    row["split"] = split
                    row["case_category"] = cls

                rows.append(row)

# Build DataFrame
df_fixed = pd.DataFrame(rows)

# Save CSV
df_fixed.to_csv(fixed_csv, index=False)

# Summary
print(f"✅ Fixed CSV saved at: {fixed_csv}")
print(f"Total studies included: {df_fixed['study_id'].nunique()}")
print(f"Total images included: {len(df_fixed)}")


training/normal studies: 100%|██████████| 5065/5065 [02:01<00:00, 41.81it/s]
training/abnormal studies: 100%|██████████| 1934/1934 [00:45<00:00, 42.11it/s]
test/normal studies: 100%|██████████| 916/916 [00:21<00:00, 43.33it/s]
test/abnormal studies: 100%|██████████| 84/84 [00:02<00:00, 40.34it/s]


✅ Fixed CSV saved at: /content/birads_preprocessed_dataset/image_df_preprocessed_fixed.csv
Total studies included: 7999
Total images included: 31996


In [None]:
print(df_fixed.columns)

Index(['image_id', 'study_id', 'filename', 'birads', 'birads_dir', 'density',
       'laterality', 'view_position', 'split', 'finding_categories',
       'finding_birads_clean', 'xmin', 'ymin', 'xmax', 'ymax', 'has_bbox',
       'age', 'birads_binary', 'birads_cleaned', 'birads_study_level',
       'finding_mass', 'finding_suspicious_calcification',
       'finding_focal_asymmetry', 'finding_asymmetry',
       'finding_global_asymmetry', 'finding_architectural_distortion',
       'finding_skin_thickening', 'finding_skin_retraction',
       'finding_nipple_retraction', 'finding_suspicious_lymph_node',
       'finding_no_finding', 'image_path', 'case_category', 'upsampled',
       'preprocessed_path'],
      dtype='object')


In [None]:
import os
import pandas as pd

# Paths
base_dir = "/content/birads_preprocessed_dataset"
csv_path = os.path.join(base_dir, "image_df_preprocessed_fixed.csv")

# Load CSV metadata
df = pd.read_csv(csv_path)

# Build sets for quick lookup
expected_study_ids = set(df["study_id"].astype(str).unique())
expected_image_ids = set(df["image_id"].astype(str).unique())

issues = []

# Iterate over splits and classes
for split in ["training", "test"]:
    for cls in ["normal", "abnormal"]:
        cls_path = os.path.join(base_dir, split, cls)
        if not os.path.exists(cls_path):
            issues.append(f"Missing folder: {cls_path}")
            continue

        # Iterate over studies
        for study in os.listdir(cls_path):
            study_path = os.path.join(cls_path, study)
            if not os.path.isdir(study_path):
                continue

            # Validate study ID
            if study not in expected_study_ids:
                issues.append(f"Study folder '{study}' not found in CSV")

            # Validate image files
            for img in os.listdir(study_path):
                if img.endswith(".png"):
                    img_id = os.path.splitext(img)[0]  # remove extension
                    if img_id not in expected_image_ids:
                        issues.append(f"Image '{img}' in '{study_path}' not found in CSV")

# Summary
if not issues:
    print("✅ Dataset structure matches CSV metadata and is valid.")
else:
    print("⚠️ Issues found:")
    for issue in issues:
        print("-", issue)


✅ Dataset structure matches CSV metadata and is valid.


In [None]:
import pandas as pd
import os

# Paths
base_dir = "/content/birads_preprocessed_dataset"
csv_path = os.path.join(base_dir, "image_df_preprocessed_fixed.csv")

# Load CSV
df = pd.read_csv(csv_path)

# Iterate over all columns and print unique values and counts
for col in df.columns:
    unique_vals = df[col].nunique()
    print(f"{col}: {unique_vals} unique values")
    # Optionally, show top 10 most frequent values
    print(df[col].value_counts().head(10))
    print("-" * 50)


image_id: 19996 unique values
image_id
7dbf6830cc06730cfe74cd58937f89a8    24
2973bcf878fad1e9edade25be62602ce    24
6266ffa44d75d2edc9d3c725b20b6d49    24
2bd9c72b886e97da1aff1361962c6acc    24
3cd51ee99070c4d625d52b848d5e9bfc    22
10e0f362333df810ac84a9db8fb3fd42    22
85a6579cbdc403cfc4dde0a8149ed855    22
a7acc2e02a4944c4fc72e32507b17fa7    22
136a7d195b654c4bf862fdd076c77574    21
360a2637d08a1b3f8fef4f3a3e14e717    21
Name: count, dtype: int64
--------------------------------------------------
study_id: 7999 unique values
study_id
2df4404b9c286ec49851d4261faec924            4
06638221ad71bf397e1109b67e425597            4
0b666e8ea8e2a0772fcd18b528553565            4
bbe408b621d15ba6fc1de8ee82dbb920            4
b2b49d880f06ca420bae27d840082b2b            4
094f43a5547f064caa95e2f160609904            4
736680ada4d51ca3d582a2453a3060ed            4
44c63ecd23616c4fce7e817470401d5d_dup1857    4
8a12da7b03eed18bcae0bf177a71b560            4
1b0eacc4e131185f50c3d366cda14307          

Step-1: Feature Engineering

In [None]:
import pandas as pd
import os
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

# Paths
base_dir = "/content/birads_preprocessed_dataset"
csv_path = os.path.join(base_dir, "image_df_preprocessed_fixed.csv")
output_csv_path = os.path.join(base_dir, "study_level_metadata.csv")

# Load CSV
df = pd.read_csv(csv_path)

# Keep only study-level rows (drop image duplicates)
df_study = df.drop_duplicates(subset=["study_id"]).reset_index(drop=True)

# Target variable
df_study["birads_binary_num"] = df_study["birads_binary"].map({"normal": 0, "abnormal": 1})

# Features: age + density only
num_features = ["age"]
cat_features = ["density"]

X = df_study[num_features + cat_features]

# Preprocessing: scale numeric + one-hot encode categorical (keep all categories)
preprocessor = ColumnTransformer(
    transformers=[
        ("num", StandardScaler(), num_features),
        ("cat", OneHotEncoder(drop=None, sparse_output=False), cat_features),
    ]
)

# Apply preprocessing
X_processed = preprocessor.fit_transform(X)

# Get one-hot density column names
density_cols = preprocessor.named_transformers_["cat"].get_feature_names_out(cat_features)

# Build final DataFrame
import numpy as np
df_final = pd.DataFrame(
    np.hstack([X_processed, df_study["birads_binary_num"].values.reshape(-1, 1)]),
    columns=list(num_features) + list(density_cols) + ["birads_binary"]
)

# Add study_id for reference
df_final["study_id"] = df_study["study_id"].values

# Reorder columns: study_id first
df_final = df_final[["study_id"] + list(num_features) + list(density_cols) + ["birads_binary"]]

# Save to CSV
df_final.to_csv(output_csv_path, index=False)

print("Final study-level CSV saved to:", output_csv_path)
print("Shape:", df_final.shape)
print("Columns:", df_final.columns.tolist())


Final study-level CSV saved to: /content/birads_preprocessed_dataset/study_level_metadata.csv
Shape: (7999, 7)
Columns: ['study_id', 'age', 'density_A', 'density_B', 'density_C', 'density_D', 'birads_binary']


Step-2: Model Result based on only metadata

In [None]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    accuracy_score, f1_score, brier_score_loss
)
import xgboost as xgb

# Paths
base_dir = "/content/birads_preprocessed_dataset"
csv_path = os.path.join(base_dir, "study_level_metadata.csv")

# Load study-level dataset
df = pd.read_csv(csv_path)

# Features and target
feature_cols = [col for col in df.columns if col not in ["study_id", "birads_binary"]]
X = df[feature_cols].values
y = df["birads_binary"].values
study_ids = df["study_id"].values

print("Raw features shape:", X.shape)

# Study-wise split
train_ids, test_ids, y_train_ids, y_test_ids = train_test_split(
    study_ids, y, test_size=0.125, random_state=42, stratify=y
)
train_ids, val_ids, y_train_ids, y_val_ids = train_test_split(
    train_ids, y_train_ids, test_size=0.2, random_state=42, stratify=y_train_ids
)

# Masks for indexing
train_mask = np.isin(study_ids, train_ids)
val_mask = np.isin(study_ids, val_ids)
test_mask = np.isin(study_ids, test_ids)

X_train, y_train = X[train_mask], y[train_mask]
X_val, y_val = X[val_mask], y[val_mask]
X_test, y_test = X[test_mask], y[test_mask]

print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")

# -------------------------------
# Logistic Regression
# -------------------------------
logreg = LogisticRegression(max_iter=500, class_weight="balanced", solver="liblinear")
logreg.fit(X_train, y_train)

y_val_pred = logreg.predict_proba(X_val)[:, 1]
y_test_pred = logreg.predict_proba(X_test)[:, 1]

print("\n=== Logistic Regression ===")
print(f"Val AUROC: {roc_auc_score(y_val, y_val_pred):.4f}")
print(f"Val AUPRC: {average_precision_score(y_val, y_val_pred):.4f}")
print(f"Test AUROC: {roc_auc_score(y_test, y_test_pred):.4f}")
print(f"Test AUPRC: {average_precision_score(y_test, y_test_pred):.4f}")
print(f"Test Accuracy: {accuracy_score(y_test, (y_test_pred > 0.5)):.4f}")
print(f"Test F1: {f1_score(y_test, (y_test_pred > 0.5)):.4f}")
print(f"Test Brier Score: {brier_score_loss(y_test, y_test_pred):.4f}")

# -------------------------------
# XGBoost
# -------------------------------
xgb_model = xgb.XGBClassifier(
    objective="binary:logistic",
    eval_metric="auc",
    scale_pos_weight=(y_train == 0).sum() / (y_train == 1).sum(),
    use_label_encoder=False,
    random_state=42,
    n_estimators=200,
    max_depth=4,
    learning_rate=0.05
)

xgb_model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)

y_val_pred = xgb_model.predict_proba(X_val)[:, 1]
y_test_pred = xgb_model.predict_proba(X_test)[:, 1]

print("\n=== XGBoost ===")
print(f"Val AUROC: {roc_auc_score(y_val, y_val_pred):.4f}")
print(f"Val AUPRC: {average_precision_score(y_val, y_val_pred):.4f}")
print(f"Test AUROC: {roc_auc_score(y_test, y_test_pred):.4f}")
print(f"Test AUPRC: {average_precision_score(y_test, y_test_pred):.4f}")
print(f"Test Accuracy: {accuracy_score(y_test, (y_test_pred > 0.5)):.4f}")
print(f"Test F1: {f1_score(y_test, (y_test_pred > 0.5)):.4f}")
print(f"Test Brier Score: {brier_score_loss(y_test, y_test_pred):.4f}")


Raw features shape: (7999, 5)
Train: (5599, 5), Val: (1400, 5), Test: (1000, 5)

=== Logistic Regression ===
Val AUROC: 0.5807
Val AUPRC: 0.3265
Test AUROC: 0.5698
Test AUPRC: 0.3177
Test Accuracy: 0.5800
Test F1: 0.3913
Test Brier Score: 0.2455

=== XGBoost ===
Val AUROC: 0.6988
Val AUPRC: 0.4504
Test AUROC: 0.6571
Test AUPRC: 0.4160
Test Accuracy: 0.6860
Test F1: 0.4332
Test Brier Score: 0.2244


Step-3: Setting up environment

In [None]:
import torch
import numpy as np
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


Step-4: Creating fusion level dataset

In [None]:
import pandas as pd
import os
from sklearn.preprocessing import OneHotEncoder

# ----------------------------
# Paths
# ----------------------------
base_dir = "/content/birads_preprocessed_dataset"
csv_path = os.path.join(base_dir, "image_df_preprocessed_fixed.csv")

# ----------------------------
# Load CSV and map labels
# ----------------------------
df = pd.read_csv(csv_path)
df["birads_binary"] = df["birads_binary"].map({"normal": 0, "abnormal": 1})

# ----------------------------
# Keep only complete studies and map images
# ----------------------------
study_groups = []
for study_id, group in df.groupby("study_id"):
    if len(group) == 4:  # L-CC, L-MLO, R-CC, R-MLO
        row = group.iloc[0].copy()  # representative row
        row["study_id"] = study_id

        # Assign split and case_category from first row
        row["split"] = group["split"].iloc[0]
        row["case_category"] = group["case_category"].iloc[0]

        # Map image paths and indices
        for idx, img_row in enumerate(group.itertuples()):
            pos_col_path = f"{img_row.laterality}_{img_row.view_position}_path"
            pos_col_idx  = f"{img_row.laterality}_{img_row.view_position}_idx"
            row[pos_col_path] = img_row.image_path
            row[pos_col_idx]  = idx  # 0-3 index within study

        study_groups.append(row)

fusion_df = pd.DataFrame(study_groups)

# ----------------------------
# One-hot encode density
# ----------------------------
ohe = OneHotEncoder(drop=None, sparse_output=False)
density_encoded = ohe.fit_transform(fusion_df[["density"]])
density_cols = ohe.get_feature_names_out(["density"])
density_df = pd.DataFrame(density_encoded, columns=density_cols)

# ----------------------------
# Build final study-level DataFrame
# ----------------------------
final_cols = ["study_id", "age", "birads_binary", "split", "case_category"]
# Collect all image path and index columns
image_path_cols = [col for col in fusion_df.columns if col.endswith("_path")]
image_idx_cols  = [col for col in fusion_df.columns if col.endswith("_idx")]

final_df = pd.concat([
    fusion_df[final_cols].reset_index(drop=True),
    density_df.reset_index(drop=True),
    fusion_df[image_path_cols + image_idx_cols].reset_index(drop=True)
], axis=1)

# ----------------------------
# Save CSV
# ----------------------------
final_csv_path = os.path.join(base_dir, "study_level_metadata.csv")
final_df.to_csv(final_csv_path, index=False)

print("Final study-level CSV saved to:", final_csv_path)
print("Shape:", final_df.shape)
print("Columns:", final_df.columns.tolist())


Final study-level CSV saved to: /content/birads_preprocessed_dataset/study_level_metadata.csv
Shape: (7999, 19)
Columns: ['study_id', 'age', 'birads_binary', 'split', 'case_category', 'density_A', 'density_B', 'density_C', 'density_D', 'image_path', 'preprocessed_path', 'L_MLO_path', 'L_CC_path', 'R_MLO_path', 'R_CC_path', 'L_MLO_idx', 'L_CC_idx', 'R_MLO_idx', 'R_CC_idx']


In [None]:
print(final_df.head())

                                    study_id  age  birads_binary     split  \
0           0025a5dc99fd5c742026f0b2b030d3e9   44              0      test   
1           0028fb2c7f0b3a5cb9a80cb0e1cdbb91   51              0  training   
2           0034765af074f93ed33d5e8399355caf   37              0  training   
3           003700f3c960e0b9bca2b8437c3dbf05   44              0  training   
4  003700f3c960e0b9bca2b8437c3dbf05_dup11243   44              0  training   

   case_category  density_A  density_B  density_C  density_D  \
0              0        0.0        0.0        1.0        0.0   
1              0        0.0        0.0        1.0        0.0   
2              0        0.0        0.0        1.0        0.0   
3              1        0.0        0.0        1.0        0.0   
4              1        0.0        0.0        1.0        0.0   

                                          image_path  \
0  /content/birads_preprocessed_dataset/test/norm...   
1  /content/birads_preprocessed_da

In [None]:
import pandas as pd
import os

# ----------------------------
# Load original study-level CSV (fusion CSV)
# ----------------------------
study_csv_path = "/content/birads_preprocessed_dataset/study_level_metadata.csv"
study_df = pd.read_csv(study_csv_path)

# ----------------------------
# Load image_df for comparison
# ----------------------------
image_df_path = "/content/drive/MyDrive/multimodal_mammography/dataset/csv/image_df_upsampled_studywise.csv"
image_df = pd.read_csv(image_df_path)

# ----------------------------
# 1️⃣ Check study counts
# ----------------------------
print("Total studies in study_level_metadata.csv:", len(study_df))
print("Unique studies in image_df:", image_df['study_id'].nunique())

# ----------------------------
# 2️⃣ Validate that each study has 4 images
# ----------------------------
# Count images per study in image_df
study_counts = image_df.groupby("study_id")["image_id"].count()
incomplete_studies = study_counts[study_counts != 4]
print("Incomplete studies in image_df (should be 0):", len(incomplete_studies))

# ----------------------------
# 3️⃣ Validate that paths exist in filesystem
# ----------------------------
missing_paths = []
for col in [c for c in study_df.columns if c.endswith("_path")]:
    for p in study_df[col]:
        if not os.path.exists(p):
            missing_paths.append(p)
print("Missing image files in study_level_metadata.csv:", len(missing_paths))

# ----------------------------
# 4️⃣ Optional: Check all studies match
# ----------------------------
merged = pd.merge(study_df[['study_id']], image_df[['study_id']].drop_duplicates(), on='study_id', how='outer', indicator=True)
print("Studies only in study_level_metadata.csv:", merged[merged['_merge']=='left_only'].shape[0])
print("Studies only in image_df:", merged[merged['_merge']=='right_only'].shape[0])


Total studies in study_level_metadata.csv: 7999
Unique studies in image_df: 7999
Incomplete studies in image_df (should be 0): 0
Missing image files in study_level_metadata.csv: 0
Studies only in study_level_metadata.csv: 0
Studies only in image_df: 0


Step-5: Multimodal Class and Dataloader

In [None]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import torchvision.transforms as T
from sklearn.model_selection import train_test_split
from collections import defaultdict

# ----------------------------
# Paths
# ----------------------------
BASE_DIR = "/content/birads_preprocessed_dataset"
CSV_PATH = os.path.join(BASE_DIR, "study_level_metadata.csv")

# ----------------------------
# Load CSV
# ----------------------------
df = pd.read_csv(CSV_PATH)
df["birads_binary"] = df["birads_binary"].astype(int)
df["case_category"] = df["case_category"].astype(int)

# ----------------------------
# Image transforms
# ----------------------------
IMAGE_TRANSFORMS = {
    "train": T.Compose([
        T.Grayscale(num_output_channels=3),
        T.Resize((224, 224)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(10),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406],
                    [0.229, 0.224, 0.225])
    ]),
    "val": T.Compose([
        T.Grayscale(num_output_channels=3),
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406],
                    [0.229, 0.224, 0.225])
    ]),
    "test": T.Compose([
        T.Grayscale(num_output_channels=3),
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406],
                    [0.229, 0.224, 0.225])
    ])
}

# ----------------------------
# Study-level dataset
# ----------------------------
class MammogramStudyDataset(Dataset):
    def __init__(self, df, split="training", transform=None, metadata_cols=None):
        self.df_split = df[df["split"] == split].copy()
        self.transform = transform
        self.metadata_cols = metadata_cols or ["age", "density_A", "density_B", "density_C", "density_D", "case_category"]

        # Prepare study dictionary
        self.study_groups = {}
        for _, row in self.df_split.iterrows():
            images = [row["L_CC_path"], row["L_MLO_path"], row["R_CC_path"], row["R_MLO_path"]]
            if all(os.path.exists(p) for p in images):
                self.study_groups[row["study_id"]] = {
                    "images": images,
                    "metadata": row[self.metadata_cols].values.astype(np.float32),
                    "label": row["birads_binary"],
                    "case_category": row["case_category"]
                }

        self.study_ids = list(self.study_groups.keys())

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        study = self.study_groups[study_id]

        images = [Image.open(p).convert("RGB") for p in study["images"]]
        if self.transform:
            images = [self.transform(img) for img in images]
        images = torch.stack(images)  # (4, C, H, W)

        metadata = torch.tensor(study["metadata"], dtype=torch.float32)
        label = torch.tensor(study["label"], dtype=torch.long)
        return images, metadata, label, study["case_category"], study_id

# ----------------------------
# Stratified batch sampler by case_category
# ----------------------------
class RepresentativeBatchSampler(Sampler):
    def __init__(self, dataset, batch_size=4, seed=42):
        self.dataset = dataset
        self.batch_size = batch_size
        random.seed(seed)

        # Group indices by case_category
        self.cat_indices = defaultdict(list)
        for idx, study_id in enumerate(dataset.study_ids):
            cat = dataset.study_groups[study_id]["case_category"]
            self.cat_indices[cat].append(idx)

        for k in self.cat_indices:
            random.shuffle(self.cat_indices[k])

    def __iter__(self):
        cat_lists = {k: list(v) for k, v in self.cat_indices.items()}
        batches = []

        while any(len(lst) > 0 for lst in cat_lists.values()):
            batch = []
            # pick one from each category if possible
            for cat in sorted(cat_lists.keys()):
                if cat_lists[cat]:
                    batch.append(cat_lists[cat].pop())
                if len(batch) == self.batch_size:
                    break
            # fill remaining slots if batch not full
            remaining = [i for lst in cat_lists.values() for i in lst]
            random.shuffle(remaining)
            while len(batch) < self.batch_size and remaining:
                batch.append(remaining.pop())
            if batch:
                batches.append(batch)

        random.shuffle(batches)
        return iter(batches)

    def __len__(self):
        return (len(self.dataset) + self.batch_size - 1) // self.batch_size

# ----------------------------
# Metadata columns
# ----------------------------
metadata_cols = ["age", "density_A", "density_B", "density_C", "density_D", "case_category"]

# ----------------------------
# Create datasets
# ----------------------------
train_dataset = MammogramStudyDataset(df, split="training", transform=IMAGE_TRANSFORMS["train"], metadata_cols=metadata_cols)
val_dataset   = MammogramStudyDataset(df, split="training", transform=IMAGE_TRANSFORMS["val"], metadata_cols=metadata_cols)
test_dataset  = MammogramStudyDataset(df, split="test", transform=IMAGE_TRANSFORMS["test"], metadata_cols=metadata_cols)

# Manual train/val split
train_ids, val_ids = train_test_split(train_dataset.study_ids, test_size=0.2, random_state=42)
train_dataset.study_ids = train_ids
val_dataset.study_ids = val_ids

# ----------------------------
# Dataloaders
# ----------------------------
batch_size = 4
train_loader = DataLoader(
    train_dataset,
    batch_sampler=RepresentativeBatchSampler(train_dataset, batch_size=batch_size),
    num_workers=2,
    pin_memory=True
)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# ----------------------------
# Summary
# ----------------------------
print("Train studies:", len(train_dataset))
print("Val studies:", len(val_dataset))
print("Test studies:", len(test_dataset))

# ----------------------------
# Test first batch
# ----------------------------
for imgs, metadata, labels, case_cat, study_ids in train_loader:
    print("Images:", imgs.shape)
    print("Metadata:", metadata.shape)
    print("Labels:", labels)
    print("Case categories:", case_cat)
    print("Study IDs:", study_ids)
    break


Train studies: 5599
Val studies: 1400
Test studies: 1000
Images: torch.Size([4, 4, 3, 224, 224])
Metadata: torch.Size([4, 6])
Labels: tensor([0, 1, 0, 0])
Case categories: tensor([0, 1, 0, 0])
Study IDs: ['40fd03a5bb87d107c5807c0f89076b0d', 'e4c3b57bc18e8474140b62bc2eb6c0c7_dup1215', '857e8fcdea4691ce1c5eb7ea7c849b9d', 'b2deb4a18f3ebd2ea976864f1117b00c']


Step-6: Model Defintion


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class StudyLevelMultimodalNet(nn.Module):
    def __init__(self, backbone="resnet50", pretrained=True, num_metadata_features=6, num_classes=1):
        super(StudyLevelMultimodalNet, self).__init__()

        # ----------------------------
        # Image backbone
        # ----------------------------
        if backbone == "resnet50":
            self.cnn = models.resnet50(pretrained=pretrained)
            in_features = self.cnn.fc.in_features
            self.cnn.fc = nn.Identity()  # remove final classification layer
        elif backbone == "efficientnet_b0":
            self.cnn = models.efficientnet_b0(pretrained=pretrained)
            in_features = self.cnn.classifier[1].in_features
            self.cnn.classifier = nn.Identity()
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")

        self.in_features = in_features
        self.num_metadata_features = num_metadata_features

        # ----------------------------
        # Study-level fusion + multimodal head
        # ----------------------------
        self.fc = nn.Sequential(
            nn.Linear(in_features + num_metadata_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)  # binary classification
        )

    def forward(self, images, metadata):
        """
        images: (batch_size, num_views, C, H, W)
        metadata: (batch_size, num_metadata_features)
        """
        batch_size, num_views, C, H, W = images.shape
        x = images.view(batch_size * num_views, C, H, W)

        # Extract features per view
        feats = self.cnn(x)  # (batch_size*num_views, in_features)
        feats = feats.view(batch_size, num_views, -1)

        # Fuse across views (mean pooling)
        fused_img_feats = feats.mean(dim=1)  # (batch_size, in_features)

        # Concatenate metadata
        fused = torch.cat([fused_img_feats, metadata], dim=1)  # (batch_size, in_features + num_metadata_features)

        # Classification head
        out = self.fc(fused).squeeze(1)  # (batch_size,)
        return out


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# ----------------------------
# Model setup
# ----------------------------
num_metadata_features = 6  # e.g., age + density one-hot
model = StudyLevelMultimodalNet(
    backbone="resnet50",
    pretrained=True,
    num_metadata_features=num_metadata_features,
    num_classes=1
)
model = model.to(device)

# ----------------------------
# Loss function
# ----------------------------
criterion = nn.BCEWithLogitsLoss()  # for binary classification

# ----------------------------
# Optimizer
# ----------------------------
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-4
)

# ----------------------------
# LR scheduler
# ----------------------------
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",     # maximize monitored metric (val AUROC)
    factor=0.5,     # reduce LR by 50%
    patience=2      # wait 2 epochs before reducing
)

print(f"Multimodal model ready on device: {device}")


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 164MB/s]


Multimodal model ready on device: cuda


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score, brier_score_loss

# -----------------------------
# Metrics helper
# -----------------------------
def compute_metrics(y_true, y_pred_probs):
    y_pred = (y_pred_probs >= 0.5).astype(int)
    return {
        "AUROC": roc_auc_score(y_true, y_pred_probs),
        "AUPRC": average_precision_score(y_true, y_pred_probs),
        "Accuracy": accuracy_score(y_true, y_pred),
        "F1": f1_score(y_true, y_pred),
        "Brier": brier_score_loss(y_true, y_pred_probs)
    }

# -----------------------------
# Training one epoch
# -----------------------------
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_labels, all_preds = [], []

    loop = tqdm(loader, desc="Training", leave=False)
    for images, metadata, labels, _, _ in loop:  # unpack multimodal batch
        images, metadata, labels = images.to(device), metadata.to(device), labels.float().to(device)

        optimizer.zero_grad()
        outputs = model(images, metadata)  # pass both modalities
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(torch.sigmoid(outputs).detach().cpu().numpy())

        loop.set_postfix(loss=loss.item())

    metrics = compute_metrics(np.array(all_labels), np.array(all_preds))
    return total_loss / len(loader.dataset), metrics

# -----------------------------
# Validation
# -----------------------------
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_labels, all_preds = []

    loop = tqdm(loader, desc="Validation", leave=False)
    with torch.no_grad():
        for images, metadata, labels, _, _ in loop:
            images, metadata, labels = images.to(device), metadata.to(device), labels.float().to(device)
            outputs = model(images, metadata)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(torch.sigmoid(outputs).cpu().numpy())

            loop.set_postfix(loss=loss.item())

    metrics = compute_metrics(np.array(all_labels), np.array(all_preds))
    return total_loss / len(loader.dataset), metrics

# -----------------------------
# Full training loop
# -----------------------------
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs=10):
    best_val_auroc = 0.0

    for epoch in range(1, epochs + 1):
        print(f"\n=== Epoch {epoch}/{epochs} ===")

        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_metrics = validate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Train AUROC: {train_metrics['AUROC']:.4f} | Val AUROC: {val_metrics['AUROC']:.4f}")
        print(f"Train Accuracy: {train_metrics['Accuracy']:.4f} | Val Accuracy: {val_metrics['Accuracy']:.4f}")

        # Step scheduler based on validation AUROC
        scheduler.step(val_metrics['AUROC'])

        # Save best model
        if val_metrics['AUROC'] > best_val_auroc:
            best_val_auroc = val_metrics['AUROC']
            torch.save(model.state_dict(), "best_multimodal_model.pth")
            print("✅ Saved new best model.")

    print(f"\nTraining complete. Best Val AUROC: {best_val_auroc:.4f}")


In [None]:
# -----------------------------
# Test evaluation (multimodal)
# -----------------------------
def evaluate_test(model, test_loader, device):
    model.eval()
    all_labels, all_preds = [], []

    loop = tqdm(test_loader, desc="Testing", leave=False)
    with torch.no_grad():
        for images, metadata, labels, _, _ in loop:  # unpack multimodal batch
            images, metadata, labels = images.to(device), metadata.to(device), labels.float().to(device)
            outputs = model(images, metadata)
            preds = torch.sigmoid(outputs)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    metrics = compute_metrics(all_labels, all_preds)

    print("\n=== Test Metrics ===")
    print(f"AUROC   : {metrics['AUROC']:.4f}")
    print(f"AUPRC   : {metrics['AUPRC']:.4f}")
    print(f"Accuracy: {metrics['Accuracy']:.4f}")
    print(f"F1      : {metrics['F1']:.4f}")
    print(f"Brier   : {metrics['Brier']:.4f}")
    return metrics


In [None]:
def run_training_multimodal(model, train_loader, val_loader, test_loader, device,
                            epochs=10, lr=1e-4, weight_decay=1e-4,
                            checkpoint_path="best_model.pth"):

    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)

    best_val_auroc = 0.0

    for epoch in range(1, epochs + 1):
        # -----------------------------
        # Training
        # -----------------------------
        model.train()
        train_loss = 0.0
        all_labels, all_preds = [], []

        loop = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [Train]", leave=False)
        for images, metadata, labels, _, _ in loop:  # unpack multimodal batch
            images, metadata, labels = images.to(device), metadata.to(device), labels.float().to(device)

            optimizer.zero_grad()
            outputs = model(images, metadata)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(torch.sigmoid(outputs).detach().cpu().numpy())

            loop.set_postfix(loss=loss.item())

        train_loss /= len(train_loader.dataset)
        train_metrics = compute_metrics(np.array(all_labels), np.array(all_preds))

        # -----------------------------
        # Validation
        # -----------------------------
        model.eval()
        val_loss = 0.0
        all_labels, all_preds = [], []

        loop = tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [Val]", leave=False)
        with torch.no_grad():
            for images, metadata, labels, _, _ in loop:
                images, metadata, labels = images.to(device), metadata.to(device), labels.float().to(device)
                outputs = model(images, metadata)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(torch.sigmoid(outputs).cpu().numpy())

                loop.set_postfix(loss=loss.item())

        val_loss /= len(val_loader.dataset)
        val_metrics = compute_metrics(np.array(all_labels), np.array(all_preds))

        print(f"Epoch {epoch}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
              f"Val AUROC: {val_metrics['AUROC']:.4f}")

        # -----------------------------
        # Scheduler & Checkpoint
        # -----------------------------
        scheduler.step(val_metrics['AUROC'])

        if val_metrics['AUROC'] > best_val_auroc:
            best_val_auroc = val_metrics['AUROC']
            torch.save(model.state_dict(), checkpoint_path)
            print(f"✅ Saved best model at epoch {epoch} (Val AUROC: {best_val_auroc:.4f})")

    # -----------------------------
    # Load best model & test
    # -----------------------------
    model.load_state_dict(torch.load(checkpoint_path))
    model.to(device)
    print("\n=== Evaluating on Test Set ===")
    test_metrics = evaluate_test(model, test_loader, device)

    return model, train_metrics, val_metrics, test_metrics


In [None]:
model = StudyLevelMultimodalNet(backbone="resnet50", pretrained=True).to(device)
trained_model, train_metrics, val_metrics, test_metrics = run_training_multimodal(
    model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    epochs=5,
    lr=1e-4,
    weight_decay=1e-4,
    checkpoint_path="best_multimodal_model.pth"
)




Epoch 1/5 | Train Loss: 0.9956 | Val Loss: 0.5858 | Val AUROC: 0.7939
✅ Saved best model at epoch 1 (Val AUROC: 0.7939)




Epoch 2/5 | Train Loss: 0.8765 | Val Loss: 0.5407 | Val AUROC: 0.8017
✅ Saved best model at epoch 2 (Val AUROC: 0.8017)




Epoch 3/5 | Train Loss: 0.7790 | Val Loss: 0.4652 | Val AUROC: 0.8136
✅ Saved best model at epoch 3 (Val AUROC: 0.8136)




Epoch 4/5 | Train Loss: 0.7408 | Val Loss: 0.4542 | Val AUROC: 0.8164
✅ Saved best model at epoch 4 (Val AUROC: 0.8164)




Epoch 5/5 | Train Loss: 0.7108 | Val Loss: 0.4468 | Val AUROC: 0.8186
✅ Saved best model at epoch 5 (Val AUROC: 0.8186)

=== Evaluating on Test Set ===


                                                          


=== Test Metrics ===
AUROC   : 0.7181
AUPRC   : 0.3267
Accuracy: 0.9040
F1      : 0.1579
Brier   : 0.0673


