In [None]:
# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# CropNet: Cassava Disease Detection

<table class="tfo-notebook-buttons" align="left">
  <td><a target="_blank" href="https://tensorflow.google.cn/hub/tutorials/cropnet_cassava"><img src="https://tensorflow.google.cn/images/tf_logo_32px.png">在 TensorFlow.org 查看</a></td>
  <td><a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/hub/tutorials/cropnet_cassava.ipynb"><img src="https://tensorflow.google.cn/images/colab_logo_32px.png">在 Google Colab 中运行 </a></td>
  <td>     <a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/hub/tutorials/cropnet_cassava.ipynb"><img src="https://tensorflow.google.cn/images/GitHub-Mark-32px.png">查看上GitHub</a> </td>
  <td><a href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/hub/tutorials/cropnet_cassava.ipynb"><img src="https://tensorflow.google.cn/images/download_logo_32px.png">下载笔记本</a></td>
  <td><a href="https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2"><img src="https://tensorflow.google.cn/images/hub_logo_32px.png">查看 TF Hub 模型</a></td>
</table>

此笔记本演示如何使用 **TensorFlow Hub** 中的 CropNet [木薯病虫害分类器](https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2)模型。该模型可将木薯叶的图像分为 6 类：*细菌性枯萎病、褐条病毒病、绿螨、花叶病、健康或未知*。

此 Colab 演示了如何执行以下操作：

- 从 **TensorFlow Hub** 加载 https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 模型
- 从 **TensorFlow Datasets (TFDS)** 加载[木薯](https://tensorflow.google.cn/datasets/catalog/cassava)数据集
- 将木薯叶图像分为 4 种不同的木薯病虫害类别、健康或未知。
- 评估分类器的*准确率*，并查看将模型在应用于域外图像时的*鲁棒性*。

## 导入和设置

In [None]:
!pip install matplotlib==3.2.2

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

In [None]:
#@title Helper function for displaying examples
def plot(examples, predictions=None):
  # Get the images, labels, and optionally predictions
  images = examples['image']
  labels = examples['label']
  batch_size = len(images)
  if predictions is None:
    predictions = batch_size * [None]

  # Configure the layout of the grid
  x = np.ceil(np.sqrt(batch_size))
  y = np.ceil(batch_size / x)
  fig = plt.figure(figsize=(x * 6, y * 7))

  for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
    # Render the image
    ax = fig.add_subplot(x, y, i+1)
    ax.imshow(image, aspect='auto')
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])

    # Display the label and optionally prediction
    x_label = 'Label: ' + name_map[class_names[label]]
    if prediction is not None:
      x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
      ax.xaxis.label.set_color('green' if label == prediction else 'red')
    ax.set_xlabel(x_label)

  plt.show()

## 数据集

让我们从 TFDS 中加载*木薯*数据集

In [None]:
dataset, info = tfds.load('cassava', with_info=True)

我们来查看数据集信息以了解更多内容，例如描述和引用以及有关可用样本量的信息

In [None]:
info

*木薯*数据集包含涉及 4 种不同病虫害的木薯叶图像以及健康的木薯叶图像。模型可以预测上述五个类，当模型不确定其预测结果时，会将图像分为第六个类，即“未知”类。

In [None]:
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])

将数据馈送至模型之前，我们需要进行一些预处理。模型接受大小为 224 x 224，且 RGB 通道值范围为 [0, 1] 的图像。让我们归一化图像并调整图像大小。

In [None]:
def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

我们看一下数据集中的一些样本

In [None]:
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)

## 模型

让我们从 TF-Hub 中加载分类器并获取一些预测结果，然后查看模型对一些样本的预测

In [None]:
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)

In [None]:
plot(examples, predictions)

## 评估和鲁棒性

我们来衡量分类器在拆分数据集上的*准确率*。我们还可以通过评估模型在非木薯数据集上的性能来评估其*鲁棒性*。对于 iNaturalist 或豆科植物等其他植物数据集中的图像，模型应几乎始终返回*未知*。

In [None]:
#@title Parameters {run: "auto"}

DATASET = 'cassava'  #@param {type:"string"} ['cassava', 'beans', 'i_naturalist2017']
DATASET_SPLIT = 'test' #@param {type:"string"} ['train', 'test', 'validation']
BATCH_SIZE =  32 #@param {type:"integer"}
MAX_EXAMPLES = 1000 #@param {type:"integer"}


In [None]:
def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data

In [None]:
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))

## 了解更多

- 详细了解 TensorFlow Hub 上的模型：https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- 了解如何使用[模型的 TensorFlow Lite 版本](https://tfhub.dev/google/lite-model/cropnet/classifier/cassava_disease_V1/1)通过 [ML Kit](https://developers.google.com/ml-kit/custom-models#tfhub) 构建在手机上运行的自定义图像分类器。