In [None]:
import os
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import pandas as pd

In [None]:
!cp -R ../input/cassava-layer/ /kaggle/working/cassava-layer/

In [None]:
os.environ["TFHUB_CACHE_DIR"] = "/kaggle/working/cassava-layer/"

In [None]:
cassava = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
model = tf.keras.Sequential([tf.keras.Input(shape=(224,224,3)),
                             cassava])

In [None]:
model.load_weights("../input/cassava-model/cassava_model.h5")

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32
IMAGE_SIZE = [512, 512]

In [None]:
def _parse_function(proto):
    # feature_description needs to be defined since datasets use graph-execution
    # - its used to build their shape and type signature
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'image_name': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'target': tf.io.FixedLenFeature([], tf.int64, default_value=-1)
    }

    parsed_features = tf.io.parse_single_example(proto, feature_description)
    image = tf.image.decode_jpeg(parsed_features['image'], channels=3)
    image = tf.cast(image, tf.float32) # :: [0.0, 255.0]
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    target = tf.one_hot(parsed_features['target'], depth=5)
    image_id = parsed_features['image_name']
    return image, target, image_id

In [None]:
def _preprocess_fn(image, label, image_id):
    image = image / 255.0
    image = tf.image.resize(image, (224, 224))
    label = tf.concat([label, [0]], axis=0)
    return image, label, image_id

In [None]:
def load_dataset(tfrecords_fnames):
    raw_ds = tf.data.TFRecordDataset(tfrecords_fnames, num_parallel_reads=AUTO)
    parsed_ds = raw_ds.map(_parse_function, num_parallel_calls=AUTO)
    parsed_ds = parsed_ds.map(_preprocess_fn, num_parallel_calls=AUTO)
    return parsed_ds

In [None]:
def build_valid_ds(valid_fnames):
    ds = load_dataset(valid_fnames)
    ds = ds.batch(BATCH_SIZE).prefetch(AUTO)
    return ds

In [None]:
TEST_PATH = '../input/cassava-leaf-disease-classification/test_tfrecords/'
valid_fnames = [TEST_PATH + fname for fname in os.listdir(TEST_PATH)]

In [None]:
test_ds = build_valid_ds(valid_fnames)

In [None]:
preds = model.predict(test_ds)
labels = tf.argmax(preds, axis=-1)
labels = labels.numpy()

In [None]:
test_ds = build_valid_ds(valid_fnames)

In [None]:
names = []
for item in test_ds:
    names.append(item[2].numpy())
names = np.concatenate(names)
names = [name.decode() for name in names]

In [None]:
submission_df = pd.DataFrame({'image_id':names, 'label':labels})
submission_df.to_csv("submission.csv", index=False)

In [None]:
submission_df.head()