## Import section

In [None]:
import re
import os
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

from functools import partial
from kaggle_datasets import KaggleDatasets
from sklearn.model_selection import train_test_split
print("Tensorflow version " + tf.__version__)

## Set up some constant variables

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16 
IMAGE_SIZE = [512, 512]
CLASSES = ['0', '1', '2', '3', '4']
CLASS_NAMES = ['Cassava Bacterial Blight', 'Cassava Brown Streak Disease', 
               'Cassava Green Mottle', 'Cassava Mosaic Disease', 
               'Healthy']

EPOCHS = 7

## Data decoding

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels = 3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

## Parse single example from TFRecord format

In [None]:
def read_tfrecord(example, labeled):
    tfrecord_format = {'image': tf.io.FixedLenFeature([], tf.string), 
                       'target': tf.io.FixedLenFeature([], tf.int64)} if labeled else \
     {'image': tf.io.FixedLenFeature([], tf.string), 
      'image_name': tf.io.FixedLenFeature([], tf.string)}
    
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    if labeled:
        label = tf.cast(example['target'], tf.int32)
        return image, label
    idnum = example['image_name']
    return image, idnum

## Read TFRecords dataset

In [None]:
def load_dataset(filenames, labeled = True, ordered = False):
    # For optimal performance, reading from multiple files at once
    # Order does not matter since we will be shuffling the data anyway
    ignore_order = tf.data.Options()
    if not ordered:
        # disable order, increase speed
        ignore_order.experimental_deterministic = False 
        
    # automatically interleaves reads from multiple files
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTOTUNE) 
    
    # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.with_options(ignore_order) 
    
    dataset = dataset.map(partial(read_tfrecord, labeled = labeled), num_parallel_calls = AUTOTUNE)
    return dataset

## Test data

In [None]:
TEST_FILENAMES = tf.io.gfile.glob('../input/cassava-leaf-disease-classification/test_tfrecords/ld_test*.tfrec')

## Number of test examples

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
NUM_TEST_IMAGES

## Data Augmentation

In [None]:
def data_augment(image, label): 
    image = tf.image.random_flip_left_right(image)
    return image, label

## Function to prepare test dataset

In [None]:
def get_test_dataset(ordered = False):
    dataset = load_dataset(TEST_FILENAMES, labeled = False, ordered = ordered)
    dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return dataset

In [None]:
print("Test data shapes:")
for image, idnum in get_test_dataset().take(3):
    print(image.numpy().shape, idnum.numpy().shape)
print("Test data IDs:", idnum.numpy().astype('U')) # U = unicode string

In [None]:
testing_dataset = get_test_dataset()
testing_dataset = testing_dataset.unbatch().batch(20)
test_batch = iter(testing_dataset)

## Load trained model

In [None]:
model = tf.keras.models.load_model('../input/casava-leaf-disease-classification/model.h5')

In [None]:
# this code will convert our test image data to a float32 
def to_float32(image, label):
    return tf.cast(image, tf.float32), label

## Make predictions

In [None]:
test_ds = get_test_dataset(ordered = True) 
test_ds = test_ds.map(to_float32)

print('Computing predictions...')
test_images_ds = testing_dataset
test_images_ds = test_ds.map(lambda image, idnum: image)
probabilities = model.predict(test_images_ds)
predictions = np.argmax(probabilities, axis = -1)

print(predictions)

## CSV file generation

In [None]:
print('Generating submission.csv file...')
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]),
           fmt=['%s', '%d'], delimiter=',', header='image_id,label', comments='')
!head submission.csv