In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import sys
sys.path.append('../')

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, recall_score, precision_score, precision_recall_curve

import tensorflow as tf
import keras 

from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, InputLayer, Flatten, Conv2D, MaxPooling2D

np.random.seed(42)
tf.random.set_seed(42)

from preprocessing.getdata import *

In [11]:
df_a = get_csvdata_ADNI()
df_o= get_csvdata(drop_young=True, drop_contradictions=False)

for ID in ['013_S_1275', '099_S_0533', '123_S_0050', '131_S_0457', '137_S_0796']:
    df_a = df_a[df_a['ID']!=ID]

df_a = df_a[df_a['Group']!='MCI']
y_a = (df_a['Group']!='CN').astype(int)
df_a_train, df_a_test, y_a_train, y_a_test = train_test_split(df_a['ID'], y_a, stratify=y_a, random_state=42)

y_o = df_o['CDR']
df_o_train, df_o_test, y_o_train, y_o_test = train_test_split(df_o['ID'], y_o, stratify=y_o, random_state=42)

In [12]:
N=5
d=1
mdict = {0: 95, 1: 110, 2: 90}

y_o_train = y_o_train.repeat(1+2*N)
y_a_train = y_a_train.repeat(1+2*N)

for dim in range(3):
  m = mdict[dim]
  X_train_o = get_slices(df_o_train, dim=dim, m=m, N=N, d=d)
  X_train_a = get_slices_ADNI2(df_a_train, dim=dim, m=m, N=N, d=d)

  X_test_o = get_slices(df_o_test, dim=dim, m=m)
  X_test_a = get_slices_ADNI2(df_a_test, dim=dim, m=m)

  X_train = np.concatenate((X_train_o, X_train_a), axis=0)
  X_test = np.concatenate((X_test_o, X_test_a), axis=0)

  y_train = np.concatenate((y_o_train, y_a_train))
  y_test = np.concatenate((y_o_test, y_a_test))

  X_train = np.repeat(X_train[..., np.newaxis], 3, -1)
  X_test = np.repeat(X_test[..., np.newaxis], 3, -1)


  HEIGHT = X_train.shape[1]
  WIDTH = X_train.shape[2]

  INPUT_SHAPE = (HEIGHT, WIDTH, 3)
  b_model = tf.keras.applications.VGG16(include_top=False, weights='imagenet', input_shape=INPUT_SHAPE)

  model = Sequential()
  model.add(InputLayer(input_shape=INPUT_SHAPE))
  model.add(b_model)
  model.add(Flatten())
  model.add(Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(l=0.1)))
  model.add(Dropout(0.3))
  model.add(Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(l=0.1)))
  model.add(Dense(1, activation="sigmoid"))

  # Defining optimizer and learning rate
  lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(
      0.00001,
      decay_steps=10000,
      decay_rate=1,
      staircase=False)
  optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, name='Adam')
  model.compile(optimizer = optimizer, loss = 'binary_crossentropy', metrics = ['accuracy'])

  with tf.device('/device:GPU:0'):
    model_history = model.fit(X_train, y_train, epochs=20, shuffle=True, validation_data=(X_test, y_test))
  
  model.save('axis'+str(dim))

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [None]:
plt.plot(model_history.history['accuracy'])
plt.plot(model_history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.ylim([0.5,1.1])
plt.show()

In [None]:
plt.plot(model_history.history['loss'])
plt.plot(model_history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
#plt.ylim([0.5,1.1])
plt.show()

In [None]:
y_predd = model.predict(X_test)
y_pred = (y_predd>0.5).astype(int)

# Plotting the confusing matrix
mat = confusion_matrix(y_test, y_pred.round())
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False)
plt.xlabel('true label')
plt.ylabel('predicted label')
print('accuracy: ' , accuracy_score(y_test, y_pred.round()).round(2))
print('recall: ' , recall_score(y_test, y_pred.round()).round(2))
print('precision: ' , precision_score(y_test, y_pred.round()).round(2))