In [None]:
import tensorflow as tf
from tensorflow.keras.applications import Xception
from tensorflow.keras.models import Model
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pandas as pd
from keras.callbacks import ModelCheckpoint
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical, Sequence
from keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras import regularizers
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from itertools import cycle
from sklearn.metrics import auc, roc_curve
from sklearn.metrics import RocCurveDisplay
from sklearn.manifold import TSNE

In [None]:
# Load the images from the .npy file
files_path = '/data/'   #put the correct path
X_train_images = np.load(files_path+'/X_train_images.npy')
X_val_images = np.load(files_path+'/X_val_images.npy')
X_test_images = np.load(files_path+'/X_test_images.npy')

y_train = np.load(files_path+'/y_train.npy')
y_val = np.load(files_path+'/y_val.npy')
y_test = np.load(files_path+'/y_test.npy')

# Check the shapes
print("Shape of the train images array:", X_train_images.shape)
print("Shape of the val images array:", X_val_images.shape)
print("Shape of the test images array:", X_test_images.shape)

print("Shape of the train labels array:", y_train.shape)
print("Shape of the val labels array:", y_val.shape)
print("Shape of the test labels array:", y_test.shape)

In [None]:
num_classes = 5

subtypes = ["Benign","HER2","LumA","LumB","TN"]

# Count the number of images per class
unique_classes, counts = np.unique(np.argmax(y_train,axis=1), return_counts=True)

# Create a DataFrame for seaborn
df = pd.DataFrame({'Class': unique_classes, 'Count': counts})

# Plot the bar chart
plt.figure(figsize=(10, 6))
sns.barplot(x='Class', y='Count', data=df, palette='viridis')
plt.xlabel('Class', fontsize=12)
plt.ylabel('Number of Images', fontsize=12)
plt.title('Distribution of Images by Class', fontsize=16)
plt.xticks(range(0,num_classes),subtypes)
plt.xticks(rotation=45)
plt.show()

In [None]:
def create_fusion_model(num_classes=5):
    # Image input branch
    image_input = tf.keras.Input(shape=(224, 224, 3))
    base_model = Xception(weights='imagenet', include_top=False, input_tensor=image_input)

    # Freeze early layers
    for layer in base_model.layers[:-20]:
        layer.trainable = False

    x = base_model.output
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    # Combine branches
    x = tf.keras.layers.Dense(256, activation='relu')(x)
    x = tf.keras.layers.Dense(128, activation='relu')(x)
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(x)

    model = tf.keras.Model(inputs=image_input, outputs=outputs)
    return model

In [None]:
def main():

    # Create and compile model
    model = create_fusion_model()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
        loss='categorical_crossentropy',
        metrics=[keras.metrics.CategoricalAccuracy(name="accuracy"),
            keras.metrics.AUC(name="AUC"),]
    )

    # Callbacks
    callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=20,
            restore_best_weights=True
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5
        )
    ]

    # Calculate class weights
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(np.argmax(y_train, axis=1)),
        y=np.argmax(y_train, axis=1)
    )
    class_weight_dict = dict(enumerate(class_weights))


    # Train the model
    history = model.fit(
        X_train_images,
        y_train,
        validation_data=(X_val_images, y_val),
        epochs= 10,
        batch_size=64,
        class_weight=class_weight_dict,
        callbacks=callbacks,
        verbose=1
    )

    # Evaluate the model
    val_loss, val_acc, val_auc = model.evaluate(
        X_val_images,
        y_val,
        verbose=1
    )

    print(f"\nValidation Accuracy: {val_acc:.4f}")
    print(f"Validation AUC: {val_auc:.4f}")

    return model, history

In [None]:
model, history = main()  

In [None]:
# Plot training and validation history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.tight_layout()
plt.show()


#Plot AUC figure
plt.figure(figsize=(10,8))
plt.plot(history.history['AUC'])
plt.plot(history.history['val_AUC'])
plt.title('model AUC')
plt.ylabel('AUC')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
y_pred = model.predict(X_test_images, verbose=0)

test_loss, test_acc, test_auc = model.evaluate(
        X_test_images,
        y_test,
        verbose=1
    )

print(f"\nTest Accuracy: {test_acc:.4f}")
print(f"\nTest AUC: {test_auc:.4f}")

predictions = y_pred
truth = y_test
target_names = subtypes

print(classification_report(np.argmax(truth, axis=1), np.argmax(predictions,axis=1), target_names=target_names))
print('*****************')

res = []
for l in [0,1,2,3,4]:
    prec,recall,_,_ = precision_recall_fscore_support(np.array(np.argmax(truth, axis=1))==l,
                                                      np.array(np.argmax(predictions,axis=1))==l,
                                                      pos_label=True,average=None)
    res.append([target_names[l],recall[0],recall[1]])
aa = pd.DataFrame(res,columns = ['class','sensitivity','specificity'])

