In [17]:
import numpy as np
import h5py
from superai.nn.layer.fc import FullyConnected
from superai.nn.layer.conv import Conv
from superai.nn.layer.pool import PoolingLayer
from superai.nn.layer.activator import Activator
from superai.nn.layer.flatten import Flatten
from superai.nn.model.nnet import Sequence
from superai.nn.optimizer.optimizer import Adam
import os
import struct
'''
from deepnet.layers import *
from deepnet.nnet import CNN
'''
def convert_to_one_hot(Y, C):
    Y = np.eye(C)[Y.reshape(-1)].T
    return Y

def load_dataset():
    train_dataset = h5py.File('datasets/train_signs.h5', "r")
    train_set_x_orig = np.array(train_dataset["train_set_x"][:]) # your train set features
    train_set_y_orig = np.array(train_dataset["train_set_y"][:]) # your train set labels

    test_dataset = h5py.File('datasets/test_signs.h5', "r")
    test_set_x_orig = np.array(test_dataset["test_set_x"][:]) # your test set features
    test_set_y_orig = np.array(test_dataset["test_set_y"][:]) # your test set labels

    classes = np.array(test_dataset["list_classes"][:]) # the list of classes
    
    train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
    test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))
    
    return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes


X_train_orig, Y_train_orig, X_test_orig, Y_test_orig, classes = load_dataset()
X_train = X_train_orig/255.
X_train = X_train.transpose(3, 1, 2, 0)
X_test = X_test_orig/255.
X_test = X_test.transpose(3, 1, 2, 0)
Y_train = convert_to_one_hot(Y_train_orig, 6).T
Y_train = Y_train.transpose(1, 0)
Y_test = convert_to_one_hot(Y_test_orig, 6).T
Y_test = Y_test.transpose(1, 0)

In [18]:
import math
def compute_samemode_pad(in_width, in_height, filter_size, strides):
    # 先确定输出维度，记住是上取整
    stride_width, stride_height = strides
    filter_width, filter_height = filter_size
    out_height = math.ceil(float(in_height) / float(stride_height))
    out_width  = math.ceil(float(in_width) / float(stride_width))

    # 上面的公式
    if (in_height % stride_height == 0):
        pad_along_height = max(filter_height - stride_height, 0)
    else:
        pad_along_height = max(filter_height - (in_height % stride_height), 0)
    if (in_width % stride_width == 0):
        pad_along_width = max(filter_width - stride_width, 0)
    else:
        pad_along_width = max(filter_width - (in_width % stride_width), 0)

    # 因为pad是在上下、左右四侧pad。所以当pi不为偶数时要分配下
    # 这里是当pi为奇数时，下侧比上侧多一，右侧比左侧多一。
    #  Note that this is different from existing libraries such as cuDNN and Caffe, which explicitly specify the number of padded pixels and always pad the same number of pixels on both sides.
    pad_top = pad_along_height // 2
    pad_bottom = pad_along_height - pad_top
    pad_left = pad_along_width // 2
    pad_right = pad_along_width - pad_left
    return (pad_top, pad_bottom, pad_left, pad_right)

In [19]:

pad = compute_samemode_pad(64, 64, (4, 4), (1, 1))
cnn_layer1 = Conv((3, 4, 4), 8, pad, (1, 1))
cnn_layer1.layer = 1
relu1 = Activator("relu")

pad = compute_samemode_pad(64, 64, (8, 8), (8, 8))
pool_layer1 = PoolingLayer((8, 8), pad, 'max', (8, 8))

pad = compute_samemode_pad(57, 57, (2, 2), (1, 1))
cnn_layer2 = Conv((8, 2, 2), 16, pad, (1, 1))
cnn_layer2.layer = 2
relu2 = Activator("relu")

pad = compute_samemode_pad(8, 8, (4, 4), (4, 4))
pool_layer2 = PoolingLayer((4, 4), pad, 'max', (4, 4))

flatten = Flatten()
dense = FullyConnected(64, 6)

model = Sequence([cnn_layer1, relu1, pool_layer1, cnn_layer2, relu2, pool_layer2, flatten, dense], learning_rate=0.01, iteration_count=100, lambd=0,
                 use_mini_batch=True, mini_batch_size=64)
