<a href="https://colab.research.google.com/github/HudsonHuang/mixbatch/blob/master/cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
# from  matplotlib import pyplot as plt
np.random.seed(6) #随机数，这样做的目的是在每次运行程序时，初始值保持一致。seed的值可以随便设置。
from __future__ import print_function
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
import os
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense , Dropout ,Activation, Flatten
from keras.layers import Conv2D ,MaxPooling2D   #卷积层，池化层
from keras.utils import np_utils
from keras.datasets import mnist
import keras.backend.tensorflow_backend as KTF
import keras.backend as K
from keras.layers import Layer

np.random.seed(2019)
tf.set_random_seed(2019)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True  # 不全部占满显存, 按需分配
sess = tf.Session(config=config)

KTF.set_session(sess)



Using TensorFlow backend.


In [0]:

class MixBatch(Layer):
    def __init__(self, alpha=0.2, **kwargs) -> None:
        self.alpha = alpha
        self.axis = 0
        super().__init__(**kwargs)

    def get_beta(self,input_shape,sample_size):
        beta_shape = [1] * len(list(input_shape))
        beta_shape[0] = sample_size
        return tf.distributions.Beta(self.alpha, self.alpha).sample(beta_shape)

    def build(self, input_shape):
        super().build(input_shape)

    def mixup_tf(self, features,sample_size):
        # do mixup here with tensorflow version
        input_shape = K.int_shape(features)
        mix = self.get_beta(features.shape,sample_size)
        mix = tf.maximum(mix, 1 - mix)  # contrl to let data close to x1
        features = features * mix + features[::-1] * (1 - mix)
        return features

    def call(self, x,training=None, **kwargs):
        if training:
            sample_size = K.prod([K.shape(x)[axis]
                        for axis in self.axis])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs))
            
            if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
                sample_size = K.cast(sample_size, dtype='float32')
          
            return self.mixup_tf(x,sample_size)
        else:
            return x

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = {
            'alpha': self.alpha,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))


In [0]:
def get_default_callbacks(log_dir=None,comment=''):
    import time
    import os

    if log_dir==None:
        def _get_current_date():
            strDate = time.strftime('%Y%m%d_%H%M%S',
                                    time.localtime(time.time())) 
            return strDate
        log_dir = os.path.join('./',_get_current_date()+comment)

    cb = []
    weight_path = os.path.join(log_dir, 'model.h5')
    ckpt_cb = keras.callbacks.ModelCheckpoint(weight_path,
                                            save_weights_only=True,
                                            save_best_only=True)
    cb.append(ckpt_cb)

    tb_cb = keras.callbacks.TensorBoard(log_dir=log_dir)
    cb.append(tb_cb)

    # nan_cb = keras.callbacks.TerminateOnNaN()
    # cb.append(nan_cb)

    # update_pruning = sparsity.UpdatePruningStep(),
    # pruning_summary = ssparsity.PruningSummaries(log_dir=log_dir, profile_batch=0)
    # cb.append(update_pruning)
    # cb.append(pruning_summary)
    return cb


In [4]:
batch_size = 32
num_classes = 10
epochs = 100
data_augmentation = False
num_predictions = 20
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'

# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(MixBatch())
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

W0716 03:24:50.020672 140624507115392 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.



x_train shape: (50000, 32, 32, 3)
50000 train samples
10000 test samples


In [5]:

# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)

# Let's train the model using RMSprop
model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

if not data_augmentation:
    print('Not using data augmentation.')
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_data=(x_test, y_test),
              shuffle=True)
else:
    print('Using real-time data augmentation.')
    # This will do preprocessing and realtime data augmentation:
    datagen = ImageDataGenerator(
        featurewise_center=False,  # set input mean to 0 over the dataset
        samplewise_center=False,  # set each sample mean to 0
        featurewise_std_normalization=False,  # divide inputs by std of the dataset
        samplewise_std_normalization=False,  # divide each input by its std
        zca_whitening=False,  # apply ZCA whitening
        zca_epsilon=1e-06,  # epsilon for ZCA whitening
        rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
        # randomly shift images horizontally (fraction of total width)
        width_shift_range=0.1,
        # randomly shift images vertically (fraction of total height)
        height_shift_range=0.1,
        shear_range=0.,  # set range for random shear
        zoom_range=0.,  # set range for random zoom
        channel_shift_range=0.,  # set range for random channel shifts
        # set mode for filling points outside the input boundaries
        fill_mode='nearest',
        cval=0.,  # value used for fill_mode = "constant"
        horizontal_flip=True,  # randomly flip images
        vertical_flip=False,  # randomly flip images
        # set rescaling factor (applied before any other transformation)
        rescale=None,
        # set function that will be applied on each input
        preprocessing_function=None,
        # image data format, either "channels_first" or "channels_last"
        data_format=None,
        # fraction of images reserved for validation (strictly between 0 and 1)
        validation_split=0.0)

    # Compute quantities required for feature-wise normalization
    # (std, mean, and principal components if ZCA whitening is applied).
    datagen.fit(x_train)

    print("x_train",len(x_train))
    datagen = datagen.flow(x_train, y_train,
                                     batch_size=batch_size)
    # Fit the model on the batches generated by datagen.flow().
    model.fit_generator(datagen,
                        samples_per_epoch = len(x_train) ,
                        epochs=epochs,
                        validation_data=(x_test, y_test),
                        workers=4)

# Save model and weights
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)

# Score trained model.
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])


W0716 03:24:50.629302 140624507115392 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W0716 03:24:54.213476 140624507115392 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0716 03:24:54.228522 140624507115392 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.



Not using data augmentation.


W0716 03:24:54.426520 140624507115392 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

W0716 03:24:54.448096 140624507115392 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:133: The name tf.placeholder_with_default is deprecated. Please use tf.compat.v1.placeholder_with_default instead.

W0716 03:24:54.518524 140624507115392 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
W0716 03:24:55.823170 140624507115392 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tens

Train on 50000 samples, validate on 10000 samples
Epoch 1/100
11136/50000 [=====>........................] - ETA: 2:12 - loss: 2.1715 - acc: 0.1836

KeyboardInterrupt: ignored