In [1]:
import numpy as np

import os
#Choose GPU 0 as a default
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import tensorflow as tf
from tensorflow.keras.layers import Flatten, Dense, Activation
from tensorflow.keras.constraints import max_norm
from tensorflow.keras.models import Model
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint,EarlyStopping

import sys
sys.path.append('/home/zsteineh/cnn_hilbert/cnn_hilbert_workspace')
import hilbert_DL_utils
from hilbert_DL_utils import load_data

In [2]:
fold = 2
pretask_type = 'rel_pos'
model_dir = '/data1/users/gsquist/state_decoder/accuracy_outputs/a0f66459/class_ssl/'
model_name = pretask_type+'_model_htnet_fold_'+str(fold)+'.h5'

model_fname = model_dir + model_name

norm_rate = 0.25
wrist_lp = '/data1/users/stepeter/cnn_hilbert/ecog_data/xarray/'
pats_ids_in = ['a0f66459']
test_day = 'last'
n_chans_all=64
tlim=[-1,1]
n_folds = 1

optimizer='adam'
loss='binary_crossentropy'
patience = 15
early_stop_monitor='val_loss'
epochs=64
sp = '/home/zsteineh/ez_ssl_results/'
chckpt_path = sp+'checkpoint_gen_tl_'+pats_ids_in[0]+'_fold'+str(fold)+'.h5'

In [3]:
X,y,x_test,y_test,sbj_order_all,sbj_order_test_last = load_data(pats_ids_in, wrist_lp,
                                                              n_chans_all=n_chans_all,
                                                              test_day=test_day, tlim=tlim)

100%|██████████| 1/1 [00:00<00:00,  2.58it/s]

Data loaded!





In [4]:
nb_classes = len(np.unique(y))
order_inds = np.arange(len(y))
np.random.shuffle(order_inds)
X = X[order_inds,...]
y = y[order_inds]
order_inds_test = np.arange(len(y_test))
np.random.shuffle(order_inds_test)
# X_test = X_test[order_inds_test,...]
# y_test = y_test[order_inds_test]
y2 = np_utils.to_categorical(y-1)
y_test2 = np_utils.to_categorical(y_test-1)
X2 = np.expand_dims(X,1)
X_test2 = np.expand_dims(x_test,1)

split_len = int(X2.shape[0]*0.2)
last_epochs = np.zeros([n_folds,2])

val_inds = np.arange(0,split_len)+(0*split_len)
train_inds = np.setdiff1d(np.arange(X2.shape[0]),val_inds) #take all events not in val set

x_train = X2[train_inds,...]
y_train = y2[train_inds]
x_val = X2[val_inds,...]
y_val = y2[val_inds]

