In [6]:
import os, sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from utils.data_utils import load_dataset ,load_file, split_ids
from utils.plot_utils import plot_targets
from utils.preprocessing import create_batch_sequences, create_sequences
from utils.evaluate import compare_events
from scipy.signal import find_peaks
np.random.seed(123)

For each valid trial, the relevant data and information are saved in **five** different data files:
- sub-\<label\>_task-\<label\>[_run-\<label\>]_events.tsv
- sub-\<label\>_task-\<label\>[_run-\<label\>]_tracksys-imu_channels.tsv
- sub-\<label\>_task-\<label\>[_run-\<label\>]_tracksys-imu_motion.tsv
- sub-\<label\>_task-\<label\>[_run-\<label\>]_tracksys-omc_channels.tsv
- sub-\<label\>_task-\<label\>[_run-\<label\>]_tracksys-omc_motion.tsv

As we only consinder **walk** trials, we can look for any **_events.tsv** file that contains **_task-walk** in the filename.

## Get data

In [2]:
# Set root directory
root_dir = "/mnt/neurogeriatrics_data/Keep Control/Data/lab dataset/rawdata" if sys.platform == "linux" else "Z:\\Keep Control\\Data\\lab dataset\\rawdata"

# Split subjects into a train, validation, and test set
train_ids, val_ids, test_ids = split_ids(root_dir, by=["gender", "participant_type"])
print(f"# of subjects in train set: {len(train_ids):d}\t{train_ids[:5]}")
print(f"# of subjects in val set: {len(val_ids):d}\t{val_ids[:5]}")
print(f"# of subjects in test set: {len(test_ids):d}\t{test_ids[:5]}")

# of subjects in train set: 62	['sub-pp147', 'sub-pp039', 'sub-pp135', 'sub-pp157', 'sub-pp106']
# of subjects in val set: 49	['sub-pp047', 'sub-pp153', 'sub-pp168', 'sub-pp125', 'sub-pp010']
# of subjects in test set: 49	['sub-pp156', 'sub-pp155', 'sub-pp167', 'sub-pp126', 'sub-pp165']


In [3]:
# Get datasets
ds_train = load_dataset(root_dir, sub_ids=train_ids, tracked_points=["left_ankle", "right_ankle"], normalize=True)
ds_val = load_dataset(root_dir, sub_ids=val_ids, tracked_points=["left_ankle", "right_ankle"], normalize=True)
ds_test = load_dataset(root_dir, sub_ids=test_ids, tracked_points=["left_ankle", "right_ankle"], normalize=True)

/mnt/neurogeriatrics_data/Keep Control/Data/lab dataset/rawdata/sub-pp164/motion/sub-pp164_task-walkFast_events.tsv contains no data for (at least) the left_ankle sensor. Skip file.
/mnt/neurogeriatrics_data/Keep Control/Data/lab dataset/rawdata/sub-pp164/motion/sub-pp164_task-walkPreferred_events.tsv contains no data for (at least) the left_ankle sensor. Skip file.
/mnt/neurogeriatrics_data/Keep Control/Data/lab dataset/rawdata/sub-pp164/motion/sub-pp164_task-walkSlow_events.tsv contains no data for (at least) the left_ankle sensor. Skip file.
/mnt/neurogeriatrics_data/Keep Control/Data/lab dataset/rawdata/sub-pp163/motion/sub-pp163_task-walkFast_events.tsv contains no data for (at least) the left_ankle sensor. Skip file.
/mnt/neurogeriatrics_data/Keep Control/Data/lab dataset/rawdata/sub-pp163/motion/sub-pp163_task-walkPreferred_events.tsv contains no data for (at least) the left_ankle sensor. Skip file.
/mnt/neurogeriatrics_data/Keep Control/Data/lab dataset/rawdata/sub-pp163/motion

In [4]:
# Create batches of input examples of equal length
WIN_LEN = 400
STEP_LEN = 200
train_data, train_targets, train_examples = create_batch_sequences(ds_train, win_len=WIN_LEN, step_len=STEP_LEN)
val_data, val_targets, val_examples = create_batch_sequences(ds_val, win_len=WIN_LEN, step_len=STEP_LEN)

