In [1]:
import tensorflow as tf
import keras

import numpy as np
import pandas as pd
import os
from PIL import Image

In [2]:
image_size = 350

In [3]:
model = tf.keras.models.load_model('../input/effnetb3/effnet_(1).h5')

In [4]:
TEST_PATH = '../input/cassava-leaf-disease-classification/test_tfrecords/'

In [5]:
AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32
IMAGE_SIZE = [512, 512]

In [6]:
def _parse_function(proto):
    # feature_description needs to be defined since datasets use graph-execution
    # - its used to build their shape and type signature
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'image_name': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'target': tf.io.FixedLenFeature([], tf.int64, default_value=-1)
    }

    parsed_features = tf.io.parse_single_example(proto, feature_description)
    image = tf.image.decode_jpeg(parsed_features['image'], channels=3)
    image = tf.cast(image, tf.float32) # :: [0.0, 255.0]
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    target = tf.one_hot(parsed_features['target'], depth=5)
    image_id = parsed_features['image_name']
    return image, target, image_id

In [7]:
def _preprocess_fn(image, label, image_id):
    image = image / 255.0
    image = tf.image.resize(image, (image_size, image_size))
    label = tf.concat([label, [0]], axis=0)
    return image, label, image_id

In [8]:
def load_dataset(tfrecords_fnames):
    raw_ds = tf.data.TFRecordDataset(tfrecords_fnames, num_parallel_reads=AUTO)
    parsed_ds = raw_ds.map(_parse_function, num_parallel_calls=AUTO)
    parsed_ds = parsed_ds.map(_preprocess_fn, num_parallel_calls=AUTO)
    return parsed_ds

In [9]:
def build_valid_ds(valid_fnames):
    ds = load_dataset(valid_fnames)
    ds = ds.batch(BATCH_SIZE).prefetch(AUTO)
    return ds

In [10]:
test_path = '../input/cassava-leaf-disease-classification/test_tfrecords/'
valid_fnames = [TEST_PATH + fname for fname in os.listdir(test_path)]

In [11]:
test_ds = build_valid_ds(valid_fnames)

In [12]:
preds = model.predict(test_ds)
labels = tf.argmax(preds, axis=-1)
labels = labels.numpy()

In [13]:
names = []
for item in test_ds:
    names.append(item[2].numpy())
names = np.concatenate(names)
names = [name.decode() for name in names]

In [14]:
submission_df = pd.DataFrame({'image_id':names, 'label':labels})
submission_df.to_csv("submission.csv", index=False)

In [15]:
submission_df.head()