In [5]:
pretask_model = tf.keras.models.load_model(model_fname)
pretask_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 1, 64, 1002)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 8, 64, 1002)       512       
_________________________________________________________________
lambda (Lambda)              (None, 8, 64, 1002)       0         
_________________________________________________________________
batch_normalization (BatchNo (None, 8, 64, 1002)       32        
_________________________________________________________________
depthwise_conv2d (DepthwiseC (None, 16, 1, 1002)       1024      
_________________________________________________________________
batch_normalization_1 (Batch (None, 16, 1, 1002)       64        
_________________________________________________________________
activation (Activation)      (None, 16, 1, 1002)       0     

In [6]:
if pretask_type == 'rel_pos':
    sig_tran_model_fname = model_dir + 'sig_tran_model_htnet_fold_'+str(fold)+'.h5'
    sig_tran_pretask_model = tf.keras.models.load_model(sig_tran_model_fname)
    x = sig_tran_pretask_model.layers[-4].output
    x = Flatten(name = 'flatten2')(x)
    x = Dense(nb_classes, name = 'dense2', kernel_constraint = max_norm(norm_rate))(x)
    softmax = Activation('softmax', name = 'softmax2')(x)

    transfer_model = Model(inputs=sig_tran_pretask_model.input, outputs=softmax)
    transfer_model.load_weights(model_fname, by_name=True)

else:
    x = pretask_model.layers[-3].output
    # x = Flatten(name = 'flatten2')(x)
    x = Dense(nb_classes, name = 'dense', kernel_constraint = max_norm(norm_rate))(x)
    softmax = Activation('softmax', name = 'softmax')(x)

    transfer_model = Model(inputs=pretask_model.input, outputs=softmax)

# Set only last 3 layers to be trainable
for l in transfer_model.layers:
    l.trainable = False
for l in transfer_model.layers[-3:]:
    l.trainable = True #train last 3 layers
    

transfer_model.get_layer('depthwise_conv2d').trainable = True

transfer_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 1, 64, 501)]      0         
_________________________________________________________________
conv2d (Conv2D)              (None, 8, 64, 501)        512       
_________________________________________________________________
lambda (Lambda)              (None, 8, 64, 501)        0         
_________________________________________________________________
batch_normalization (BatchNo (None, 8, 64, 501)        32        
_________________________________________________________________
depthwise_conv2d (DepthwiseC (None, 16, 1, 501)        1024      
_________________________________________________________________
batch_normalization_1 (Batch (None, 16, 1, 501)        64        
_________________________________________________________________
activation (Activation)      (None, 16, 1, 501)        0     

In [7]:
transfer_model.compile(loss=loss, optimizer=optimizer, metrics = ['accuracy'])
checkpointer = ModelCheckpoint(filepath=chckpt_path,verbose=1,save_best_only=True)
early_stop = EarlyStopping(monitor=early_stop_monitor, mode='min',
                               patience=patience, verbose=0)

In [8]:
h = transfer_model.fit(x_train, y_train, batch_size = 16, epochs = epochs, 
                        verbose = 2, validation_data=(x_val, y_val),
                        callbacks=[checkpointer,early_stop])

Epoch 1/64

Epoch 00001: val_loss improved from inf to 0.53987, saving model to /home/zsteineh/ez_ssl_results/checkpoint_gen_tl_a0f66459_fold2.h5
48/48 - 1s - loss: 0.6128 - accuracy: 0.6777 - val_loss: 0.5399 - val_accuracy: 0.7407
Epoch 2/64

Epoch 00002: val_loss improved from 0.53987 to 0.44262, saving model to /home/zsteineh/ez_ssl_results/checkpoint_gen_tl_a0f66459_fold2.h5
48/48 - 0s - loss: 0.4413 - accuracy: 0.8177 - val_loss: 0.4426 - val_accuracy: 0.7989
Epoch 3/64

Epoch 00003: val_loss improved from 0.44262 to 0.39390, saving model to /home/zsteineh/ez_ssl_results/checkpoint_gen_tl_a0f66459_fold2.h5
48/48 - 0s - loss: 0.3715 - accuracy: 0.8441 - val_loss: 0.3939 - val_accuracy: 0.8413
Epoch 4/64

Epoch 00004: val_loss improved from 0.39390 to 0.37866, saving model to /home/zsteineh/ez_ssl_results/checkpoint_gen_tl_a0f66459_fold2.h5
48/48 - 0s - loss: 0.3290 - accuracy: 0.8719 - val_loss: 0.3787 - val_accuracy: 0.8730
Epoch 5/64

Epoch 00005: val_loss improved from 0.37866 

Epoch 44/64

Epoch 00044: val_loss did not improve from 0.30434
48/48 - 0s - loss: 0.2014 - accuracy: 0.9366 - val_loss: 0.3123 - val_accuracy: 0.8677
Epoch 45/64

Epoch 00045: val_loss did not improve from 0.30434
48/48 - 0s - loss: 0.2035 - accuracy: 0.9339 - val_loss: 0.3104 - val_accuracy: 0.8624
Epoch 46/64

Epoch 00046: val_loss did not improve from 0.30434
48/48 - 0s - loss: 0.1977 - accuracy: 0.9392 - val_loss: 0.3097 - val_accuracy: 0.8624
Epoch 47/64

Epoch 00047: val_loss did not improve from 0.30434
48/48 - 0s - loss: 0.1927 - accuracy: 0.9392 - val_loss: 0.3090 - val_accuracy: 0.8677
Epoch 48/64

Epoch 00048: val_loss did not improve from 0.30434
48/48 - 0s - loss: 0.1870 - accuracy: 0.9406 - val_loss: 0.3172 - val_accuracy: 0.8624
Epoch 49/64

Epoch 00049: val_loss did not improve from 0.30434
48/48 - 0s - loss: 0.1797 - accuracy: 0.9406 - val_loss: 0.3134 - val_accuracy: 0.8624
Epoch 50/64

Epoch 00050: val_loss did not improve from 0.30434
48/48 - 0s - loss: 0.1896 - ac

In [9]:
transfer_model.load_weights(chckpt_path)
acc_lst = []
preds = transfer_model.predict(x_train).argmax(axis = -1) 
acc_lst.append(np.mean(preds == y_train.argmax(axis=-1)))
preds = transfer_model.predict(x_val).argmax(axis=-1)
acc_lst.append(np.mean(preds == y_val.argmax(axis=-1)))
preds = transfer_model.predict(X_test2).argmax(axis = -1)
acc_lst.append(np.mean(preds == y_test2.argmax(axis=-1)))

print(np.asarray(acc_lst))

[0.94848085 0.86243386 0.84408602]
