<a href="https://colab.research.google.com/github/sunshine2285/tensorflow_cnn/blob/master/tensorflow_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [62]:
from PIL import Image
import numpy as np
import tensorflow.compat.v1 as tf
# import tensorflow as tf
import os

tf.disable_v2_behavior()

#通过一下两个参数决定模型行为
# 训练还是测试
train = False
# 待分类label id
label_classify_list =  [0, 1, 2, 3, 4]

# 数据文件夹
train_data_dir = "/content/drive/My Drive/dataset/cifar-10-jpg/"
test_data_dir = "/content/drive/My Drive/dataset/cifar-10-jpg/test/"

# 模型文件路径
model_path = "/content/drive/My Drive/dataset/model/cifar-10-cnn/"

# label和名称的对照关系
label_name_dict = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck"
}

#定义分类数量
class_count = len(label_classify_list)

# 从文件夹中读取图片和标签到numpy数组中
# 标签信息在文件名中，0_99.jpg 表示标签为1
def read_data(data_dir):
    datas = []
    labels = []
    fpaths = []
    if data_dir == train_data_dir:
        for i in label_classify_list:
            for fname in os.listdir(data_dir + "\\" + str(i)):
                fpath = os.path.join(data_dir, fname)
                fpaths.append(fpath)
                image = Image.open(fpath)
                data = np.array(image) / 255.0
                label = int(fname.split("_")[0])
                datas.append(data)
                labels.append(label)
    elif data_dir == test_data_dir:
        for fname in os.listdir(data_dir):
            label = int(fname.split("_")[0])
            if label not in label_classify_list:
                continue
            fpath = os.path.join(data_dir, fname)
            image = Image.open(fpath)
            data = np.array(image) / 255.0
            fpaths.append(fpath)
            datas.append(data)
            labels.append(label)
    datas = np.array(datas)
    labels = np.array(labels)
    print("shape of datas: {}\tshape of labels: {}".format(datas.shape, labels.shape))
    return fpaths, datas, labels


data_dir = test_data_dir
if not train:
    data_dir = test_data_dir


fpaths, datas, labels = read_data(data_dir)


# 计算有多少类图片
num_classes = len(set(labels))

# 定义Placeholder，存放输入和标签
# tf.disable_eager_execution()
tf.reset_default_graph()
datas_placeholder = tf.placeholder(tf.float32, [None, 32, 32, 3])
labels_placeholder = tf.placeholder(tf.int32, [None])

# 定义存放DropOut参数的容器，训练时为0.25，测试时为0
dropout_placeholder = tf.placeholder(tf.float32)

# 定义卷积层，20个卷积核，卷积核大小为5，用Relu激活
conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.relu)
# 定义max-pooling层，pooling窗口为2x2，步长为2x2
pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])

# 定义卷积层，40个卷积核，卷积核大小为4，用Relu激活
conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
# 定义max-pooling层，pooling窗口为2x2，步长为2x2
pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])

# 将3维特征转换为1维向量
flatten = tf.layers.flatten(pool1)

# 全连接层，转换为长度为400的特征向量
fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)

# 加上DropOut，防止过拟合
dropout_fc = tf.layers.dropout(fc, dropout_placeholder)

# 未激活的输出层
logits = tf.layers.dense(dropout_fc, num_classes)

predicted_labels = tf.arg_max(logits, 1)

# 利用交叉熵定义损失
losses = tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(labels_placeholder, num_classes),
    logits=logits
)

# 平均损失
mean_loss = tf.reduce_mean(losses)

# 定义优化器，指定要优化的损失函数
optimizer = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(losses)

# 用于保存和载入模型

saver = tf.train.Saver()

with tf.Session() as  sess:
    if train:
        print("训练模式")
        # 如果是训练，初始化参数
        sess.run(tf.global_variables_initializer())
        # 定义输入和label一填充容器，训练时dropout为0.25
        train_feed_dict = {
            datas_placeholder: datas,
            labels_placeholder: labels,
            dropout_placeholder: 0.25
        }
        for step in range(500):
            _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
            if step % 5 == 0:
                print("step = {}\tmean_loss = {}".format(step, mean_loss_val))
        saver.save(sess, model_path)
        print("训练结束，保存模型到{}".format(model_path))
    else:
        print("测试模式")
        # 如果是测试，载入参数
        saver.restore(sess, model_path)
        print("从{}载入模型".format(model_path))

        # 定义输入和label以填充容器，测试时dropout为0
        test_feed_dict = {
            datas_placeholder: datas,
            # labels_placeholder: labels,
            dropout_placeholder: 0
        }
        predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
        tested_img = 0
        correct_img = 0
        uncorrect_img = 0
        # 真实label与模型预测label名
        for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
            # 将label id转换为label名
            real_label_name = label_name_dict[real_label]
            predicted_label_name = label_name_dict[predicted_label]
            tested_img += 1
            if real_label_name == predicted_label_name:
                correct_img += 1
            else:
                uncorrect_img += 1
            print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))
        print("tested img count: {}, correct img: {}, uncorrect img: {}, accuracy: {}".format(
            tested_img, correct_img, uncorrect_img, correct_img / tested_img
        ))

