In [None]:
# import necessary packages
!pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
#tfds.list_builders()
builder = tfds.builder('pneumonia_mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)

In [None]:
# Load train and test sets
ds_train, info_train = tfds.load('pneumonia_mnist', split='train', with_info=True)
ds_test, info_test = tfds.load('pneumonia_mnist', split='test', with_info=True)

In [None]:
#  Prepare training and testing data
train_images = []
train_labels = []
for example in ds_train.take(10):
    train_images.append(example['image'].numpy())
    train_labels.append(example['label'].numpy())

test_images = []
test_labels = []
for example in ds_test.take(10):
    test_images.append(example['image'].numpy())
    test_labels.append(example['label'].numpy())

In [None]:
train_images = np.array(train_images)
train_labels = np.array(train_labels)
test_images = np.array(test_images)
test_labels = np.array(test_labels)

In [None]:
train_images, test_images = train_images / 255.0, test_images / 255.0

# Plot some training images to verify
class_names = ['Normal', 'Pneumonia']
plt.figure(figsize=(10,10))
for i in range(train_images.shape[0]):
    plt.subplot(5,5,i+1)
    plt.xticks([]), plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()

In [None]:
# Create the CNN network
# Binary classification Normal (0), Pneumonia (1)
model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(2, activation='softmax')
])

In [None]:
# Dispaly the full CNN network's architecture
model.summary()

In [None]:
# Compile the model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))

In [None]:
# Visualise the loss curve during training
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.ylim([0, 1])
plt.legend(loc='upper right')
plt.show()

In [None]:
# Evaluate the model on the test set
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f'Test accuracy = {test_acc * 100:.2f}%')

In [None]:
y_pred = model.predict(test_images)
y_pred = np.argmax(y_pred, axis=1)

In [None]:
# calculate the evaluation metrics
from sklearn.metrics import confusion_matrix, classification_report

print("Confusion Matrix:")
print(confusion_matrix(test_labels, y_pred))

print("\nClassification Report:")
print(classification_report(test_labels, y_pred))

In [None]:
# Show one raw input image and its predicted output
# Select one raw image from the test set first image from test dataset

raw_image = test_images[0]

# Make prediction
prediction = model.predict(raw_image[np.newaxis, ...])
predicted_label = np.argmax(prediction, axis=1)[0]

# Map predicted and true labels to class names
class_names = ['Normal', 'Pneumonia']

predicted_class = class_names[predicted_label]

# Display the raw input image and the output with the predicted label
plt.figure

# Plot the raw input image
plt.subplot(1, 2, 1)
plt.imshow(raw_image.squeeze(), cmap=plt.cm.binary)
plt.title(f"Input\nRaw Input Image")
plt.axis('off')

# Plot the same image with the prediction
plt.subplot(1, 2, 2)
plt.imshow(raw_image.squeeze(), cmap=plt.cm.binary)
plt.title(f"Output\n{predicted_class}")
plt.axis('off')

plt.tight_layout()
plt.show()