'''
cnn_layer1 = Conv((3, 5, 5), 16, (2, 2, 2, 2), (1, 1))
cnn_layer1.layer = 1
relu1 = Activator("relu")
pool_layer1 = PoolingLayer((2, 2), (0,0,0,0), 'max', (2, 2))
cnn_layer2 = Conv((16, 5, 5), 20, (2, 2, 2, 2), (1, 1))
cnn_layer2.layer = 2
relu2 = Activator("relu")
pool_layer2 = PoolingLayer((2, 2), (0,0,0,0), 'max', (2, 2))

flatten = Flatten()
dense = FullyConnected(5120, 6)

model = Sequence([cnn_layer1, relu1, pool_layer1, cnn_layer2, relu2, pool_layer2, flatten, dense], learning_rate=0.01, iteration_count=300, lambd=0,
                 use_mini_batch=False, mini_batch_size=64)
'''
# model.fit(X_train[:, :, :,0:40], Y_train[:, 0:40])
adamOpt = Adam()
adamOpt.run(model, X_train[:, :, :, 0:1000], Y_train[: , 0:1000])


iteration1  cost:8.640417176463735, accuracy:0.09375
iteration2  cost:5.607568888105896, accuracy:0.1875
iteration3  cost:4.305623834318485, accuracy:0.1875
iteration4  cost:3.7193326679371643, accuracy:0.125
iteration5  cost:3.098267457821398, accuracy:0.171875
iteration6  cost:2.4209705331905047, accuracy:0.125
iteration7  cost:2.5097486246504164, accuracy:0.15625
iteration8  cost:2.7401565848198977, accuracy:0.09375
iteration9  cost:2.5709734020426875, accuracy:0.21875
iteration10  cost:2.629469693383035, accuracy:0.21875
iteration11  cost:2.5183631453700013, accuracy:0.28125
iteration12  cost:2.5719689504269043, accuracy:0.09375
iteration13  cost:2.231446706989805, accuracy:0.171875
iteration14  cost:2.4198253820565605, accuracy:0.109375
iteration15  cost:2.407610826957165, accuracy:0.15625
iteration16  cost:2.2492846295818145, accuracy:0.17500000000000004
iteration17  cost:2.247276213282781, accuracy:0.109375
iteration18  cost:2.1490079530188604, accuracy:0.09375
iteration19  cost

iteration150  cost:1.5702848434746224, accuracy:0.28125
iteration151  cost:1.6335385145338046, accuracy:0.25
iteration152  cost:1.5741910482895358, accuracy:0.375
iteration153  cost:1.536581577002699, accuracy:0.4375
iteration154  cost:1.5789143933818677, accuracy:0.296875
iteration155  cost:1.6573421707903029, accuracy:0.3125
iteration156  cost:1.5332025826957265, accuracy:0.421875
iteration157  cost:1.5859722426841163, accuracy:0.265625
iteration158  cost:1.606857890247228, accuracy:0.3125
iteration159  cost:1.7209187348783728, accuracy:0.3125
iteration160  cost:1.6442061823608476, accuracy:0.30000000000000004
iteration161  cost:1.555854214096289, accuracy:0.390625
iteration162  cost:1.6009349972828588, accuracy:0.265625
iteration163  cost:1.5555920138162156, accuracy:0.3125
iteration164  cost:1.570010084039568, accuracy:0.375
iteration165  cost:1.5498847513244318, accuracy:0.390625
iteration166  cost:1.5340282618383567, accuracy:0.3125
iteration167  cost:1.5347311189346633, accuracy

iteration297  cost:1.199510440572722, accuracy:0.53125
iteration298  cost:1.163999356975296, accuracy:0.546875
iteration299  cost:1.1418774563753848, accuracy:0.578125
iteration300  cost:1.1616673963991189, accuracy:0.5625
iteration301  cost:1.3835759746382743, accuracy:0.46875
iteration302  cost:1.2441336213736318, accuracy:0.546875
iteration303  cost:1.2090599801479434, accuracy:0.578125
iteration304  cost:1.1184210272388941, accuracy:0.475
iteration305  cost:1.117596910309207, accuracy:0.59375
iteration306  cost:1.196582886989387, accuracy:0.5
iteration307  cost:1.128433013634647, accuracy:0.578125
iteration308  cost:1.1158143658829616, accuracy:0.625
iteration309  cost:1.0815420812033825, accuracy:0.65625
iteration310  cost:1.1916629488203219, accuracy:0.515625
iteration311  cost:1.10681197397462, accuracy:0.59375
iteration312  cost:1.4252810982488864, accuracy:0.453125
iteration313  cost:1.2715222590305888, accuracy:0.4375
iteration314  cost:1.0985633349388102, accuracy:0.59375
it

