In [0]:
import numpy as np
import os
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, Activation, MaxPool2D, Dropout, Flatten
from keras.layers.convolutional import Convolution3D, MaxPooling3D, ZeroPadding3D
from keras.optimizers import SGD

In [0]:
TRAIN_DATA_PATH = r'C:\Users\Sanjay Saha\CS5242-project\processed_train_data'
TEST_DATA_PATH = r'C:\Users\Sanjay Saha\CS5242-project\processed_test_data'

In [0]:
def get_model(input_shape=(49, 49, 49, 1), class_num=1):
    """Example CNN

    Keyword Arguments:
        input_shape {tuple} -- shape of input images. Should be (28,28,1) for MNIST and (32,32,3) for CIFAR (default: {(28,28,1)})
        class_num {int} -- number of classes. Shoule be 10 for both MNIST and CIFAR10 (default: {10})

    Returns:
        model -- keras.models.Model() object
    """

    im_input = Input (shape=input_shape)

    t = Convolution3D (32, (24, 24, 24), padding='same') (im_input)  # (24,24,24)
    t = MaxPooling3D (pool_size=(2, 2, 2)) (t)
    t = Convolution3D (64, (12, 12, 12), padding='same') (t)
    t = MaxPooling3D (pool_size=(2, 2, 2)) (t)
    t = Convolution3D (128, (6, 6, 6), padding='same') (t)
    t = Flatten () (t)
    t = Dense (256) (t)
    t = Activation ('relu') (t)
    t = Dense (class_num) (t)
    output = Activation ('softmax') (t)
    model = Model (input=im_input, output=output)
    sgd = SGD (lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile (sgd, 'binary_crossentropy', metrics=['accuracy'])

    return model


In [0]:
#########################################################################
#                               TRAINING                                #
#########################################################################
X_train = np.load(os.path.join(TRAIN_DATA_PATH, 'X.npy'))
y_train = np.load(os.path.join(TRAIN_DATA_PATH, 'y.npy'))
print(X_train.shape)
print(y_train.shape)
model = get_model(input_shape=(21, 21, 21, 2))
# print (model.summary ())
history = model.fit (X_train, y_train, epochs=1, verbose=1, validation_split=0.10, shuffle=True)