Imports + initialize tpu/gpu

In [1]:
import tensorflow as tf
assert float(tf.__version__[:3]) >= 2.3

import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    GCS_PATH = 'gs://kds-730dfb73e90bcfe94bba623cbc90984df476fa507c3e44b785ea223e/'
except:
    print('WARNING: Not connected to a TPU runtime')
    strategy = tf.distribute.get_strategy()
    GCS_PATH = ''



Import Pneumonia Dataset (from google store for tpu) + set variables (batch, epoch, image size, tune)

In [2]:
filenames = tf.io.gfile.glob(str(GCS_PATH + 'chest_xray/train/*/*'))
filenames.extend(tf.io.gfile.glob(str(GCS_PATH + 'chest_xray/val/*/*')))
train_filenames, val_filenames = train_test_split(filenames, test_size=0.2)

AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMG_SIZE = 224
EPOCHS = 15

Functions for processing images

In [3]:
def get_label(file_path):
    # convert the path to a list of path components
    parts = tf.strings.split(file_path, os.path.sep)
    # The second to last is the class-directory
    return parts[-2] == "PNEUMONIA"

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_SIZE, IMG_SIZE])

def process_path(file_path):
    label = get_label(file_path)
    # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img, label

Use buffered prefetching so we can yield data from disk without having I/O blocking

In [4]:
def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
    # This is a small dataset, only load it once, and keep it in memory.
    # use `.cache(filename)` to cache preprocessing work for datasets that don't
    # fit in memory.
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
        else:
            ds = ds.cache()

    ds = ds.shuffle(buffer_size=shuffle_buffer_size)
    ds = ds.repeat()
    ds = ds.batch(BATCH_SIZE)
    ds = ds.prefetch(buffer_size=AUTOTUNE)

    return ds

Load + Process train/test data

In [5]:
COUNT_NORMAL = len([filename for filename in train_filenames if "NORMAL" in filename])
COUNT_PNEUMONIA = len([filename for filename in train_filenames if "PNEUMONIA" in filename])
TRAIN_IMG_COUNT = tf.data.experimental.cardinality(tf.data.Dataset.from_tensor_slices(train_filenames)).numpy()
VAL_IMG_COUNT = tf.data.experimental.cardinality(tf.data.Dataset.from_tensor_slices(val_filenames)).numpy()

In [6]:
test_ds = tf.data.Dataset.list_files(str(GCS_PATH + 'chest_xray/test/*/*')).map(process_path, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE)

train_ds = prepare_for_training(tf.data.Dataset.from_tensor_slices(train_filenames).map(process_path, num_parallel_calls=AUTOTUNE))
val_ds = prepare_for_training(tf.data.Dataset.from_tensor_slices(val_filenames).map(process_path, num_parallel_calls=AUTOTUNE))

image_batch, label_batch = next(iter(train_ds))

2022-03-02 20:03:40.057232: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2022-03-02 20:03:42.276237: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Data was imbalanced, with more images classified as pneumonia than normal. Each normal image will be weighted more to balance the data as the CNN works best when the training data is balanced.

In [7]:
weight_for_0 = (1 / COUNT_NORMAL)*(TRAIN_IMG_COUNT)/2.0 
weight_for_1 = (1 / COUNT_PNEUMONIA)*(TRAIN_IMG_COUNT)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

Retrain moblenetv2 model for use with dataset (transfer learning).

Since there are only two possible labels for the image, we will be using the `binary_crossentropy` loss. When we fit the model, identify the class weights. Because we are using a TPU, training will be relatively quick.

In [8]:
with strategy.scope():
    base_model = tf.keras.applications.MobileNetV2(weights='imagenet', input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False)
    base_model.trainable = False
    model = tf.keras.Sequential([
      base_model,
      tf.keras.layers.GlobalAveragePooling2D(),
      tf.keras.layers.Dropout(0.2),
      tf.keras.layers.Dense(1),
    ])
    model.compile(
        optimizer = 'adam',
        loss = 'binary_crossentropy',
        metrics = ['accuracy']
    )
    
history = model.fit(
    train_ds,
    steps_per_epoch=TRAIN_IMG_COUNT // BATCH_SIZE,
    epochs=int((EPOCHS/3)*2),
    validation_data=val_ds,
    validation_steps=VAL_IMG_COUNT // BATCH_SIZE,
    class_weight=class_weight,
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [9]:
base_model.trainable = True
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False

model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
              loss='binary_crossentropy',
              metrics=['accuracy'])

history_fine = model.fit(
    train_ds,
    steps_per_epoch=TRAIN_IMG_COUNT // BATCH_SIZE,
    epochs=EPOCHS,
    initial_epoch=history.epoch[-1],
    validation_data=val_ds,
    validation_steps=VAL_IMG_COUNT // BATCH_SIZE,
    class_weight=class_weight
)

Epoch 10/15
Epoch 11/15

KeyboardInterrupt: 

Plot Results

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

plt.figure(figsize=(8, 8), tight_layout=True)
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

plt.figure(figsize=(8, 8), tight_layout=True)
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.plot([int((EPOCHS/3)*2),int((EPOCHS/3)*2)],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

Convert model to TFLite

In [None]:
def representative_data_gen():
    dataset_list = tf.data.Dataset.list_files(GCS_PATH + "chest_xray/train/*/*")
    for i in range(100):
        path = next(iter(dataset_list))
        file_bytes = tf.io.read_file(path)
        img = tf.io.decode_jpeg(file_bytes, channels=3)
        img = tf.cast(img, tf.float32) / 255.0 
        img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
        img = tf.expand_dims(img, 0)
        yield [img]
        
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model_quant = converter.convert()

with open('xray_mobilenetv2_quant_model.tflite', 'wb') as f:
    f.write(tflite_model_quant)
    
interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
output_type = interpreter.get_output_details()[0]['dtype']
print('input: ', input_type)
print('output: ', output_type)

Google's edgeTPU compiler only works on linux, and only on x86_64 at that. Lets spin up a docker container to get that working run `docker build --platform linux/amd64 --tag edgetpu_compiler https://github.com/tomassams/docker-edgetpu-compiler.git` if you haven't already

In [None]:
!docker run -it --platform linux/amd64 --rm -v (pwd):/home/edgetpu edgetpu_compiler edgetpu_compiler xray_mobilenetv2_quant_model.tflite