In [1]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from functools import partial
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense, Conv2D, MaxPool2D
from datetime import datetime
from keras.preprocessing import image
AUTOTUNE = tf.data.experimental.AUTOTUNE

Using TensorFlow backend.


In [2]:
# Working directory; to the tfrecord files
cwd = os.getcwd()
tfrecord_files_dir = (cwd + '/tfrecords')

In [3]:
# Prepare the lists of train and test tfrecords files
tfrecord_files = os.listdir(tfrecord_files_dir)
full_train_tfrecords = []
test_tfrecords = []
for i in tfrecord_files:
    if i[:4] == 'trai':
        full_train_tfrecords.append(tfrecord_files_dir + '/' + i)
    elif i[:4] == 'test':
        test_tfrecords.append(tfrecord_files_dir + '/' + i)

In [4]:
input_shape = [300,300,3]

In [5]:
# Define functions to create train and validation datasets

def preprocess(tfrecord):
    train_feature_descriptions = {
        "image": tf.io.VarLenFeature(tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string, default_value=""),
        "target": tf.io.FixedLenFeature([], tf.int64)
    }
    example = tf.io.parse_single_example(tfrecord, train_feature_descriptions)
    image = tf.io.decode_image(example["image"].values[0])
    image = tf.reshape(image, shape=[1024, 1024, 3])
    image = tf.image.resize(image, input_shape[:2], method='nearest')
    return image, example["target"]

def create_dataset(filepaths, batch_size=16):
    full_dataset = tf.data.TFRecordDataset(filepaths)
    
    train_size = int(0.8 * len(list(full_dataset)))
    valid_size = int(0.2 * len(list(full_dataset)))
    
    full_dataset = full_dataset.shuffle(len(list(full_dataset)))
    full_dataset = full_dataset.map(preprocess)
    
    train_dataset = full_dataset.take(train_size)
    valid_dataset = full_dataset.skip(valid_size)

    return train_dataset, valid_dataset, train_size, valid_size

In [6]:
train_set, valid_set, train_size, valid_size = create_dataset(full_train_tfrecords)

In [7]:
train_set

<TakeDataset shapes: ((300, 300, 3), ()), types: (tf.uint8, tf.int64)>

In [8]:
valid_set

<SkipDataset shapes: ((300, 300, 3), ()), types: (tf.uint8, tf.int64)>

In [9]:
train_size

26500

In [10]:
valid_size

6625

In [11]:
def convert(image, label):
    image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]
    return image, label

def augment(image,label):
    image,label = convert(image, label)
    image = tf.image.resize_with_crop_or_pad(image, 1244, 1244) # Add 220 pixels of padding
    image = tf.image.random_crop(image, size=[1024, 1024, 3]) # Random crop back to 1024 x 1024
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, max_delta=0.7) # Random brightness
    image = tf.image.random_contrast(image, 0.2, 0.7)
    image = tf.image.random_saturation(image, 0.2, 0.7)
    image = tf.image.resize(image, input_shape[:2], method='nearest')
    return image,label

In [12]:
BATCH_SIZE = 64

In [13]:
augmented_train_batches = (
    train_set.shuffle(train_size)
    .map(augment, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
) 

In [14]:
augmented_train_batches

<PrefetchDataset shapes: ((None, 300, 300, 3), (None,)), types: (tf.float32, tf.int64)>

In [15]:
validation_batches = (valid_set.map(convert, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE))

In [16]:
validation_batches

<BatchDataset shapes: ((None, 300, 300, 3), (None,)), types: (tf.float32, tf.int64)>

In [17]:
DefaultConv2D = partial(keras.layers.Conv2D, kernel_size=3, activation='relu', padding="SAME")

model = keras.models.Sequential([
    DefaultConv2D(filters=64, kernel_size=3, input_shape=input_shape),
    keras.layers.MaxPooling2D(pool_size=2),
    DefaultConv2D(filters=128),
    DefaultConv2D(filters=128),
    keras.layers.MaxPooling2D(pool_size=2),
    DefaultConv2D(filters=256),
    DefaultConv2D(filters=256),
    keras.layers.MaxPooling2D(pool_size=2),
    keras.layers.Flatten(),
    keras.layers.Dense(units=128, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(units=64, activation='relu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(units=1, activation='sigmoid'),
])

model.compile(
    optimizer=keras.optimizers.SGD(lr=0.0001, momentum=0.9, nesterov=True),
    loss='binary_crossentropy',
    metrics=[tf.keras.metrics.AUC()])

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 300, 300, 64)      1792      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 150, 150, 64)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 150, 150, 128)     73856     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 150, 150, 128)     147584    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 75, 75, 128)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 75, 75, 256)       295168    
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 75, 75, 256)       5

In [18]:
logs = os.path.join(os.curdir, "my_logs", "run_" + datetime.now().strftime("%Y%m%d_%H%M%S"))
tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=logs, histogram_freq=1, profile_batch=10)
early_stopping_cb = keras.callbacks.EarlyStopping(patience=10)
checkpoint_cb = keras.callbacks.ModelCheckpoint("my_model_tfrecords_data_aug.h5", save_best_only=True)

In [19]:
history = model.fit(
    augmented_train_batches, 
    epochs=80, 
    validation_data=validation_batches, 
    callbacks=[early_stopping_cb, checkpoint_cb, tensorboard_cb])

Epoch 1/80
Epoch 2/80
Epoch 3/80
Epoch 4/80
Epoch 5/80
Epoch 6/80
Epoch 7/80
Epoch 8/80
Epoch 9/80
Epoch 10/80
Epoch 11/80
Epoch 12/80
Epoch 13/80
Epoch 14/80
Epoch 15/80
Epoch 16/80
Epoch 17/80
Epoch 18/80
Epoch 19/80
Epoch 20/80
Epoch 21/80
Epoch 22/80


In [24]:
# Define functions to create test dataset

def get_test_images(tfrecord):
    test_feature_descriptions = {
        "image": tf.io.VarLenFeature(tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string, default_value=""),
    }
    example = tf.io.parse_single_example(tfrecord, test_feature_descriptions)
    image = tf.io.decode_image(example["image"].values[0])
    image = tf.reshape(image, shape=[1024, 1024, 3])
    image = tf.image.resize(image, input_shape[:2], method='nearest')
    return image/255, example["image_name"]

def create_test_dataset(filepaths, n_read_threads=5, n_parse_threads=5, batch_size=1):
    dataset = tf.data.TFRecordDataset(filepaths, num_parallel_reads=n_read_threads)
    dataset = dataset.map(get_test_images, num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset.prefetch(1)

In [25]:
test_set = create_test_dataset(test_tfrecords)

In [26]:
test_set

<PrefetchDataset shapes: ((None, 300, 300, 3), (None,)), types: (tf.float32, tf.string)>

In [27]:
test_images = []
predictions = []

def get_predictions(tfrecords):
    
    for item in tfrecords:
        img_no = item[1].numpy()[0]
        test_images.append(img_no.decode('utf-8'))
        img = item[0] 
        pred = model.predict_classes(img)
        predictions.append(pred[0][0])
        
    predictions_d = pd.DataFrame(list(zip(test_images, predictions)), columns =['image_name', 'target'])
    predictions_df = predictions_d.sort_values(by=['image_name'])
    
    predictions_df.to_csv(os.path.join(cwd, 'predictions.csv'), index = False, header=True)
    
    return predictions_df

In [28]:
predictions_tfrecords = get_predictions(test_set)

In [29]:
np.array(predictions_tfrecords['target']).sum()

10982