In [3]:
import tensorflow as tf
tf.keras.backend.clear_session()





In [4]:
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.efficientnet import EfficientNetB0, preprocess_input

BASE_DIR = r"C:/Users/raksh/x-ai_chest/data/chest_xray_multi"
TRAIN_DIR = os.path.join(BASE_DIR, "train")
VAL_DIR   = os.path.join(BASE_DIR, "val")
TEST_DIR  = os.path.join(BASE_DIR, "test")

IMG_SIZE = (224, 224)
BATCH_SIZE = 32

train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=10,
    width_shift_range=0.05,
    height_shift_range=0.05,
    zoom_range=0.1,
    horizontal_flip=True
)

val_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

test_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

train_gen = train_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="categorical"
)

val_gen = val_datagen.flow_from_directory(
    VAL_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="categorical"
)

test_gen = test_datagen.flow_from_directory(
    TEST_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    shuffle=False
)

num_classes = train_gen.num_classes
print("Classes:", train_gen.class_indices)


Found 1622 images belonging to 4 classes.
Found 327 images belonging to 4 classes.
Found 326 images belonging to 4 classes.
Classes: {'COVID19': 0, 'NORMAL': 1, 'PNEUMONIA': 2, 'TURBERCULOSIS': 3}


In [5]:
from tensorflow.keras import layers, models

base_model = EfficientNetB0(
    include_top=False,
    weights="imagenet",
    input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)
)

# Phase 1: use EfficientNet as frozen feature extractor
base_model.trainable = False

inputs = layers.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)

model = models.Model(inputs, outputs)
model.summary()


In [6]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),  # NOTE: 1e-3
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        "backend/saved_models/chest_multidisease_phase1.keras",
        monitor="val_accuracy",
        save_best_only=True,
        mode="max"
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="val_accuracy",
        patience=5,
        restore_best_weights=True
    )
]

history1 = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=15,
    callbacks=callbacks
)


Epoch 1/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m107s[0m 2s/step - accuracy: 0.7965 - loss: 0.5028 - val_accuracy: 0.9174 - val_loss: 0.2099
Epoch 2/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 2s/step - accuracy: 0.9143 - loss: 0.2394 - val_accuracy: 0.9419 - val_loss: 0.1598
Epoch 3/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m104s[0m 2s/step - accuracy: 0.9328 - loss: 0.1899 - val_accuracy: 0.9266 - val_loss: 0.1949
Epoch 4/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 2s/step - accuracy: 0.9383 - loss: 0.1656 - val_accuracy: 0.9786 - val_loss: 0.0895
Epoch 5/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 2s/step - accuracy: 0.9427 - loss: 0.1629 - val_accuracy: 0.9694 - val_loss: 0.0876
Epoch 6/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m108s[0m 2s/step - accuracy: 0.9371 - loss: 0.1615 - val_accuracy: 0.9174 - val_loss: 0.2037
Epoch 7/15
[1m51/51[0m [32m━━━━

In [7]:
# Unfreeze some of the base_model layers
base_model.trainable = True

for layer in base_model.layers[:-50]:   # freeze all but last 50
    layer.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),  # small LR
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

callbacks_ft = [
    tf.keras.callbacks.ModelCheckpoint(
        "backend/saved_models/chest_multidisease_ft.keras",
        monitor="val_accuracy",
        save_best_only=True,
        mode="max"
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="val_accuracy",
        patience=5,
        restore_best_weights=True
    )
]

history2 = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=15,
    callbacks=callbacks_ft
)


Epoch 1/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m124s[0m 2s/step - accuracy: 0.8724 - loss: 0.3423 - val_accuracy: 0.9755 - val_loss: 0.0830
Epoch 2/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 2s/step - accuracy: 0.9069 - loss: 0.2460 - val_accuracy: 0.9664 - val_loss: 0.1092
Epoch 3/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 2s/step - accuracy: 0.9303 - loss: 0.2083 - val_accuracy: 0.9511 - val_loss: 0.1219
Epoch 4/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 2s/step - accuracy: 0.9205 - loss: 0.1996 - val_accuracy: 0.9602 - val_loss: 0.1217
Epoch 5/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 2s/step - accuracy: 0.9390 - loss: 0.1781 - val_accuracy: 0.9572 - val_loss: 0.1188
Epoch 6/15
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 2s/step - accuracy: 0.9464 - loss: 0.1662 - val_accuracy: 0.9602 - val_loss: 0.1165


In [8]:
test_loss, test_acc = model.evaluate(test_gen)
print("Test accuracy:", test_acc)


[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 1s/step - accuracy: 0.9080 - loss: 0.2529
Test accuracy: 0.907975435256958


In [9]:
model.save("backend/saved_models/chest_multidisease_ft.keras")


In [14]:
import numpy as np

x_batch, y_batch = next(test_gen)
preds = model.predict(x_batch)
pred_labels = np.argmax(preds, axis=1)
true_labels = np.argmax(y_batch, axis=1)

print("Predicted:", pred_labels[:20])
print("True     :", true_labels[:20])
print("Class map:", train_gen.class_indices)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
Predicted: [0 2 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 2 1 1]
True     : [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
Class map: {'COVID19': 0, 'NORMAL': 1, 'PNEUMONIA': 2, 'TURBERCULOSIS': 3}
