In [1]:
import tensorflow as tf 
from tensorflow.keras import layers, optimizers, datasets, Sequential, losses
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from sklearn.utils import shuffle

def load_dataset():
    # 在线下载，加载 CIFAR10 数据集
    (X_train, y_train), (X_test, y_test) = datasets.cifar10.load_data()
    # 删除 y 的一个维度， [b,1] => [b]
    y_train = tf.squeeze(y_train, axis=1)
    y_test = tf.squeeze(y_test, axis=1)
    # 构建训练集对象，随机打乱，预处理，批量化
    """
    tf.data.Dataset.from_tensor_slices()的输入可以是numpy也可以是tensor，如果是numpy会自动转化为tensor。
    tf.data.Dataset.shuffle()函数的作用是打乱数据，参数为缓冲区的数据条数
    map表示预处理,参数为与处理函数
    """
    print(X_test.shape)
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_db = train_db.shuffle(1000).map(preprocess).batch(128)
    # 构建测试集对象，预处理，批量化
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    test_db = test_db.map(preprocess).batch(64)
    # 从训练集中采样一个 Batch，并观察
    sample = next(iter(train_db))
    #print('sample:', sample[0].shape, sample[1].shape, tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))
    return train_db, test_db

def preprocess(x, y):
    x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1
    y = tf.cast(y, dtype=tf.int32)
    return x, y

def load_dataset1():
    # 在线下载，加载 CIFAR10 数据集
    (X_train, y_train), (X_test, y_test) = datasets.cifar10.load_data()
    # 删除 y 的一个维度， [b,1] => [b]
    #y_train = tf.squeeze(y_train, axis=1)
    #y_test = tf.squeeze(y_test, axis=1)
    #(X_train, y_train) = shuffle((X_train, y_train))
    X_train, y_train = preprocess(X_train, y_train)
    X_test, y_test = preprocess(X_test, y_test)
    y_test = tf.keras.utils.to_categorical(y_test, 10)
    y_train = tf.keras.utils.to_categorical(y_train, 10)
    return X_train, y_train, X_test, y_test

#train_db, test_db = load_dataset()
X_train, y_train, X_test, y_test = load_dataset1()
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)



(50000, 32, 32, 3)
(10000, 32, 32, 3)
(50000, 10)
(10000, 10)


In [2]:
def network_model():
    model = Sequential()
    model.add(Conv2D(64, kernel_size=(3,3), padding="same", activation="relu", input_shape=[32, 32, 3]))
    model.add(Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu))
    model.add(MaxPooling2D(pool_size=(2,2), strides=2, padding="same"))
        
    model.add(Conv2D(128, kernel_size=(3,3), padding="same", activation="relu"))
    model.add(Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu))
    model.add(MaxPooling2D(pool_size=(2,2), strides=2, padding="same"))

    model.add(Conv2D(256, kernel_size=(3,3), padding="same", activation="relu"))
    model.add(Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu))
    model.add(MaxPooling2D(pool_size=(2,2), strides=2, padding="same"))

    model.add(Conv2D(512, kernel_size=(3,3), padding="same", activation="relu"))
    model.add(Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu))
    model.add(MaxPooling2D(pool_size=(2,2), strides=2, padding="same"))
    
    model.add(Conv2D(512, kernel_size=(3,3), padding="same", activation="relu"))
    model.add(Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu))
    model.add(MaxPooling2D(pool_size=(2,2), strides=2, padding="same"))
    
    model.add(Flatten())###把上层的输出拉平
    model.add(Dense(256, activation=tf.nn.relu))
    model.add(Dense(128, activation=tf.nn.relu))
    model.add(Dense(10, activation="softmax"))
    
    #model.summary()
    return model


In [3]:
model = network_model()
model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=optimizers.Adam(), metrics=['accuracy'])

In [4]:
model.fit(X_train, y_train, batch_size=128, epochs=15, validation_data=(X_test, y_test))


Train on 50000 samples, validate on 10000 samples
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


<tensorflow.python.keras.callbacks.History at 0x243b83bba48>