## Build model

In [5]:
from tensorflow import keras
from tcn import TCN, tcn_full_summary

In [8]:
def build_model(input_dim, target_classes):
    # Define layers
    inputs = keras.layers.Input(shape=(None, input_dim), name="inputs")
    tcn = TCN(nb_filters=16, kernel_size=5, nb_stacks=1, dilations=[1, 2], padding="same", use_skip_connections=True, use_batch_norm=True, return_sequences=True, name="tcn")(inputs)
    outputs = []
    for i in range(len(target_classes)):
        outputs.append(keras.layers.Dense(units=1, activation="sigmoid", name=target_classes[i])(tcn))
    
    # Instantiate model
    model = keras.models.Model(inputs=inputs, outputs=outputs, name="tcn_model")

    # Compile
    model.compile(loss=keras.losses.MeanSquaredError(), optimizer=keras.optimizers.Adam(learning_rate=0.001))
    return model

In [9]:
model = build_model(train_data.shape[-1], target_classes=list(train_targets.keys()))
model.summary()

Model: "tcn_model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 inputs (InputLayer)            [(None, None, 9)]    0           []                               
                                                                                                  
 tcn (TCN)                      (None, None, 16)     5040        ['inputs[0][0]']                 
                                                                                                  
 initial_contact (Dense)        (None, None, 1)      17          ['tcn[0][0]']                    
                                                                                                  
 final_contact (Dense)          (None, None, 1)      17          ['tcn[0][0]']                    
                                                                                          

In [10]:
history = model.fit(x=train_data, y=train_targets, batch_size=8, epochs=5, validation_data=(val_data, val_targets))

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


### Evaluate on test set

In [37]:
def fancy_plot(ds_test, ix_sub_id, predictions, ix_IC_pred, ix_FC_pred):
    cm = 1/2.54

    fig, axs = plt.subplots(1+len(ds_test[ix_sub_id]["targets"].keys()), 1, figsize=(16*cm, 9*cm), sharex=True, gridspec_kw={"height_ratios": [3, 1, 1]})
    axs[0].plot(np.arange(ds_test[ix_sub_id]["data"].shape[0]), ds_test[ix_sub_id]["data"][:,5], ls="-", c=(0, 0, 0), alpha=0.3)
    axs[0].plot(np.argwhere(ds_test[ix_sub_id]["targets"]["initial_contact"]==1)[:,0], ds_test[ix_sub_id]["data"][np.argwhere(ds_test[ix_sub_id]["targets"]["initial_contact"]==1)[:,0],5], ls="none", marker="o", mfc="none", mec=(0, 0, 1))
    axs[0].plot(np.argwhere(ds_test[ix_sub_id]["targets"]["final_contact"]==1)[:,0], ds_test[ix_sub_id]["data"][np.argwhere(ds_test[ix_sub_id]["targets"]["final_contact"]==1)[:,0],5], ls="none", marker="o", mfc="none", mec=(0, 0.65, 0))
    axs[0].xaxis.set_minor_locator(plt.MultipleLocator(10))
    axs[0].grid(which="both", axis="both", c=(0, 0, 0), alpha=0.1, ls=":")
    axs[0].set_xlim([0, ds_test[ix_sub_id]["data"].shape[0]])

    axs[1].plot(np.arange(ds_test[ix_sub_id]["targets"]["initial_contact"].shape[0]), ds_test[ix_sub_id]["targets"]["initial_contact"][:,0], ls="-", c=(0, 0, 1), lw=2, alpha=0.2)
    axs[1].plot(np.arange(predictions[0][0].shape[0]), predictions[0][0][:,0], ls="-", lw=1, c=(0, 0, 1))
    axs[1].plot(ix_IC_pred, predictions[0][0][ix_IC_pred,0], ls="none", marker="o", mfc="none", mec=(0, 0, 1))
    axs[1].set_ylim([0.1, 1.1])
    axs[1].set_yticks([0, 1])
    axs[1].set_ylabel("Pr(IC)")
    axs[1].yaxis.set_minor_locator(plt.MultipleLocator(0.25))
    axs[1].grid(which="both", axis="both", c=(0, 0, 0), alpha=0.1, ls=":")

    axs[2].plot(np.arange(ds_test[ix_sub_id]["targets"]["final_contact"].shape[0]), ds_test[ix_sub_id]["targets"]["final_contact"][:,0], ls="-", c=(0, 0.65, 0), lw=2, alpha=0.2)
    axs[2].plot(np.arange(predictions[1][0].shape[0]), predictions[1][0][:,0], ls="-", lw=1, c=(0, 0.65, 0))
    axs[2].plot(ix_FC_pred, predictions[1][0][ix_FC_pred,0], ls="none", marker="o", mfc="none", mec=(0, 0.65, 0))
    axs[2].set_ylim([0.1, 1.1])
    axs[2].set_yticks([0, 1])
    axs[2].set_ylabel("Pr(FC)")
    axs[2].yaxis.set_minor_locator(plt.MultipleLocator(0.25))
    axs[2].grid(which="both", axis="both", c=(0, 0, 0), alpha=0.1, ls=":")
    axs[2].set_xlabel("sample")

    plt.tight_layout()
    plt.savefig(os.path.join("/home/robbin/Desktop/fig/20220222", ds_test[ix_sub_id]["filename_prefix"]+"_"+ds_test[ix_sub_id]["left_or_right"]+"_predictions.png"), dpi=300)
    plt.close()
    return

