# AI for Improved Medical Diagnostics using Multi-Modal Data

**Copyright (c) 2026 Shrikara Kaudambady. All rights reserved.**

This notebook demonstrates how to build a **multi-modal neural network** for medical diagnostics. The model learns from two different types of data simultaneously—structured **patient data** (like age and blood pressure) and **image data** (like medical scans)—to make a more accurate diagnostic prediction than either data type could alone.

### 1. Setup and Library Imports

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, confusion_matrix
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="whitegrid")

### 2. Data Simulation
We will generate a synthetic dataset containing both tabular patient data and corresponding medical scan images. We will create two distinct groups: 'healthy' and 'diseased'.

In [None]:
def generate_data(n_samples=1000):
    # Generate Tabular Data
    # Healthy group
    healthy_age = np.random.normal(40, 10, n_samples // 2)
    healthy_bp = np.random.normal(120, 10, n_samples // 2)
    healthy_chol = np.random.normal(200, 20, n_samples // 2)
    
    # Diseased group (statistically different)
    diseased_age = np.random.normal(60, 12, n_samples // 2)
    diseased_bp = np.random.normal(145, 15, n_samples // 2)
    diseased_chol = np.random.normal(240, 25, n_samples // 2)
    
    ages = np.concatenate([healthy_age, diseased_age])
    bps = np.concatenate([healthy_bp, diseased_bp])
    chols = np.concatenate([healthy_chol, diseased_chol])
    labels = np.array([0] * (n_samples // 2) + [1] * (n_samples // 2))
    
    tabular_data = pd.DataFrame({'age': ages, 'blood_pressure': bps, 'cholesterol': chols})
    
    # Generate Image Data
    def create_scan(is_diseased, img_size=64):
        img = np.zeros((img_size, img_size))
        center_x, center_y = img_size // 2, img_size // 2
        radius = img_size // 4
        Y, X = np.ogrid[:img_size, :img_size]
        dist_from_center = np.sqrt((X - center_x)**2 + (Y-center_y)**2)
        mask = dist_from_center <= radius
        img[mask] = 1.0
        if is_diseased:
            # Add noise and irregularities for diseased scans
            noise = np.random.normal(0, 0.3, img.shape)
            img += noise
            img[dist_from_center <= radius / 2] = 0.5 # Add a central lesion
        return np.clip(img, 0, 1)

    images = np.array([create_scan(label) for label in labels])
    images = images.reshape(-1, 64, 64, 1) # Reshape for CNN
    
    # Shuffle all data consistently
    indices = np.arange(n_samples)
    np.random.shuffle(indices)
    return tabular_data.iloc[indices], images[indices], labels[indices]

tabular_data, image_data, labels = generate_data()

print("Data Simulation Complete.")
print("Tabular Data Shape:", tabular_data.shape)
print("Image Data Shape:", image_data.shape)
print("Labels Shape:", labels.shape)

# Visualize an example of each scan type
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(image_data[labels==0][0].squeeze(), cmap='bone')
axes[0].set_title('Healthy Scan Example')
axes[1].imshow(image_data[labels==1][0].squeeze(), cmap='bone')
axes[1].set_title('Diseased Scan Example')
plt.show()

### 3. Data Preprocessing and Splitting
We scale the tabular data and split both datasets into synchronized training and testing sets.

In [None]:
# Scale tabular data
scaler = MinMaxScaler()
tabular_scaled = scaler.fit_transform(tabular_data)

# Normalize image data (already in [0, 1] but good practice)
images_normalized = image_data.astype('float32')

# Split the data
X_train_tab, X_test_tab, X_train_img, X_test_img, y_train, y_test = train_test_split(
    tabular_scaled, images_normalized, labels, test_size=0.2, random_state=42
)

### 4. Build the Multi-Modal Model
We define the two-branch architecture. One branch is a CNN for images, the other is an MLP for tabular data. Their outputs are then concatenated and fed into a final classifier.

In [None]:
# Tabular Branch (MLP)
tabular_input = Input(shape=(X_train_tab.shape[1],), name='tabular_input')
x = layers.Dense(16, activation='relu')(tabular_input)
x = layers.Dense(8, activation='relu')(x)
tabular_output = Model(inputs=tabular_input, outputs=x)

# Image Branch (CNN)
image_input = Input(shape=X_train_img.shape[1:], name='image_input')
y = layers.Conv2D(16, (3, 3), activation='relu')(image_input)
y = layers.MaxPooling2D((2, 2))(y)
y = layers.Conv2D(32, (3, 3), activation='relu')(y)
y = layers.MaxPooling2D((2, 2))(y)
y = layers.Flatten()(y)
y = layers.Dense(16, activation='relu')(y)
image_output = Model(inputs=image_input, outputs=y)

# Concatenate the outputs of the two branches
combined = layers.concatenate([tabular_output.output, image_output.output])

# Final classifier head
z = layers.Dense(8, activation='relu')(combined)
z = layers.Dense(1, activation='sigmoid')(z)

# Create the final model
multimodal_model = Model(inputs=[tabular_input, image_input], outputs=z)

multimodal_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
multimodal_model.summary()

### 5. Train and Evaluate the Model
We train the model using both sets of inputs and then evaluate its performance on the test set.

In [None]:
print("Training the multi-modal model...")
history = multimodal_model.fit(
    [X_train_tab, X_train_img],
    y_train,
    validation_split=0.2,
    epochs=20,
    batch_size=32,
    verbose=1
)

# Evaluate on the test set
loss, accuracy = multimodal_model.evaluate([X_test_tab, X_test_img], y_test)
print(f"\nTest Accuracy: {accuracy*100:.2f}%")

# Confusion Matrix
y_pred_prob = multimodal_model.predict([X_test_tab, X_test_img])
y_pred = (y_pred_prob > 0.5).astype('int32')
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()