# 各模型的预测提交Kernel

* [训练Kernel](https://www.kaggle.com/tianyu5/tpus-cassava-leaf-disease)

In [None]:
# install this or efficientNet will failed to load
!pip install --quiet /kaggle/input/kerasapplications
!pip install --quiet /kaggle/input/efficientnet-git

In [None]:
import math, re, os, random, warnings
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
from tensorflow import keras
from functools import partial
from sklearn.model_selection import train_test_split
import efficientnet.tfkeras as efn

print("Tensorflow version " + tf.__version__)

## Set up variables

In [None]:
strategy = tf.distribute.get_strategy()
AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_PATH = "../input/cassava-leaf-disease-classification"
IMAGE_SIZE = [512, 512]
RESIZE_IMAGE_SIZE = [512, 512]  #  图像增强压缩后的大小 TPU 512,  GPU 300(太大爆内存)
CLASSES = ['0', '1', '2', '3', '4']
WEIGHTS_PATH = "../input/cassava-leaf-disease-resnet-weights/EfficientNetB4-best-08-0.8890.h5"
BATCH_SIZE = 16

In [None]:
# seed everything
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed = 0
seed_everything(seed)
warnings.filterwarnings('ignore')

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) # model zoo 的训练脚本不用
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

In [None]:
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test_tfrecords/*.tfrec')

print(TEST_FILENAMES)

测试集的处理和验证集一样. 用val_augment过一下

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE)
    return dataset


def data_val_augment(image, label):
    # val验证集图片预处理
    image = tf.reshape(image, [*IMAGE_SIZE, 3])  # 这里去掉了模型里的前处理层, 直接在这里reshape
    if not IMAGE_SIZE == RESIZE_IMAGE_SIZE:
        image = tf.image.resize(image, RESIZE_IMAGE_SIZE)  
    return image, label

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.map(data_val_augment, num_parallel_calls=AUTOTUNE)  
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
def read_tfrecord(example, labeled):
    tfrecord_format = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64)
    } if labeled else {
        "image": tf.io.FixedLenFeature([], tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    if labeled:
        label = tf.cast(example['target'], tf.int32)
        return image, label
    idnum = example['image_name']
    return image, idnum

In [None]:
# def to_float32(image, label):
#     return tf.cast(image, tf.float32), label

In [None]:
# 加载模型
trained_model = tf.keras.models.load_model(WEIGHTS_PATH)
trained_model.summary()

## 进行预测

In [None]:
test_ds = get_test_dataset(ordered=True) 

print('Computing predictions...')
test_images_ds = test_ds.map(lambda image, idnum: image)
probabilities = trained_model.predict(test_images_ds)
predictions = np.argmax(probabilities, axis=-1)
print(predictions)

In [None]:
print('Generating submission.csv file...')
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='image_id,label', comments='')
!head submission.csv