# 利用tensorflow2.0卷积神经网络进行卫星图片分类实例操作详解
来源：https://blog.csdn.net/lys_828/article/details/101322246

In [1]:
import tensorflow as tf
from tensorflow import keras
print('Tensorflow version: {}, GPU is {}'.format(tf.__version__, tf.test.is_gpu_available()))
if tf.test.is_gpu_available():
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    #下面的方式是设置Tensorflow固定消耗GPU:0的2GB显存（该设置对代码缺陷（GPU需求内存超物理GPU内存时异常退出）不起作用）。
    tf.config.experimental.set_virtual_device_configuration(
        gpus[0],
        [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)]
    )

Tensorflow version: 2.0.0, GPU is False


In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

In [3]:
import pathlib
import random
import IPython.display as display
import datetime
import json
import pprint

class MyCnn:
    def __init__(self, width = 256 , height = 256):
        #设定图片预处理的尺寸(width,height)
        super(MyCnn, self).__init__()
        self.time_start = datetime.datetime.now()
        self.time_keep = 0
        self.IMAGE_WIDTH = width
        self.IMAGE_HEIGHT = height
        self.model = tf.keras.Sequential()   #顺序模型
        if width == 256 and height == 256:
            #正常模式（有充足的GPU内存或CPU模式）
            self.model.add(tf.keras.layers.Conv2D(64, (3, 3), input_shape=(width, height, 3), activation='relu'))
            self.model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.MaxPooling2D())
            self.model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.MaxPooling2D())
            self.model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.MaxPooling2D())
            self.model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.MaxPooling2D())
            self.model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.MaxPooling2D())
            self.model.add(tf.keras.layers.Conv2D(1024, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.GlobalAveragePooling2D())
            self.model.add(tf.keras.layers.Dense(1024, activation='relu'))
            self.model.add(tf.keras.layers.Dense(256, activation='relu'))
            self.model.add(tf.keras.layers.Dense(10, activation='softmax'))
        else:
            #适应小内存，input_shape=(128,128,3)
            self.model.add(tf.keras.layers.Conv2D(64, (3, 3), input_shape=(width, height, 3), activation='relu'))
            #self.model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.MaxPooling2D())
            self.model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu'))
            #self.model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.MaxPooling2D())
            self.model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu'))
            #self.model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.MaxPooling2D())
            self.model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.MaxPooling2D())
            #self.model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu'))
            #self.model.add(tf.keras.layers.MaxPooling2D())
            self.model.add(tf.keras.layers.Conv2D(1024, (3, 3), activation='relu'))
            self.model.add(tf.keras.layers.GlobalAveragePooling2D())
            self.model.add(tf.keras.layers.Dense(1024, activation='relu'))
            self.model.add(tf.keras.layers.Dense(256, activation='relu'))
            self.model.add(tf.keras.layers.Dense(10, activation='softmax'))
    def loaddata(self, data_dir, batch_size = 32, num_parallel_calls=tf.data.experimental.AUTOTUNE, test_rate = 0.2):
        self.all_image_paths, self.all_image_labels, self.image_count, self.label_names = self.loadfiles(data_dir)
        self.train_data , self.train_count , self.test_data , self.test_count = self.datasplit(self.all_image_paths , batch_size = batch_size , num_parallel_calls = num_parallel_calls , test_rate = test_rate)
        self.steps_per_epoch = self.train_count//batch_size
        self.validation_steps = self.test_count//batch_size
    def compile(self):
        self.model.compile(optimizer='adam',
                      loss='sparse_categorical_crossentropy',
                      metrics=['acc']
                     )
    def fit(self , epochs = 30):
        self.history = self.model.fit(self.train_data, epochs=epochs, steps_per_epoch=self.steps_per_epoch, validation_data=self.test_data, validation_steps=self.validation_steps)
        self.time_stop = datetime.datetime.now()
        self.time_keep = self.time_stop - self.time_start
    def loadmodel(self , path):
        #模型的加载
        try:
            model_exist = tf.keras.models.load_model(path)
        except:
            return False;
        self.model = model_exist
        return True;
    def savemodel(self , path):
        #模型的保存
        self.model.save(path)
        #保存tensorflow格式的参数值
        #model.save_weights(SAVE_PATH + 'model')
        #保存keras格式的参数值（权重）
        #model.save_weights(SAVE_PATH + 'model',save_format='HDF5')
    def from_json(self, json_str):
        self.model = tf.keras.models.model_from_json(json_str) 
    def to_json(self):
        # 序列化成json
        json_str = self.model.to_json()
        pprint.pprint(json.loads(json_str))
    def showhistory(self):
        try:
            r=object.__getattribute__(self, 'history')
        except:
            r=None

        if r == None:
            print('no history')
        else:
            print(self.history.history.keys())
            #显示正确率
            plt.plot(self.history.epoch, self.history.history.get('acc'), label='acc')
            plt.plot(self.history.epoch, self.history.history.get('val_acc'), label='val_acc')
            plt.plot(self.history.epoch, np.array(self.history.history.get('acc')) * np.array(self.history.history.get('val_acc')), label='cc')
            plt.legend()
            plt.show()

            #显示错误率
            plt.plot(self.history.epoch, self.history.history.get('loss'), label='loss')
            plt.plot(self.history.epoch, self.history.history.get('val_loss'), label='val_loss')
            plt.plot(self.history.epoch, np.array(self.history.history.get('loss')) * np.array(self.history.history.get('val_loss')), label='ss')
            plt.legend()
            plt.show()
    def imagepred(self, image_path , label , withshow = True):
        #模型预测
        image = self.load_and_preprocess_image(image_path)
        if withshow:
            plt.imshow(image)
            plt.grid(False)
            plt.xlabel(self.caption_image(label))
            plt.show()

        image1 = tf.io.read_file(image_path)
        image1 = tf.image.decode_jpeg(image1, channels=3)
        image1 = tf.image.resize(image1, [self.IMAGE_WIDTH, self.IMAGE_HEIGHT])
        image1 = tf.cast(image1, tf.float32)
        image1 = image1/255.0  # normalize to [0,1] range

        image2 = np.array([image1.numpy()])

        result = self.model.predict(image2)

        result_value = tf.argmax(result,1).numpy()
        if label == -1:
            print('pred label={},{}({})'.format(result_value[0] , self.caption_image(result_value[0]) , image_path))
        elif result_value != label:
            print('pred label=' , result_value[0] , self.caption_image(result_value[0]), '(error)') 
        else:
            print('pred label=' , result_value[0] , self.caption_image(result_value[0])) 
    def imagepredi(self, index, withshow = True):
        #获取要预测的图片路径及标签，并调用imagepred进行预测校验
        image_path = self.all_image_paths[index]
        label = self.all_image_labels[index]
        return self.imagepred(image_path , label , withshow = withshow)
    def loadfiles(self,data_dir , randomflag = True):
        #根据相对目录路径，读取文件清单，标签清单，文件数（文件被random打乱）
        #一级目录为airplan、lake，分别表示机场、湖泊，二级目录为对应的图片文件
        data_root = pathlib.Path(data_dir)

        #print(data_root)
        #for item in data_root.iterdir():
        #    print(item)

        all_image_paths = list(data_root.glob('*/*'))
        image_count = len(all_image_paths)

        #print(all_image_paths[:3],all_image_paths[-3:])

        #打乱顺序
        all_image_paths = [str(path) for path in all_image_paths]
        if randomflag:
            random.shuffle(all_image_paths)
            #print(all_image_paths[:5])

        label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
        #print(label_names)
        #对排序后的label按目录名进行排序
        label_to_index = dict((name, index) for index,name in enumerate(label_names))
        #print(label_to_index)

        #根据文件的目录名生成label结果集
        all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
        #print(all_image_labels[:5])

        return all_image_paths, all_image_labels, image_count, label_names
    def checkfile(self):
        #随机抽查几张图片及其标签值
        for n in range(3):
            image_index = random.choice(range(len(self.all_image_paths)))
            display.display(display.Image(self.all_image_paths[image_index]))
            print(image_index, self.caption_image(self.all_image_labels[image_index]))
            #print()
        #加载和格式化图像

        img_path = self.all_image_paths[0]
        print(img_path)

        img_raw = tf.io.read_file(img_path)
        print(repr(img_raw)[:100]+"...")

        img_tensor = tf.image.decode_image(img_raw)
        print(img_tensor.shape)
        print(img_tensor.dtype)

        img_tensor = tf.cast(img_tensor, tf.float32)
        img_final = img_tensor/255.0
        print(img_final.shape)
        print(img_final.numpy().min())
        print(img_final.numpy().max())
    def loadcheckfiles(self,data_dir):
        data_root = pathlib.Path(data_dir)

        all_image_paths = list(data_root.glob('*'))
        all_image_paths = [str(path) for path in all_image_paths]
        return all_image_paths
    def caption_image(self,label):
        #根据标签值获取标签对应的名称
        return {0: 'airplane', 1: 'lake'}.get(label)
    def load_and_preprocess_image(self,path):
        #加载并预处理图像，使之值介于[0,1]范围内
        image = tf.io.read_file(path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [self.IMAGE_WIDTH, self.IMAGE_HEIGHT])
        image = tf.cast(image, tf.float32)
        image = image/255.0  # normalize to [0,1] range
        return image
    def imageshow(self,imageindex):
        #按文件索引进行显示图片
        image_path = self.all_image_paths[imageindex]
        label = self.all_image_labels[imageindex]

        plt.imshow(self.load_and_preprocess_image(image_path))
        plt.grid(False)
        plt.xlabel(self.caption_image(label))
        print()
    def datasplit(self, data , batch_size = 32, num_parallel_calls=tf.data.experimental.AUTOTUNE, test_rate = 0.2):
        path_ds = tf.data.Dataset.from_tensor_slices(data)
        image_ds = path_ds.map(self.load_and_preprocess_image, num_parallel_calls=num_parallel_calls)
        label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(self.all_image_labels, tf.int64))
        image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
        #for label in label_ds.take(10):
        #    print(label_names[label.numpy()])
        #print(image_label_ds)
        image_count = len(self.all_image_paths)
        test_count = int(image_count*test_rate)
        train_count = image_count - test_count
        train_data = image_label_ds.skip(test_count)
        test_data = image_label_ds.take(test_count)

        train_data = train_data.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=train_count))
        #可以改成以下写法（待定）
        #train_data = train_data.shuffle(buffer_size=train_count)
        #train_data = train_data.repeat(10)#epochs
        #train_data = train_data.apply(train_data)
        
        train_data = train_data.batch(batch_size)
        train_data = train_data.prefetch(buffer_size=num_parallel_calls)
        #print(train_data)

        test_data = test_data.batch(batch_size)

        return train_data , train_count , test_data , test_count
    def summary(self):
        self.model.summary()