print(aa)
print('*****************')



# Create confusion matrix
cmtx = confusion_matrix(np.argmax(truth, axis=1), np.argmax(predictions,axis=1))
print(cmtx)

def plot_conf(cm, label : str = "", figsize=(7,4)) :
    fig, ax = plt.subplots(figsize=figsize)         # Sample figsize in inches
    ax = sns.heatmap(cmtx, annot=True, linewidths=.5, ax=ax, cmap='Blues')
    ax.set_title('Confusion Matrix with labels\n\n');
    ax.set_xlabel('\nPredicted Values')
    ax.set_ylabel('Actual Values ');
    ## Ticket labels - List must be in alphabetical order
    ax.xaxis.set_ticklabels(target_names)
    ax.yaxis.set_ticklabels(target_names)
    
plot_conf(cmtx,target_names)  

In [None]:
# store the fpr, tpr, and roc_auc for all averaging strategies
fpr, tpr, roc_auc = dict(), dict(), dict()
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(truth.ravel(), predictions.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

for i in range(num_classes):
    fpr[i], tpr[i], _ = roc_curve(truth[:, i], predictions[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

fpr_grid = np.linspace(0.0, 1.0, 1000)

# Interpolate all ROC curves at these points
mean_tpr = np.zeros_like(fpr_grid)

for i in range(num_classes):
    mean_tpr += np.interp(fpr_grid, fpr[i], tpr[i])  # linear interpolation

# Average it and compute AUC
mean_tpr /= num_classes

fpr["macro"] = fpr_grid
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

fig, ax = plt.subplots(figsize=(8,8))


colors = cycle(["aqua", "darkorange", "cornflowerblue","red","green","yellow"])
for class_id, color in zip(range(num_classes), colors):
    RocCurveDisplay.from_predictions(
        truth[:, class_id],
        predictions[:, class_id],
        name=f"ROC curve for {target_names[class_id]}",
        color=color,
        ax=ax,
        plot_chance_level=(class_id == 2),
    )

_ = ax.set(
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
    title="Extension of Receiver Operating Characteristic\nto One-vs-Rest multiclass",
)

In [None]:
#Take subset of 20 images per class
samples_per_class = 20

# Initialize lists to store selected indices, images, and labels
selected_indices = []
new_images = []
new_labels = []

y_test_enc = np.argmax(y_test,axis=1)

# Iterate through each class
for class_label in range(5):  # Classes 0 to 4
    # Get indices of samples belonging to the current class
    class_indices = np.where(y_test_enc == class_label)[0]
    
    # Randomly select 15 indices from the current class
    if len(class_indices) >= samples_per_class:
        selected_class_indices = np.random.choice(class_indices, samples_per_class, replace=False)
    else:
        print(f"Warning: Class {class_label} has fewer than {samples_per_class} samples. Using all available samples.")
        selected_class_indices = class_indices
    
    # Add selected indices to the list
    selected_indices.extend(selected_class_indices)

# Extract the corresponding images, metadata, and labels
new_images = X_test_images[selected_indices]
new_labels = y_test_enc[selected_indices]

# Print shapes to verify
print("New images shape:", new_images.shape)
print("New labels shape:", new_labels.shape)

In [None]:
# Assuming the last layer of the model is "dense_3" 
last_conv_layer = model.get_layer("dense_3")
features_extractor = tf.keras.Model(inputs= model.input, outputs=last_conv_layer.output)

features = features_extractor(new_images) 

# Perform t-SNE dimensionality reduction
tsne = TSNE(n_components=2, perplexity=50).fit_transform(features)

def scale_to_01_range(x):
    value_range = (np.max(x) - np.min(x))
    starts_from_zero = x - np.min(x)
    return starts_from_zero / value_range


# Extract x and y coordinates
tx = tsne[:, 0]
ty = tsne[:, 1]

tx = scale_to_01_range(tx)
ty = scale_to_01_range(ty)

# Assuming lbls is a list or array containing class labels
unique_labels = np.unique(new_labels)  # Get unique class labels

# Define a colormap or color dictionary (optional)

colors_dict = {label: color for label, color in zip(unique_labels, ["red", "green", "blue", "purple","orange"])}

fig = plt.figure()
ax = fig.add_subplot(111)  

for i, label in enumerate(unique_labels):
    # Find indices for current class
    indices = [i for i, l in enumerate(new_labels) if l == label]
    print(indices)
    # Extract coordinates and color based on label
    current_tx = np.take(tx, indices)
    current_ty = np.take(ty, indices)
    if isinstance(colors_dict, dict):
        color = colors_dict[label]
    else:
        color = cmap(i / len(unique_labels))  # Color based on colormap

    # Add scatter plot
    ll = target_names[label]
    ax.scatter(current_tx, current_ty, c=color, marker='.',label=ll)

ax.legend(loc='best')
plt.show()

# The end! 