# Feature Extraction using Convolutional Neural Network

In [28]:
import tensorflow as tf
from tensorflow.keras.layers import InputLayer,Conv2D,BatchNormalization,MaxPool2D,Dropout,Flatten,LSTM,Dense,Reshape
from tensorflow.keras import Sequential
import numpy as np
import csv

In [29]:
feature_extraction_layers = [
    InputLayer(input_shape=(15,1025,151),name="Input Layer",batch_size=1),
    Conv2D(filters=32,kernel_size=(5,5),activation="relu",name="Convolution_Layer_1"),
    MaxPool2D(pool_size=(2,2),name="Max_Pool_1"),
    BatchNormalization(),
    Dropout(rate=0.4),
    Conv2D(filters=64,kernel_size=(2,2),activation="relu",name="Convolution_Layer_2"),
    MaxPool2D(pool_size=(2,2),name="Max_Pool_2"),
    BatchNormalization(),
    Conv2D(filters=128,kernel_size=(1,1),activation="relu",name="Convolution_Layer_3"),
    Flatten(),
    Reshape((128,508)),
    LSTM(128,name="LSTM_1"),
    Reshape((128,1)),
    LSTM(127,name="LSTM_2"),
    Dropout(rate=0.4),
    Dense(30,activation="relu"),
    Dense(4),
    Reshape((2,2))
    
    
]

In [30]:
feature_extraction_layers = [
    InputLayer(input_shape=(15,1025,151),name="Input Layer",batch_size=1),
    Conv2D(filters=32,kernel_size=(5,5),activation="relu",name="Convolution_Layer_1"),
    MaxPool2D(pool_size=(2,2),name="Max_Pool_1"),
    BatchNormalization(),
    Reshape((64,1275)),
    LSTM(64,name="LSTM_1"),
    Reshape((64,1)),
    LSTM(64,name="LSTM_2"),
    Dropout(rate=0.4),
    Dense(30,activation="relu"),
    Dense(4),
    Reshape((2,2))
    
    
]

In [31]:
feature_extraction_model = Sequential(feature_extraction_layers)

In [32]:
feature_extraction_model.compile(optimizer=tf.keras.optimizers.Adam(0.00001)
                                 ,metrics=["accuracy"] ,loss=tf.keras.losses.binary_crossentropy)

In [33]:
feature_extraction_model.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Convolution_Layer_1 (Conv2D) (1, 11, 1021, 32)         120832    
_________________________________________________________________
Max_Pool_1 (MaxPooling2D)    (1, 5, 510, 32)           0         
_________________________________________________________________
batch_normalization_6 (Batch (1, 5, 510, 32)           128       
_________________________________________________________________
reshape_14 (Reshape)         (1, 64, 1275)             0         
_________________________________________________________________
LSTM_1 (LSTM)                (1, 64)                   343040    
_________________________________________________________________
reshape_15 (Reshape)         (1, 64, 1)                0         
_________________________________________________________________
LSTM_2 (LSTM)                (1, 64)                  

In [37]:
def create_dataset():
    features = []
    labels=[]
    training_folder="./training/"
    with open("./training/dataset.csv","r") as csv_file:
        file = csv.reader(csv_file)
        for row in file:
            features.append(np.reshape(np.load(training_folder+row[0]),(15,1025,151)))
            labels.append(np.array([int(row[1]),int(row[2])]))
    features = np.array(features)
    labels = np.array(labels)
    labels = tf.keras.utils.to_categorical(labels,num_classes=2)
    dataset = tf.data.Dataset.from_tensor_slices((features,labels))
    dataset = dataset.repeat(60)
    dataset = dataset.shuffle(4)
    dataset = dataset.batch(15)
    dataset = dataset.as_numpy_iterator()
    return features,labels,dataset

    

In [None]:
features,labels,dataset = create_dataset()
#print("Labels without one hot encoding",labels)
#labels = tf.keras.utils.to_categorical(labels,num_classes=2)
#print("Labels",labels)
feature_extraction_model.fit(dataset,epochs=30,batch_size=15,steps_per_epoch=23)
# data = np.reshape(data,(1,15,1025,151))
# prediction = feature_extraction_model.fit(data)
# prediction.shape

In [81]:
result=feature_extraction_model.predict(np.array(np.reshape(np.load("./training/"+"Patient_1_interictal_segment_0004_1.npy"),(1,15,1025,151))))

In [82]:
print("",result)

 [[[ 0.7400652  -0.62076694]
  [-0.02251145 -0.00233709]]]


In [None]:
p