<a href="https://colab.research.google.com/github/Machine-Learning-Tokyo/CNN-Architectures/blob/master/Implementations/ShuffleNet/ShuffleNet_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
from tensorflow.keras.layers import Input, Conv2D, DepthwiseConv2D, \
     Dense, Concatenate, Add, ReLU, BatchNormalization, AvgPool2D, \
     MaxPool2D, GlobalAvgPool2D, Reshape, Permute, Lambda, Dropout

import food_mnist

import numpy as np
from keras.utils.np_utils import to_categorical

from tensorflow.keras import Model

import warnings
warnings.filterwarnings("ignore")

In [7]:
(x_train, y_train), (x_test, y_test) = food_mnist.load_data()
labels_names = food_mnist.labels()

module google.colab.patches not imported.


In [8]:
x_train = x_train/255.
x_test = x_test/255.

y_train = np.array(y_train)
y_test = np.array(y_test)

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [17]:
def stage(x, channels, repetitions, groups):
    x = shufflenet_block(x, channels=channels, strides=2, groups=groups)
    for i in range(repetitions):
        x = shufflenet_block(x, channels=channels, strides=1, groups=groups)
    return x


def shufflenet_block(tensor, channels, strides, groups):
    x = gconv(tensor, channels=channels // 4, groups=groups)
    x = BatchNormalization()(x)
    x = ReLU()(x)
 
    x = channel_shuffle(x, groups)
    x = DepthwiseConv2D(kernel_size=3, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
 
    if strides == 2:
        channels = channels - tensor.get_shape().as_list()[-1]
    x = gconv(x, channels=channels, groups=groups)
    x = BatchNormalization()(x)
 
    if strides == 1:
        x = Add()([tensor, x])
    else:
        avg = AvgPool2D(pool_size=3, strides=2, padding='same')(tensor)
        x = Concatenate()([avg, x])
 
    output = ReLU()(x)
    return output


def gconv(tensor, channels, groups):
    input_ch = tensor.get_shape().as_list()[-1]
    group_ch = input_ch // groups
    output_ch = channels // groups
    groups_list = []
 
    for i in range(groups):
        # group_tensor = tensor[:, :, :, i * group_ch: (i+1) * group_ch]
        group_tensor = Lambda(lambda x: x[:, :, :, i * group_ch: (i+1) * group_ch])(tensor)
        group_tensor = Conv2D(output_ch, 1)(group_tensor)
        groups_list.append(group_tensor)
 
    output = Concatenate()(groups_list)
    return output


def channel_shuffle(x, groups):  
    _, width, height, channels = x.get_shape().as_list()
    group_ch = channels // groups
 
    x = Reshape([width, height, group_ch, groups])(x)
    x = Permute([1, 2, 4, 3])(x)
    x = Reshape([width, height, channels])(x)
    return x


input = Input([224, 224, 3])
x = Conv2D(filters=24, kernel_size=3, strides=2, padding='same')(input)
x = BatchNormalization()(x)
x = ReLU()(x)
x = MaxPool2D(pool_size=3, strides=2, padding='same')(x)


repetitions = 3, 7, 3
initial_channels = 384
groups = 8
 
for i, reps in enumerate(repetitions):
    channels = initial_channels * (2**i)
    x = stage(x, channels, reps, groups)


x = GlobalAvgPool2D()(x)

output = Dense(10, activation='softmax')(x)

In [18]:
model = Model(input, output)
model.compile(optimizer='Adam',loss='categorical_crossentropy',metrics=['accuracy'])
model.summary()

Model: "functional_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv2d_1028 (Conv2D)            (None, 112, 112, 24) 672         input_5[0][0]                    
__________________________________________________________________________________________________
batch_normalization_196 (BatchN (None, 112, 112, 24) 96          conv2d_1028[0][0]                
__________________________________________________________________________________________________
re_lu_132 (ReLU)                (None, 112, 112, 24) 0           batch_normalization_196[0][0]    
_______________________________________________________________________________________

In [19]:
for i,layer in enumerate(model.layers):
  print("{}: {}".format(i,layer))

0: <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fd469d9b0d0>
1: <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fd7ed013a30>
2: <tensorflow.python.keras.layers.normalization_v2.BatchNormalization object at 0x7fd7f1cfd400>
3: <tensorflow.python.keras.layers.advanced_activations.ReLU object at 0x7fd7ed013940>
4: <tensorflow.python.keras.layers.pooling.MaxPooling2D object at 0x7fd7ed036b80>
5: <tensorflow.python.keras.layers.core.Lambda object at 0x7fd7ed006c40>
6: <tensorflow.python.keras.layers.core.Lambda object at 0x7fd4684f6e80>
7: <tensorflow.python.keras.layers.core.Lambda object at 0x7fd7ecff2970>
8: <tensorflow.python.keras.layers.core.Lambda object at 0x7fd7ecfe9c40>
9: <tensorflow.python.keras.layers.core.Lambda object at 0x7fd7ecfdec10>
10: <tensorflow.python.keras.layers.core.Lambda object at 0x7fd7ecfe5220>
11: <tensorflow.python.keras.layers.core.Lambda object at 0x7fd7ecfa0be0>
12: <tensorflow.python.keras.layers.core.Lambda object 

In [None]:
for layer in model.layers[:712]:
    layer.trainable=False
for layer in model.layers[712:]:
    layer.trainable=True

In [10]:
history = model.fit(x_train, y_train,
              epochs=5,
              batch_size = 224)

In [None]:
ypred = model.predict(x_test)

total = 0
accurate = 0
accurateindex = []
wrongindex = []

for i in range(len(ypred)):
    if np.argmax(ypred[i]) == np.argmax(y_test[i]):
        accurate += 1
        accurateindex.append(i)
    else:
        wrongindex.append(i)
        
    total += 1
    
print('Total-test-data;', total, '\nAccurately-predicted-data:', accurate, '\nWrongly-predicted-data: ', total - accurate)
print('Accuracy:', round(accurate/total*100, 3), '%')