In [38]:
def add_metrics(ds_test, ix_sub_id, visualize=True):
    # Get data array
    X = ds_test[ix_sub_id]["data"]
    X = np.expand_dims(X, axis=0)
    
    # Predict
    predictions = model.predict(X)
    
    # Get true and predicted events
    ix_IC_true, _ = find_peaks(ds_test[ix_sub_id]["targets"]["initial_contact"][:,0], height=0.4, distance=100)
    ix_FC_true, _ = find_peaks(ds_test[ix_sub_id]["targets"]["final_contact"][:,0], height=0.4, distance=100)
    ix_IC_pred, _ = find_peaks(predictions[0][0][:,0], height=0.4, distance=100)
    ix_FC_pred, _ = find_peaks(predictions[1][0][:,0], height=0.4, distance=100)
    
    event_type, ix_ref, ix_pred = [], [], []
    # Initial contacts
    ann2pred, pred2ann, dt = compare_events(ix_IC_true, ix_IC_pred)
    for i in range(len(ix_IC_true)):
        event_type.append("IC")
        ix_ref.append(ix_IC_true[i])
        if ann2pred[i] > -999:
            # print(f"{ix_IC_true[i]:>6d} : {ix_IC_pred[ann2pred[i]]}")
            ix_pred.append(ix_IC_pred[ann2pred[i]])
        else:
            # print(f"{ix_IC_true[i]:>6d} : {'n/a':s}")
            ix_pred.append("n/a")
    for j in range(len(ix_IC_pred)):
        if pred2ann[j] == -999:
            event_type.append("IC")
            ix_ref.append("n/a")
            ix_pred.append(ix_IC_pred[j])
            # print(f"{'n/a':>6s} : {ix_IC_pred[j]}")
            
    # Final contacts
    ann2pred, pred2ann, dt = compare_events(ix_FC_true, ix_FC_pred)
    for i in range(len(ix_FC_true)):
        event_type.append("FC")
        ix_ref.append(ix_FC_true[i])
        if ann2pred[i] > -999:
            # print(f"{ix_FC_true[i]:>6d} : {ix_FC_pred[ann2pred[i]]}")
            ix_pred.append(ix_FC_pred[ann2pred[i]])
        else:
            # print(f"{ix_FC_true[i]:>6d} : {'n/a':s}")
            ix_pred.append("n/a")
    for j in range(len(ix_FC_pred)):
        if pred2ann[j] == -999:
            event_type.append("FC")
            ix_ref.append("n/a")
            ix_pred.append(ix_FC_pred[j])
            # print(f"{'n/a':>6s} : {ix_FC_pred[j]}")
    
    sub_id = [ds_test[ix_sub_id]["filename_prefix"][4:9] for _ in range(len(event_type))]
    filenames = [ds_test[ix_sub_id]["filename_prefix"] for _ in range(len(event_type))]
    side = [ds_test[ix_sub_id]["left_or_right"] for _ in range(len(event_type))]
    
    # Visualize
    if visualize:
        fancy_plot(ds_test=ds_test, ix_sub_id=ix_sub_id, predictions=predictions, ix_IC_pred=ix_IC_pred, ix_FC_pred=ix_FC_pred)
    
    return sub_id, filenames, side, event_type, ix_ref, ix_pred

