# Breast Cancer Multimodal Classifier

Ce notebook construit un modèle de classification du cancer du sein en TensorFlow en combinant :
- des images 4 canaux (mammographies Gauche/Droite, LE/SUB),
- des métadonnées cliniques (Excel),
- des rapports médicaux (texte).

**Variable cible : `Pathology`**

In [116]:
# 📦 Imports
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, GlobalAveragePooling2D, Concatenate, Embedding, Bidirectional, GRU
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
import json
from glob import glob
from PIL import Image

## 🔄 Charger et préparer les données (images + Excel + JSON)

In [117]:
# 📁 Dossiers (à adapter selon ton organisation)
image_dir_le = 'CDD-CESM/PKG - CDD-CESM/CDD-CESM/Low energy images of CDD-CESM'
image_dir_sub = 'CDD-CESM/PKG - CDD-CESM/CDD-CESM/Subtracted images of CDD-CESM'
json_dir = 'CDD-CESM/json_output'
excel_path = 'processed_metadata.csv'

# 📂 Exemple: structure d'un échantillon
# images/patient123_Left_LE.png, images/patient123_Left_SUB.png, ...
# json_reports/patient123.json

def load_image_stack(patient_id):
    paths = [
        f"{image_dir_le}/P{patient_id}_L_DM_CC.jpg",
        f"{image_dir_le}/P{patient_id}_L_DM_MLO.jpg",
        f"{image_dir_sub}/P{patient_id}_L_CM_CC.jpg",
        f"{image_dir_sub}/P{patient_id}_L_CM_MLO.jpg",
        f"{image_dir_le}/P{patient_id}_R_DM_CC.jpg",
        f"{image_dir_le}/P{patient_id}_R_DM_MLO.jpg",
        f"{image_dir_sub}/P{patient_id}_R_CM_CC.jpg",
        f"{image_dir_sub}/P{patient_id}_R_CM_MLO.jpg",
    ]
    imgs = [np.array(Image.open(p).resize((224, 224))) for p in paths]
    return np.stack(imgs, axis=-1) / 255.0

In [118]:
# 🔢 Charger Excel (métadonnées)
meta_df = pd.read_csv(excel_path)

# 🧹 Nettoyage minimal
meta_df = meta_df.dropna(subset=['Patient_ID', 'Pathology Classification/ Follow up'])
meta_df['Patient_ID'] = meta_df['Patient_ID'].astype(str)

# ⚙️ Encodage et normalisation
numerical = meta_df.select_dtypes(include=['float', 'int']).columns.tolist()
categorical = meta_df.select_dtypes(include=['object']).drop(columns=['Patient_ID', 'Pathology Classification/ Follow up']).columns.tolist()

scaler = StandardScaler()
encoder = OneHotEncoder(sparse_output=False)

meta_num = scaler.fit_transform(meta_df[numerical])
meta_cat = encoder.fit_transform(meta_df[categorical])
meta_features = np.concatenate([meta_num, meta_cat], axis=1)

In [119]:
# --- Étape 1 : charger les rapports texte pour TOUS les patients (ordre = meta_df)
texts = []
for pid in meta_df['Patient_ID']:
    path = os.path.join(json_dir, f"P{pid}.json")
    if os.path.exists(path):
        with open(path, encoding="utf-8") as f:
            try:
                d = json.load(f)
                flat_text = []
                for v in d.values():
                    flat_text.extend(map(str, v) if isinstance(v, list) else [str(v)])
                texts.append(" ".join(flat_text))
            except Exception as e:
                print(f"Erreur JSON {pid}: {e}")
                texts.append("")
    else:
        texts.append("")




In [120]:
# --- Étape 2 : échantillonner 10 % des patients
sampled_df = meta_df.sample(frac=0.1, random_state=42)

# get sampled patient IDs
sampled_patient_ids = sampled_df['Patient_ID'].astype(str).tolist()


In [121]:

# --- Étape 3 : charger les images valides et enregistrer les PID correspondants
images = []
valid_labels = []
valid_pids = []

for pid, label in zip(sampled_df['Patient_ID'], sampled_df['Pathology Classification/ Follow up']):
    try:
        img = load_image_stack(pid)
        images.append(img)
        valid_labels.append(label)
        valid_pids.append(pid)
    except:
        continue  # Ignore les erreurs (images manquantes etc.)

images = np.array(images)


In [None]:
# Assuming you have lists or arrays for images, labels, texts, and metadata features
# and a DataFrame meta_df with a 'Patient_ID' column

