This notebook uses parts of code from Martin Goerner's notebook [
Getting started with 100+ flowers on TPU](https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu).

In [None]:
!pip install -q efficientnet

In [None]:
import numpy as np
import pandas as pd
import re
from tqdm.auto import tqdm
from kaggle_datasets import KaggleDatasets
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow_addons as tfa
import efficientnet.tfkeras as efn

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

In [None]:
class CFG:
    N_CLASSES = 64500
    IMAGE_SIZE = [256, 256]
    BATCH_SIZE = 16

In [None]:
TEST_FILENAMES = tf.io.gfile.glob('../input/herb2021-test-256/*.tfrec')

In [None]:
def get_model():
    base_model = efn.EfficientNetB0(weights=None, 
                                    include_top=False, 
                                    pooling='avg',
                                    input_shape=(*CFG.IMAGE_SIZE, 3))
    model = tf.keras.Sequential([
        base_model,
        L.Dense(CFG.N_CLASSES, activation='sigmoid')
    ])
    
    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=[tfa.metrics.F1Score(CFG.N_CLASSES, average='weighted')])
    
    return model

In [None]:
model = get_model()
model.load_weights('../input/herb2021-effnet/best.h5')

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)  # image format uint8 [0,255]
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*CFG.IMAGE_SIZE, 3])
    return image

def get_idx(image, idnum):
    idnum = tf.strings.split(idnum, sep='/')[6]
    idnum = tf.strings.regex_replace(idnum, ".jpg", "")
    idnum = tf.strings.to_number(idnum, out_type=tf.int64)
    return image, idnum
    
def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    return image, label

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'image_idx': tf.io.FixedLenFeature([], tf.string)
    }

    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum=example['image_idx']
    return image, idnum

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'image_idx': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }

    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = example['label']
    return image, label

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    return dataset

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES)
    dataset = dataset.map(onehot, num_parallel_calls=AUTO)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(CFG.BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(ordered=False, augmented=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.map(get_idx, num_parallel_calls=AUTO)
    dataset = dataset.batch(CFG.BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
print('Dataset: {} unlabeled test images'.format(NUM_TEST_IMAGES))

In [None]:
print('Calculating predictions...')
test_ds = get_test_dataset(ordered=True)
test_images_ds = test_ds.map(lambda image, idnum: image)

predictions = np.zeros(NUM_TEST_IMAGES, dtype=np.int32)
for i, image in tqdm(enumerate(test_images_ds), total=NUM_TEST_IMAGES//CFG.BATCH_SIZE + 1):
    idx1 = i*CFG.BATCH_SIZE
    if (idx1 + CFG.BATCH_SIZE) > NUM_TEST_IMAGES:
        idx2 = NUM_TEST_IMAGES
    else:
        idx2 = idx1 + CFG.BATCH_SIZE
    predictions[idx1:idx2] = np.argmax(model.predict_on_batch(image), axis=-1)

print('Generating submission file...')
sub = pd.read_csv('../input/herbarium-2021-fgvc8/sample_submission.csv')
sub['Predicted'] = predictions
sub.to_csv('submission.csv', index=False)
print(sub.head(10))