In [21]:
import tensorflow as tf
from tensorflow.keras.applications import Xception
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [22]:
root_dir = "COVID-19_Radiography_Dataset/"
classes = ["COVID", "Normal", "Lung Opacity", "Viral Pneumonia"]
num_classes = 4

In [23]:
base_model = Xception(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

for layer in base_model.layers:
    layer.trainable = False

In [24]:
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

In [25]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [26]:
train_datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    validation_split=0.25,
    horizontal_flip=True
)

train_generator = train_datagen.flow_from_directory(
    root_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset="training"
)


validation_generator = train_datagen.flow_from_directory(
    root_dir,
    target_size=(244, 224),
    batch_size=32,
    class_mode='categorical',
    subset="validation"
)

Found 31748 images belonging to 4 classes.
Found 10582 images belonging to 4 classes.


In [None]:
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // train_generator.batch_size,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // validation_generator.batch_size,
    epochs=10
)

In [None]:
model.save("covid19.h5")

In [None]:
history_df = pd.DataFrame(history.history)
history_df.head()

In [None]:
test_loss, test_acc = model.evaluate(validation_generator)
print('Test accuracy:', test_acc)

In [None]:
plt.figure()
plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend(["train", "val"])
plt.show()

In [None]:
plt.figure()
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend(["train", "val"])
plt.show()