# Extract patient IDs from each dataset
image_patient_ids = set([pid for pid, img in zip(sampled_patient_ids, images)])
text_patient_ids = set([pid for pid, text in zip(sampled_patient_ids, texts)])
metadata_patient_ids = set(meta_df['Patient_ID'].unique())

# Find the intersection of all patient IDs
common_patient_ids = list(image_patient_ids.intersection(text_patient_ids).intersection(metadata_patient_ids))

# Create a dictionary to map patient IDs to their corresponding images, labels, and texts
image_dict = {pid: img for pid, img in zip(sampled_patient_ids, images) if pid in common_patient_ids}
label_dict = {pid: label for pid, label in zip(sampled_patient_ids, valid_labels) if pid in common_patient_ids}
text_dict = {pid: text for pid, text in zip(sampled_patient_ids, texts) if pid in common_patient_ids}
meta_dict = {pid: meta_features[i] for i, pid in enumerate(sampled_patient_ids) if pid in common_patient_ids}

# Filter metadata features for common patient IDs
meta_features_filtered = meta_features[meta_df['Patient_ID'].isin(common_patient_ids)]

# Ensure that each patient ID appears only once and corresponds correctly
images_filtered = [image_dict[pid] for pid in common_patient_ids]
labels_filtered = [label_dict[pid] for pid in common_patient_ids]
texts_filtered = [text_dict[pid] for pid in common_patient_ids]
meta_features_filtered = [meta_dict[pid] for pid in common_patient_ids]

# Check the number of validated samples
print(f"Validated Images: {len(images_filtered)} out of {len(sampled_patient_ids)} samples")
print(f"Validated Labels: {len(labels_filtered)} out of {len(sampled_patient_ids)} samples")
print(f"Validated Texts: {len(texts_filtered)} out of {len(sampled_patient_ids)} samples")
print(f"Validated Metadata Features: {len(meta_features_filtered)} out of {len(sampled_patient_ids)} samples")

# Ensure all filtered datasets have the same number of samples
assert len(images_filtered) == len(labels_filtered) == len(texts_filtered) == len(meta_features_filtered), \
    "Mismatch in the number of samples after filtering."


Validated Images: 122 out of 201 samples
Validated Labels: 122 out of 201 samples
Validated Texts: 122 out of 201 samples
Validated Metadata Features: 122 out of 201 samples


In [129]:

# --- Étape 5 : vectoriser les textes (fit sur corpus réduit)
vectorizer = TfidfVectorizer(max_features=1000)
text_features_filtered = vectorizer.fit_transform(texts_filtered).toarray()


In [131]:

# --- Étape 6 : scaler et encoder uniquement sur meta_df_filtered
numerical = meta_features_filtered.select_dtypes(include=['float', 'int']).columns.tolist()
categorical = meta_features_filtered.select_dtypes(include=['object']).drop(columns=['Patient_ID', 'Pathology Classification/ Follow up']).columns.tolist()

scaler = StandardScaler()
encoder = OneHotEncoder(sparse_output=False)

meta_num_filtered = scaler.fit_transform(meta_features_filtered[numerical])
meta_cat_filtered = encoder.fit_transform(meta_features_filtered[categorical])
print(f"Nombre de caractéristiques numériques : {meta_num_filtered.shape[1]}")
print(f"Nombre de caractéristiques catégorielles : {meta_cat_filtered.shape[1]}")
meta_features_filtered = np.concatenate([meta_num_filtered, meta_cat_filtered], axis=1)


AttributeError: 'numpy.ndarray' object has no attribute 'select_dtypes'

In [None]:

# --- Étape 7 : encoder les labels
valid_labels_encoded, label_names = pd.factorize(valid_labels)
labels_cat = to_categorical(valid_labels_encoded)


In [None]:

# # --- Étape 8 : vérification finale
# assert len(images) == meta_features_filtered.shape[0] == text_features_filtered.shape[0] == labels_cat.shape[0], \
#     f"Mismatch: images={len(images)}, meta={meta_features_filtered.shape[0]}, text={text_features_filtered.shape[0]}, labels={labels_cat.shape[0]}"

# print("✅ Données bien alignées :", len(images), "exemples valides")


## 🧠 Définir l’architecture multimodale

In [None]:
# Entrées
img_input = Input(shape=(224, 224, 4))
meta_input = Input(shape=(meta_features.shape[1],))
text_input = Input(shape=(text_features.shape[1],))