iteration444  cost:0.9585545385453852, accuracy:0.625
iteration445  cost:0.7510854599683161, accuracy:0.71875
iteration446  cost:0.7935719164280457, accuracy:0.734375
iteration447  cost:0.7798478001117422, accuracy:0.78125
iteration448  cost:0.7753480811515212, accuracy:0.75
iteration449  cost:0.8606413739795513, accuracy:0.671875
iteration450  cost:0.7666692351405886, accuracy:0.734375
iteration451  cost:0.9939711070237889, accuracy:0.703125
iteration452  cost:0.8152256132319896, accuracy:0.734375
iteration453  cost:0.6749903219297704, accuracy:0.78125
iteration454  cost:0.975935274831565, accuracy:0.640625
iteration455  cost:0.9910189043920521, accuracy:0.640625
iteration456  cost:0.8584721095733961, accuracy:0.703125
iteration457  cost:0.8768398740343766, accuracy:0.65625
iteration458  cost:0.7727381233263, accuracy:0.75
iteration459  cost:0.7912634836491746, accuracy:0.703125
iteration460  cost:0.8675771011402025, accuracy:0.609375
iteration461  cost:0.752113614579137, accuracy:0.6

iteration591  cost:0.7519923281761942, accuracy:0.765625
iteration592  cost:0.6984755118840923, accuracy:0.75
iteration593  cost:0.7821017433629918, accuracy:0.671875
iteration594  cost:0.7026639493186455, accuracy:0.765625
iteration595  cost:0.591330161149674, accuracy:0.828125
iteration596  cost:0.6404378144306764, accuracy:0.71875
iteration597  cost:0.7249203126943874, accuracy:0.71875
iteration598  cost:0.5095539373829212, accuracy:0.796875
iteration599  cost:0.7091145037501834, accuracy:0.765625
iteration600  cost:0.6784012913370545, accuracy:0.765625
iteration601  cost:0.6772585861997399, accuracy:0.734375
iteration602  cost:0.7836023220865685, accuracy:0.703125
iteration603  cost:0.5639520609429414, accuracy:0.8125
iteration604  cost:0.7020459375523509, accuracy:0.71875
iteration605  cost:0.4950311269629408, accuracy:0.875
iteration606  cost:0.6375819500544706, accuracy:0.796875
iteration607  cost:0.6203224160285219, accuracy:0.75
iteration608  cost:0.6751887064126156, accuracy:

iteration738  cost:0.5078373836796077, accuracy:0.84375
iteration739  cost:0.377142566878498, accuracy:0.890625
iteration740  cost:0.40696218151928, accuracy:0.828125
iteration741  cost:0.49504838479743246, accuracy:0.796875
iteration742  cost:0.39468506700863243, accuracy:0.875
iteration743  cost:0.2539085282476316, accuracy:0.96875
iteration744  cost:0.4437403281592525, accuracy:0.84375
iteration745  cost:0.35330115071879, accuracy:0.890625
iteration746  cost:0.41352159643906966, accuracy:0.828125
iteration747  cost:0.5719304057571906, accuracy:0.78125
iteration748  cost:0.48889248761063536, accuracy:0.796875
iteration749  cost:0.5092667515494472, accuracy:0.828125
iteration750  cost:0.37390949419082714, accuracy:0.90625
iteration751  cost:0.45079223998197204, accuracy:0.84375
iteration752  cost:0.4340881612282767, accuracy:0.825
iteration753  cost:0.4010233371212562, accuracy:0.859375
iteration754  cost:0.5083972275695889, accuracy:0.84375
iteration755  cost:0.3434797607052624, accu

iteration884  cost:0.3914656233306689, accuracy:0.84375
iteration885  cost:0.3441714696265643, accuracy:0.9375
iteration886  cost:0.4642925467428393, accuracy:0.859375
iteration887  cost:0.38492986319796374, accuracy:0.84375
iteration888  cost:0.3154348115723625, accuracy:0.890625
iteration889  cost:0.24000420545563583, accuracy:0.953125
iteration890  cost:0.304864818704542, accuracy:0.875
iteration891  cost:0.4056426309543393, accuracy:0.828125
iteration892  cost:0.33226397026830684, accuracy:0.890625
iteration893  cost:0.2376444068318434, accuracy:0.921875
iteration894  cost:0.3764301573098228, accuracy:0.875
iteration895  cost:0.4457859707314548, accuracy:0.890625
iteration896  cost:0.3155422052361291, accuracy:0.875
iteration897  cost:0.22198277011165812, accuracy:0.90625
iteration898  cost:0.3743111415019612, accuracy:0.84375
iteration899  cost:0.394516274889035, accuracy:0.8125
iteration900  cost:0.307826766592044, accuracy:0.875
iteration901  cost:0.24932974941789288, accuracy:0