In [4]:
data_dir = './2_class'
MODEL_PATH = './airplan-lake.model/'

In [5]:
IMAGE_WIDTH = 128
IMAGE_HEIGHT = 128

In [6]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32

In [7]:
cnn = MyCnn(IMAGE_WIDTH , IMAGE_HEIGHT)
cnn.loaddata(data_dir , batch_size = BATCH_SIZE , num_parallel_calls = AUTOTUNE , test_rate = 0.2)
cnn.compile()

Instructions for updating:
Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by `tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take care of using the fused implementation.


In [8]:
cnn.loadmodel(MODEL_PATH)

True

In [10]:
cnn.fit(epochs = 10)

Train for 35 steps, validate for 8 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [11]:
cnn.savemodel(MODEL_PATH)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: ./airplan-lake.model/assets


In [9]:
#模型评估（直接采用验证集来评估，结果应该跟val_acc一致）
steps = cnn.validation_steps
cnn.model.evaluate(cnn.test_data, steps= steps) 



[0.041400633519515395, 0.9921875]

In [10]:
cnn.showhistory()

no history


In [11]:
for i in range(20):
    cnn.imagepredi(i, withshow = False)

pred label= 0 airplane
pred label= 0 airplane
pred label= 0 airplane
pred label= 1 lake
pred label= 1 lake
pred label= 1 lake
pred label= 1 lake
pred label= 1 lake
pred label= 1 lake
pred label= 0 airplane
pred label= 1 lake
pred label= 0 airplane
pred label= 1 lake
pred label= 1 lake
pred label= 0 airplane
pred label= 0 airplane
pred label= 0 airplane
pred label= 1 lake
pred label= 0 airplane
pred label= 1 lake


In [12]:
CHECK_PATH = './4_check'
check_image_paths = cnn.loadcheckfiles(CHECK_PATH)
print('check image count = ', len(check_image_paths))
for f in check_image_paths:
    cnn.imagepred(f , -1, withshow = False)

check image count =  20
pred label=0,airplane(4_check\机场21.jpg)
pred label=0,airplane(4_check\机场22.jpg)
pred label=1,lake(4_check\机场23.jpg)
pred label=0,airplane(4_check\机场24.jpg)
pred label=0,airplane(4_check\机场25.jpg)
pred label=0,airplane(4_check\机场26.jpg)
pred label=0,airplane(4_check\机场27.jpg)
pred label=0,airplane(4_check\机场28.jpg)
pred label=0,airplane(4_check\机场29.jpg)
pred label=1,lake(4_check\机场30.jpg)
pred label=1,lake(4_check\湖泊21.jpg)
pred label=0,airplane(4_check\湖泊22.jpg)
pred label=1,lake(4_check\湖泊23.jpg)
pred label=1,lake(4_check\湖泊24.jpg)
pred label=1,lake(4_check\湖泊25.jpg)
pred label=1,lake(4_check\湖泊26.jpg)
pred label=1,lake(4_check\湖泊27.jpg)
pred label=1,lake(4_check\湖泊28.jpg)
pred label=1,lake(4_check\湖泊29.jpg)
pred label=1,lake(4_check\湖泊30.jpg)