# Branch image
x_img = Conv2D(32, (3, 3), activation='relu')(img_input)
x_img = MaxPooling2D()(x_img)
x_img = Conv2D(64, (3, 3), activation='relu')(x_img)
x_img = GlobalAveragePooling2D()(x_img)

# Branch metadata
x_meta = Dense(64, activation='relu')(meta_input)

# Branch texte (TF-IDF direct)
x_text = Dense(64, activation='relu')(text_input)

# Fusion
x = Concatenate()([x_img, x_meta, x_text])
x = Dense(128, activation='relu')(x)
output = Dense(len(label_names), activation='softmax')(x)

model = Model(inputs=[img_input, meta_input, text_input], outputs=output)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

## 🚂 Entraînement

In [None]:
# 🔀 Split
X_img_train, X_img_val, X_meta_train, X_meta_val, X_txt_train, X_txt_val, y_train, y_val = train_test_split(
    images, meta_features_filtered, text_features_filtered, labels_cat, test_size=0.2, random_state=42
)


# 🚀 Fit
history = model.fit(
    [X_img_train, X_meta_train, X_txt_train], y_train,
    validation_data=([X_img_val, X_meta_val, X_txt_val], y_val),
    epochs=10, batch_size=16
)

In [None]:

# --- Étape 5 : vectoriser les textes (fit sur corpus réduit)
vectorizer = TfidfVectorizer(max_features=1000)
text_features_filtered = vectorizer.fit_transform(texts_filtered).toarray()


In [None]:

# --- Étape 6 : scaler et encoder uniquement sur meta_df_filtered
numerical = meta_df_filtered.select_dtypes(include=['float', 'int']).columns.tolist()
categorical = meta_df_filtered.select_dtypes(include=['object']).drop(columns=['Patient_ID', 'Pathology Classification/ Follow up']).columns.tolist()

scaler = StandardScaler()
encoder = OneHotEncoder(sparse_output=False)

meta_num_filtered = scaler.fit_transform(meta_df_filtered[numerical])
meta_cat_filtered = encoder.fit_transform(meta_df_filtered[categorical])
print(f"Nombre de caractéristiques numériques : {meta_num_filtered.shape[1]}")
print(f"Nombre de caractéristiques catégorielles : {meta_cat_filtered.shape[1]}")
meta_features_filtered = np.concatenate([meta_num_filtered, meta_cat_filtered], axis=1)


In [None]:

# --- Étape 7 : encoder les labels
valid_labels_encoded, label_names = pd.factorize(valid_labels)
labels_cat = to_categorical(valid_labels_encoded)


In [None]:

# # --- Étape 8 : vérification finale
# assert len(images) == meta_features_filtered.shape[0] == text_features_filtered.shape[0] == labels_cat.shape[0], \
#     f"Mismatch: images={len(images)}, meta={meta_features_filtered.shape[0]}, text={text_features_filtered.shape[0]}, labels={labels_cat.shape[0]}"

# print("✅ Données bien alignées :", len(images), "exemples valides")


## 🧠 Définir l’architecture multimodale

In [None]:
# Entrées
img_input = Input(shape=(224, 224, 4))
meta_input = Input(shape=(meta_features.shape[1],))
text_input = Input(shape=(text_features.shape[1],))

# Branch image
x_img = Conv2D(32, (3, 3), activation='relu')(img_input)
x_img = MaxPooling2D()(x_img)
x_img = Conv2D(64, (3, 3), activation='relu')(x_img)
x_img = GlobalAveragePooling2D()(x_img)

# Branch metadata
x_meta = Dense(64, activation='relu')(meta_input)

# Branch texte (TF-IDF direct)
x_text = Dense(64, activation='relu')(text_input)

# Fusion
x = Concatenate()([x_img, x_meta, x_text])
x = Dense(128, activation='relu')(x)
output = Dense(len(label_names), activation='softmax')(x)

model = Model(inputs=[img_input, meta_input, text_input], outputs=output)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

## 🚂 Entraînement

In [None]:
# 🔀 Split
X_img_train, X_img_val, X_meta_train, X_meta_val, X_txt_train, X_txt_val, y_train, y_val = train_test_split(
    images, meta_features_filtered, text_features_filtered, labels_cat, test_size=0.2, random_state=42
)


# 🚀 Fit
history = model.fit(
    [X_img_train, X_meta_train, X_txt_train], y_train,
    validation_data=([X_img_val, X_meta_val, X_txt_val], y_val),
    epochs=10, batch_size=16
)