In [51]:
sub_ids, filenames, sides, event_types, ix_refs, ix_preds = [], [], [], [], [], []
for i in range(98, len(ds_test)):
    print(f"{i:>3d}: {ds_test[i]['filename_prefix']+'_'+ds_test[i]['left_or_right']:s}")
    sub_id, filename, side, event_type, ix_ref, ix_pred = add_metrics(ds_test, ix_sub_id=i)
    sub_ids += sub_id
    filenames += filename
    sides += side
    event_types += event_type
    ix_refs += ix_ref
    ix_preds += ix_pred    

 98: sub-pp013_task-walkSlow_left
 99: sub-pp013_task-walkSlow_right
100: sub-pp045_task-walkFast_left
101: sub-pp045_task-walkFast_right
102: sub-pp045_task-walkPreferred_left
103: sub-pp045_task-walkPreferred_right
104: sub-pp045_task-walkSlow_left
105: sub-pp045_task-walkSlow_right
106: sub-pp031_task-walkFast_left
107: sub-pp031_task-walkFast_right
108: sub-pp031_task-walkPreferred_left
109: sub-pp031_task-walkPreferred_right
110: sub-pp031_task-walkSlow_left
111: sub-pp031_task-walkSlow_right
112: sub-pp050_task-walkFast_left
113: sub-pp050_task-walkFast_right
114: sub-pp050_task-walkPreferred_left
115: sub-pp050_task-walkPreferred_right
116: sub-pp050_task-walkSlow_left
117: sub-pp050_task-walkSlow_right
118: sub-pp044_task-walkFast_left
119: sub-pp044_task-walkFast_right
120: sub-pp044_task-walkPreferred_left
121: sub-pp044_task-walkPreferred_right
122: sub-pp044_task-walkSlow_left
123: sub-pp044_task-walkSlow_right
124: sub-pp123_task-walkFast_left
125: sub-pp123_task-walkFast_

ValueError: operands could not be broadcast together with shapes (3,) (4,) 

In [55]:
# Get data array
X = ds_test[192]["data"]
X = np.expand_dims(X, axis=0)

# Predict
predictions = model.predict(X)

# Get true and predicted events
ix_IC_true, _ = find_peaks(ds_test[192]["targets"]["initial_contact"][:,0], height=0.4, distance=100)
ix_FC_true, _ = find_peaks(ds_test[192]["targets"]["final_contact"][:,0], height=0.4, distance=100)
ix_IC_pred, _ = find_peaks(predictions[0][0][:,0], height=0.4, distance=100)
ix_FC_pred, _ = find_peaks(predictions[1][0][:,0], height=0.4, distance=100)

event_type, ix_ref, ix_pred = [], [], []
# Initial contacts
ann2pred, pred2ann, dt = compare_events(ix_IC_true, ix_IC_pred)
for i in range(len(ix_IC_true)):
    # event_type.append("IC")
    # ix_ref.append(ix_IC_true[i])
    if ann2pred[i] > -999:
        print(f"{ix_IC_true[i]:>6d} : {ix_IC_pred[ann2pred[i]]}")
        # ix_pred.append(ix_IC_pred[ann2pred[i]])
    else:
        print(f"{ix_IC_true[i]:>6d} : {'n/a':s}")
        # ix_pred.append("n/a")
for j in range(len(ix_IC_pred)):
    if pred2ann[j] == -999:
        # event_type.append("IC")
        # ix_ref.append("n/a")
        # ix_pred.append(ix_IC_pred[j])
        print(f"{'n/a':>6s} : {ix_IC_pred[j]}")
        