shape of datas: (5000, 32, 32, 3)	shape of labels: (5000,)
测试模式
INFO:tensorflow:Restoring parameters from /content/drive/My Drive/dataset/model/cifar-10-cnn/
从/content/drive/My Drive/dataset/model/cifar-10-cnn/载入模型
/content/drive/My Drive/dataset/cifar-10-jpg/test/3_0.jpg	cat => cat
/content/drive/My Drive/dataset/cifar-10-jpg/test/0_3.jpg	airplane => airplane
/content/drive/My Drive/dataset/cifar-10-jpg/test/1_6.jpg	automobile => automobile
/content/drive/My Drive/dataset/cifar-10-jpg/test/3_8.jpg	cat => cat
/content/drive/My Drive/dataset/cifar-10-jpg/test/1_9.jpg	automobile => automobile
/content/drive/My Drive/dataset/cifar-10-jpg/test/0_10.jpg	airplane => airplane
/content/drive/My Drive/dataset/cifar-10-jpg/test/0_21.jpg	airplane => airplane
/content/drive/My Drive/dataset/cifar-10-jpg/test/4_22.jpg	deer => deer
/content/drive/My Drive/dataset/cifar-10-jpg/test/2_25.jpg	bird => bird
/content/drive/My Drive/dataset/cifar-10-jpg/test/4_26.jpg	deer => deer
/content/drive/My Drive/da

In [0]:
!pwd
!ls -la
!mv .*

/content/drive/My Drive/dataset/model
total 9904
drwx------ 3 root root    4096 Feb 11 11:39 cifar-10-cnn
-rw------- 1 root root 9983128 Feb 11 09:58 .data-00000-of-00001
-rw------- 1 root root    1787 Feb 11 09:58 .index
drwx------ 2 root root    4096 Feb 11 07:55 .ipynb_checkpoints
-rw------- 1 root root  147886 Feb 11 09:58 .meta
mv: target '.meta' is not a directory


In [0]:
import tensorflow as tf
print(tf.__version__)

1.15.0


In [0]:
# 功能：将cifar10的数据集分类生成图片

import matplotlib.pyplot as plt
import numpy as np
import os

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        data = pickle.load(fo, encoding='bytes')
    return data


dataset_path = "cifar-10-batches-py/"
img_save_path = "cifar-10-jpg/"
file_name_list = ['batches.meta', 'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5',
                  'test_batch']
generate_train_img = True
generate_test_img = True

# 获取 lable 与 id 的对应关系
def get_lables():
    file_dict = unpickle(dataset_path + file_name_list[0])
    return file_dict[b'label_names']


if __name__ == "__main__":
    print(get_lables())
    if generate_train_img:
        lable_count = np.zeros(10, dtype="int")
        for i in range(1, 6):
            file_data = unpickle(dataset_path + file_name_list[i])
            for j in range(0, 10000):
                img_data = np.reshape(file_data[b'data'][j], (3, 32, 32))
                img_data = img_data.transpose((1, 2, 0))
                lable_id = file_data[b'labels'][j]
                img_name = str(lable_id) + "_" + str(lable_count[lable_id]) + ".jpg"
                if not os.path.exists(img_save_path + str(lable_id)):
                    os.makedirs(img_save_path + str(lable_id))
                plt.imsave(img_save_path + str(lable_id) + "/" + img_name, img_data)
                lable_count[lable_id] += 1
                print('\t', img_name, "saved ……")
            print("【*】---", i, "st dataset_path saved")
    if generate_test_img:
        file_data = unpickle(dataset_path + file_name_list[6])
        for k in range(0, 10000):
            img_data = np.reshape(file_data[b'data'][k], (3, 32, 32))
            img_data = img_data.transpose((1, 2, 0))
            lable_id = file_data[b'labels'][k]
            img_name = str(lable_id) + "_" + str(k) + ".jpg"
            if not os.path.exists(img_save_path + "test/"):
                os.makedirs(img_save_path + "test/")
            plt.imsave(img_save_path + "test/" + img_name, img_data)
            print('\t', img_name, "saved ……")
        print("【*】---", "test dataset_path saved")
    print("*******ALl Task Finished********")


In [0]:
import requests

if __name__ == '__main__':
    print("downloading with requests")
    url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
    r = requests.get(url)
    with open("cifar-10-python.tar.gz", "wb") as code:
        code.write(r.content)
    print("downloading finished")