iteration1029  cost:0.16410892805377425, accuracy:0.96875
iteration1030  cost:0.17132343005623138, accuracy:0.96875
iteration1031  cost:0.3237855527466922, accuracy:0.890625
iteration1032  cost:0.33400420802057995, accuracy:0.84375
iteration1033  cost:0.2514690379854757, accuracy:0.9375
iteration1034  cost:0.20439064930589274, accuracy:0.96875
iteration1035  cost:0.3631673979065365, accuracy:0.828125
iteration1036  cost:0.24294790461418161, accuracy:0.953125
iteration1037  cost:0.21677700493656762, accuracy:0.953125
iteration1038  cost:0.3794608689182041, accuracy:0.859375
iteration1039  cost:0.28976433833301873, accuracy:0.875
iteration1040  cost:0.18929237746706404, accuracy:0.975
iteration1041  cost:0.20998244206132727, accuracy:0.921875
iteration1042  cost:0.2843743778725148, accuracy:0.859375
iteration1043  cost:0.326736954357706, accuracy:0.875
iteration1044  cost:0.2293282113302878, accuracy:0.90625
iteration1045  cost:0.255471513905442, accuracy:0.90625
iteration1046  cost:0.16

iteration1172  cost:0.1473823351679855, accuracy:0.96875
iteration1173  cost:0.16260915150340596, accuracy:0.9375
iteration1174  cost:0.36158102927489577, accuracy:0.859375
iteration1175  cost:0.1585386768963499, accuracy:0.9375
iteration1176  cost:0.2423052176867586, accuracy:0.890625
iteration1177  cost:0.11703099724151778, accuracy:0.96875
iteration1178  cost:0.19967159952348795, accuracy:0.953125
iteration1179  cost:0.2382846581836893, accuracy:0.921875
iteration1180  cost:0.30308824849674376, accuracy:0.875
iteration1181  cost:0.2524774769662319, accuracy:0.90625
iteration1182  cost:0.17242171431889491, accuracy:0.890625
iteration1183  cost:0.314226852623303, accuracy:0.875
iteration1184  cost:0.22965953334549152, accuracy:0.925
iteration1185  cost:0.1725626027582809, accuracy:0.9375
iteration1186  cost:0.23600624336178122, accuracy:0.9375
iteration1187  cost:0.2321966971849496, accuracy:0.921875
iteration1188  cost:0.23743459442598858, accuracy:0.890625
iteration1189  cost:0.1738

iteration1315  cost:0.1550309253692816, accuracy:0.953125
iteration1316  cost:0.268137610762489, accuracy:0.859375
iteration1317  cost:0.1459200833800538, accuracy:0.96875
iteration1318  cost:0.17036519705007486, accuracy:0.921875
iteration1319  cost:0.20169893495105656, accuracy:0.921875
iteration1320  cost:0.1844426941063925, accuracy:0.9375
iteration1321  cost:0.12932455484171482, accuracy:0.96875
iteration1322  cost:0.09283502656758169, accuracy:1.0
iteration1323  cost:0.1310283636159914, accuracy:0.96875
iteration1324  cost:0.14672144535388365, accuracy:0.984375
iteration1325  cost:0.223612377215414, accuracy:0.921875
iteration1326  cost:0.1497796229147098, accuracy:0.96875
iteration1327  cost:0.1191878297404608, accuracy:0.96875
iteration1328  cost:0.14833372492018407, accuracy:0.975
iteration1329  cost:0.15796635063957348, accuracy:0.953125
iteration1330  cost:0.22209396897052675, accuracy:0.921875
iteration1331  cost:0.1512223092091362, accuracy:0.96875
iteration1332  cost:0.08

iteration1458  cost:0.10588206181003278, accuracy:0.96875
iteration1459  cost:0.10742328648305971, accuracy:1.0
iteration1460  cost:0.1820112344074351, accuracy:0.9375
iteration1461  cost:0.16619929237204187, accuracy:0.953125
iteration1462  cost:0.10348857727802385, accuracy:0.953125
iteration1463  cost:0.1520093290053272, accuracy:0.96875
iteration1464  cost:0.177562779894204, accuracy:0.953125
iteration1465  cost:0.12013441408245212, accuracy:0.984375
iteration1466  cost:0.15267151720378705, accuracy:0.9375
iteration1467  cost:0.1366317220560212, accuracy:0.953125
iteration1468  cost:0.2782377935737648, accuracy:0.90625
iteration1469  cost:0.1854683830891977, accuracy:0.921875
iteration1470  cost:0.12882202191503972, accuracy:0.96875
iteration1471  cost:0.09780505872449916, accuracy:0.96875
iteration1472  cost:0.2760492077492985, accuracy:0.875
iteration1473  cost:0.21730387403127585, accuracy:0.890625
iteration1474  cost:0.20892693267341128, accuracy:0.921875
iteration1475  cost:0.