# Final contacts
ann2pred, pred2ann, dt = compare_events(ix_FC_true, ix_FC_pred)
for i in range(len(ix_FC_true)):
    # event_type.append("FC")
    # ix_ref.append(ix_FC_true[i])
    if ann2pred[i] > -999:
        print(f"{ix_FC_true[i]:>6d} : {ix_FC_pred[ann2pred[i]]}")
        # ix_pred.append(ix_FC_pred[ann2pred[i]])
    else:
        print(f"{ix_FC_true[i]:>6d} : {'n/a':s}")
        # ix_pred.append("n/a")
for j in range(len(ix_FC_pred)):
    if pred2ann[j] == -999:
        # event_type.append("FC")
        # ix_ref.append("n/a")
        # ix_pred.append(ix_FC_pred[j])
        print(f"{'n/a':>6s} : {ix_FC_pred[j]}")

   142 : 19
   394 : 273
   651 : 525
   907 : 781
  1175 : n/a
   n/a : 1034


ValueError: operands could not be broadcast together with shapes (3,) (4,) 

In [57]:
print(ix_FC_true)
print(ix_FC_pred)

[  56  311  560  818 1077]
[ 185  436  695  950 1220]


In [75]:
a2b = np.zeros_like(ix_FC_true)
print(f"a2b: {a2b}")
for i in range(len(ix_FC_true)):
    imin = np.argmin(np.abs(ix_FC_pred - ix_FC_true[i]))
    a2b[i] = imin
print(f"a2b: {a2b}")
a2b_unique = np.unique(a2b)
for i in range(len(a2b_unique)):
    indices = np.argwhere(a2b == a2b_unique[i])[:,0]
    if len(indices) > 1:
        a2b[np.setdiff1d(indices, b2a[a2b_unique[i]])] = -999
print(f"a2b: {a2b}")

b2a = np.zeros_like(ix_FC_true)
print(f"b2a: {b2a}")
for i in range(len(ix_FC_pred)):
    imin = np.argmin(np.abs(ix_FC_true - ix_FC_pred[i]))
    b2a[i] = imin
print(f"b2a: {b2a}")
b2a_unique = np.unique(b2a)
for i in range(len(b2a_unique)):
    indices = np.argwhere(b2a == b2a_unique[i])[:,0]
    if len(indices) > 1:
        b2a[np.setdiff1d(indices, a2b[b2a_unique[i]])] = -999
print(f"b2a: {b2a}")

# indices_a2b = np.argwhere(a2b > -999)[:,0]
# ii = np.argwhere(b2a[a2b[indices_a2b]]==-999)[:,0]
# a2b[indices_a2b[ii]] = -999
# print(f"a2b: {a2b}")

# indices_b2a = np.argwhere(b2a > -999)[:,0]
# ii = np.argwhere(a2b[b2a[indices_b2a]]==-999)[:,0]
# b2a[indices_b2a[ii]] = -999
# print(f"b2a: {b2a}")

time_diff = ix_IC_true[a2b > -999] - ix_IC_pred[b2a > -999]
print(f"time: {time_diff}")

a2b: [0 0 0 0 0]
a2b: [0 1 1 2 3]
a2b: [   0 -999    1    2    3]
b2a: [0 0 0 0 0]
b2a: [1 2 3 4 4]
b2a: [   1    2    3    4 -999]
time: [123 378 382 394]


In [76]:
df_out = pd.DataFrame({
    "sub": sub_ids,
    "filename": filenames,
    "side": sides,
    "event_type": event_types,
    "ix_ref": ix_refs,
    "ix_pred": ix_preds
})
df_out.head()

Unnamed: 0,sub,filename,side,event_type,ix_ref,ix_pred
0,pp013,sub-pp013_task-walkSlow,left,IC,108,114
1,pp013,sub-pp013_task-walkSlow,left,IC,366,369
2,pp013,sub-pp013_task-walkSlow,left,IC,632,636
3,pp013,sub-pp013_task-walkSlow,left,IC,894,898
4,pp013,sub-pp013_task-walkSlow,left,IC,1150,1155


In [77]:
df_out.to_csv("gait_event_detection.tsv", sep="\t")