In this notebook, we will convert the TrayDataset to TFRecords and then use UNet model predict the mask using TPU.

In [None]:
import numpy as np
import glob
import tensorflow as tf

In [None]:
IMAGE_SIZE = [256, 256]
EPOCHS = 5
BATCH_SIZE = 64

In [None]:
def decode_image(file_path):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, IMAGE_SIZE)
    image = image.numpy().flatten() #TFRecord only accept 1-dim array
    return image

In [None]:
def serialize_example(image, mask):
  feature = {
      'image': tf.train.Feature(float_list=tf.train.FloatList(value=image)),
      'mask': tf.train.Feature(float_list=tf.train.FloatList(value=mask))
  }
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

In [None]:
file_path_train = 'train.tfrecords'
file_path_test = 'test.tfrecords'

Now, we will create the TFRecords. In order to use this file with TPU, we need to put them in a dataset.

In [None]:
def create_tfrecord(folder_path, subset, saveFile_path):
    writer = tf.io.TFRecordWriter(saveFile_path)
    for index, file_path_input in enumerate(glob.glob(folder_path+"X"+subset+"/*")):
        file_path_output = file_path_input.replace("/X"+subset, "/y"+subset)
        file_path_output = file_path_output.replace(".jpg", ".png").replace(".JPG", ".png")
        print(file_path_input, file_path_output)
        image, mask = decode_image(file_path_input), decode_image(file_path_output)
        writer.write(serialize_example(image, mask))
    writer.close()
create_tfrecord("/kaggle/input/tray-food-segmentation/TrayDataset/TrayDataset/", "Train", file_path_train)
create_tfrecord("/kaggle/input/tray-food-segmentation/TrayDataset/TrayDataset/", "Test", file_path_test)

From here we will use TPU.
The created tfrecords can be found on Kaggle named: traydataset-256x256-tfrecords-image-mask

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

In [None]:
from kaggle_datasets import KaggleDatasets
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)
GCS_DS_PATH = KaggleDatasets().get_gcs_path('traydataset-256x256-tfrecords-image-mask')
print(GCS_DS_PATH)

In [None]:
file_path_train = GCS_DS_PATH+ '/train.tfrecords'
file_path_test = GCS_DS_PATH+'/test.tfrecords'

In [None]:
def extract_fn(data_record):
    features = {
        'image': tf.io.FixedLenFeature([256*256*3], tf.float32),
        'mask': tf.io.FixedLenFeature([256*256*3], tf.float32)
    }
    sample = tf.io.parse_example(data_record, features)
    image = sample['image']
    image = tf.reshape(image, IMAGE_SIZE+[3])
    mask = sample['mask']
    mask = tf.reshape(mask, IMAGE_SIZE+[3])
    return image, mask

In [None]:
def count(file_path):
    dataset = tf.data.TFRecordDataset([file_path])
    dataset = dataset.map(extract_fn)
    dataset = dataset.enumerate()

    count = 0
    for element in dataset.as_numpy_iterator():
      #print("element:", len(element))
      count+=1
    print("count", count)
count(file_path_train)
count(file_path_test)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
def load_dataset(file_path):
    dataset = tf.data.TFRecordDataset([file_path])
    dataset = dataset.map(extract_fn)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset
dataset_train = load_dataset(file_path_train)
dataset_test = load_dataset(file_path_test)

In [None]:
from tensorflow.keras.layers import Activation, Lambda, GlobalAveragePooling2D, concatenate
from tensorflow.keras.layers import UpSampling2D, Conv2D, Dropout, MaxPooling2D, Conv2DTranspose
from tensorflow.keras.layers import Dense, Flatten, Input
from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
def get_model(IMG_HEIGHT, IMG_WIDTH):
    in1 = Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3))
    preprop = in1
    #preprop = tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical")(preprop)
    #preprop = tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)(preprop)

    conv1 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(preprop)#(in1)
    conv1 = Dropout(0.2)(conv1)
    conv1 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv1)
    pool1 = MaxPooling2D((2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(pool1)
    conv2 = Dropout(0.2)(conv2)
    conv2 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv2)
    pool2 = MaxPooling2D((2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(pool2)
    conv3 = Dropout(0.2)(conv3)
    conv3 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv3)
    pool3 = MaxPooling2D((2, 2))(conv3)

    conv4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(pool3)
    conv4 = Dropout(0.2)(conv4)
    conv4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv4)

    up1 = concatenate([UpSampling2D((2, 2))(conv4), conv3], axis=-1)
    conv5 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(up1)
    conv5 = Dropout(0.2)(conv5)
    conv5 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv5)

    up2 = concatenate([UpSampling2D((2, 2))(conv5), conv2], axis=-1)
    conv6 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(up2)
    conv6 = Dropout(0.2)(conv6)
    conv6 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv6)

    up2 = concatenate([UpSampling2D((2, 2))(conv6), conv1], axis=-1)
    conv7 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(up2)
    conv7 = Dropout(0.2)(conv7)
    conv7 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(conv7)
    segmentation = Conv2D(3, (1, 1), activation='sigmoid', name='seg')(conv7)

    model = Model(inputs=[in1], outputs=[segmentation])

    losses = {'seg': 'binary_crossentropy'}

    metrics = {'seg': ['acc']}
    model.compile(optimizer="adam", loss = losses, metrics=metrics)

    return model

In [None]:
with strategy.scope():
    model = get_model(256, 256)

In [None]:
history = model.fit(dataset_train, epochs=EPOCHS+100, batch_size=BATCH_SIZE, validation_data=dataset_test)

In [None]:
save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save("model", options=save_locally)
model.save("model.h5", options=save_locally)

In [None]:
import matplotlib.pyplot as plt

metrics = history.history

plt.plot(history.epoch, metrics['loss'], metrics['acc'])
plt.legend(['loss', 'acc'])
plt.savefig("fit-history.png")
plt.show()
plt.close()

In [None]:
import matplotlib.pyplot as plt

dataset = tf.data.TFRecordDataset([file_path_train])
dataset = dataset.map(extract_fn)
dataset = dataset.enumerate()

count = 0
for element in dataset.as_numpy_iterator():
    #print("element:", element)
    image = element[1][0]
    mask = element[1][1]
    predicted = model.predict(np.array([image]))[0]
    f, axarr = plt.subplots(1,3)
    axarr[0].imshow(image)
    axarr[1].imshow(mask)
    axarr[2].imshow(predicted)
    plt.show()                    
    count+=1
print("count", count)