In [1]:
import torch
import numpy as np
import pandas as pd
import os
import sys

In [2]:
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
from torch import optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import shap

## Read Data

In [3]:
data = pd.read_csv('../data/gesture-recognition-and-biometrics-electromyogram-grabmyo-1.0.2/features_v2.csv')

In [49]:
def simple_model(input_shape, output_shape):
    model = nn.Sequential(
        nn.Linear(input_shape, 64),  # Input layer: Fully connected (linear) with 64 units
        nn.ReLU(),  # Activation function: ReLU,
        nn.Dropout(0.3),
        nn.Linear(64, 128),
        nn.ReLU(),  # Activation function: ReLU
        nn.Dropout(0.3),
        nn.Linear(128, 128),
        nn.ReLU(),  # Activation function: ReLU
        nn.Dropout(0.3),
        nn.Linear(128, output_shape)  # Output layer: Fully connected (linear) with 'output_shape' units
    )
    
    return model

In [50]:
data.head()

Unnamed: 0,session,trial,filename,crest_factor_F1,dasd_F1,diffvar_F1,form_factor_F1,iemg_F1,kurtosis_F1,kurtosis_f_F1,...,rms_W9,skew_W9,skew_f_W9,ssi_W9,sum_f_W9,var_f_W9,wflen_W9,willison_W9,gesture,participant
0,1,1,E:\DS5500-project\data\gesture-recognition-and...,7.091099,0.021834,0.000606,-6502.754507,365.456658,2.574033,85.547962,...,0.022588,-0.148877,5.529415,5.224663,0.109574,2.901826e-10,69.944068,15.0,10,1
1,1,2,E:\DS5500-project\data\gesture-recognition-and...,13.324514,0.023852,0.000689,1352.897349,379.664639,16.942218,329.11487,...,0.016172,-0.205427,9.332349,2.67794,0.077566,2.769772e-10,54.530919,23.0,10,1
2,1,3,E:\DS5500-project\data\gesture-recognition-and...,8.047954,0.016213,0.000326,1295.052394,296.823135,3.795827,134.751765,...,0.017039,-0.204604,9.07843,2.973002,0.064458,1.426359e-10,57.815909,13.0,10,1
3,1,4,E:\DS5500-project\data\gesture-recognition-and...,8.459555,0.026572,0.000864,14453.068135,479.688214,6.938354,98.83279,...,0.031982,-0.266818,5.067522,10.473811,0.131909,4.509337e-10,83.441138,0.0,10,1
4,1,5,E:\DS5500-project\data\gesture-recognition-and...,8.329011,0.026896,0.000901,11827.310277,464.350257,2.379207,116.412205,...,0.031134,-0.092205,5.414199,9.925787,0.177516,9.5771e-10,94.393506,3.0,10,1


In [51]:
feature_cols = [c for c in data.columns if "_" in c]

## split on Participants

In [52]:
rand_parts =  np.random.choice(data.participant.unique(), 5, replace = False)

In [53]:
rand_parts

array([ 1, 18, 24,  9, 20], dtype=int64)

In [54]:
train_df = data[~data['participant'].isin(rand_parts)]
test_df = data[data['participant'].isin(rand_parts)]

In [55]:
train_df

Unnamed: 0,session,trial,filename,crest_factor_F1,dasd_F1,diffvar_F1,form_factor_F1,iemg_F1,kurtosis_F1,kurtosis_f_F1,...,rms_W9,skew_W9,skew_f_W9,ssi_W9,sum_f_W9,var_f_W9,wflen_W9,willison_W9,gesture,participant
119,1,1,E:\DS5500-project\data\gesture-recognition-and...,7.625899,0.009812,0.000123,766.381922,166.919123,2.725185,346.343487,...,0.019608,-0.358002,5.409344,3.937070,0.032382,1.737857e-11,65.788911,5.0,10,10
120,1,2,E:\DS5500-project\data\gesture-recognition-and...,6.547451,0.026094,0.000890,-329.780444,430.964512,1.810296,129.466103,...,0.034974,-0.362530,2.995080,12.525026,0.108438,1.443714e-10,105.478357,1.0,10,10
121,1,3,E:\DS5500-project\data\gesture-recognition-and...,6.708311,0.018371,0.000432,-2879.507872,314.306619,1.676854,57.008656,...,0.030377,-0.654511,5.118171,9.449012,0.083889,1.418762e-10,89.575479,21.0,10,10
122,1,4,E:\DS5500-project\data\gesture-recognition-and...,6.726042,0.030643,0.001212,4814.696544,492.184779,1.744178,84.161226,...,0.053153,-0.696739,3.650792,28.930671,0.204543,8.759613e-10,148.160301,12.0,10,10
123,1,5,E:\DS5500-project\data\gesture-recognition-and...,4.634019,0.030248,0.001168,-7947.193479,534.330959,1.286783,38.664931,...,0.047429,-0.484191,3.522640,23.035254,0.198556,7.833930e-10,127.081209,5.0,10,10
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15227,3,3,E:\DS5500-project\data\gesture-recognition-and...,5.325903,0.009598,0.000111,-5043.147127,203.031442,1.156750,50.794422,...,0.056875,0.140382,5.270861,33.124231,0.053407,5.943447e-11,114.449822,0.0,9,8
15228,3,4,E:\DS5500-project\data\gesture-recognition-and...,5.924260,0.014301,0.000243,-781.397471,327.500843,1.051674,30.325437,...,0.077714,-0.027485,3.734465,61.843595,0.106743,1.824836e-10,148.791420,0.0,9,8
15229,3,5,E:\DS5500-project\data\gesture-recognition-and...,4.312347,0.012119,0.000179,-823.176614,254.643543,0.914980,30.567064,...,0.053871,0.053229,3.034506,29.716898,0.055936,5.490889e-11,113.669801,0.0,9,8
15230,3,6,E:\DS5500-project\data\gesture-recognition-and...,7.250760,0.009071,0.000098,-1693.957900,200.158878,1.707411,29.401532,...,0.053679,-0.122013,3.496624,29.505784,0.042985,3.658016e-11,101.096876,0.0,9,8


In [56]:
test_df

Unnamed: 0,session,trial,filename,crest_factor_F1,dasd_F1,diffvar_F1,form_factor_F1,iemg_F1,kurtosis_F1,kurtosis_f_F1,...,rms_W9,skew_W9,skew_f_W9,ssi_W9,sum_f_W9,var_f_W9,wflen_W9,willison_W9,gesture,participant
0,1,1,E:\DS5500-project\data\gesture-recognition-and...,7.091099,0.021834,0.000606,-6502.754507,365.456658,2.574033,85.547962,...,0.022588,-0.148877,5.529415,5.224663,0.109574,2.901826e-10,69.944068,15.0,10,1
1,1,2,E:\DS5500-project\data\gesture-recognition-and...,13.324514,0.023852,0.000689,1352.897349,379.664639,16.942218,329.114870,...,0.016172,-0.205427,9.332349,2.677940,0.077566,2.769772e-10,54.530919,23.0,10,1
2,1,3,E:\DS5500-project\data\gesture-recognition-and...,8.047954,0.016213,0.000326,1295.052394,296.823135,3.795827,134.751765,...,0.017039,-0.204604,9.078430,2.973002,0.064458,1.426359e-10,57.815909,13.0,10,1
3,1,4,E:\DS5500-project\data\gesture-recognition-and...,8.459555,0.026572,0.000864,14453.068135,479.688214,6.938354,98.832790,...,0.031982,-0.266818,5.067522,10.473811,0.131909,4.509337e-10,83.441138,0.0,10,1
4,1,5,E:\DS5500-project\data\gesture-recognition-and...,8.329011,0.026896,0.000901,11827.310277,464.350257,2.379207,116.412205,...,0.031134,-0.092205,5.414199,9.925787,0.177516,9.577100e-10,94.393506,3.0,10,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15346,3,3,E:\DS5500-project\data\gesture-recognition-and...,7.771453,0.013208,0.000208,2599.892725,274.024844,2.176596,28.888640,...,0.020526,-0.293608,3.830298,4.314177,0.035283,1.966555e-11,40.970398,2.0,9,9
15347,3,4,E:\DS5500-project\data\gesture-recognition-and...,9.172642,0.019135,0.000429,1255.242161,385.695263,4.615531,133.237150,...,0.025825,-0.389927,4.200155,6.829604,0.076162,1.003953e-10,52.342457,0.0,9,9
15348,3,5,E:\DS5500-project\data\gesture-recognition-and...,5.923115,0.018533,0.000408,1713.228290,389.039490,1.912226,22.613382,...,0.027460,-0.321020,4.375813,7.721463,0.081006,1.125954e-10,59.843660,0.0,9,9
15349,3,6,E:\DS5500-project\data\gesture-recognition-and...,6.466362,0.015749,0.000292,-992.810528,319.503197,2.189195,21.864578,...,0.013359,-0.493375,3.832319,1.827333,0.062739,6.332731e-11,35.905471,0.0,9,9


In [126]:
X = train_df.loc[:, feature_cols].values
Y = (train_df.loc[:, 'gesture']-1).values

In [127]:
scaler = StandardScaler()
X = scaler.fit_transform(X)

In [128]:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=56)

In [129]:
train_dataset = TensorDataset(torch.tensor(x_train).type(torch.float32), torch.tensor(y_train).type(torch.LongTensor))
test_dataset = TensorDataset(torch.tensor(x_test).type(torch.float32), torch.tensor(y_test).type(torch.LongTensor))

In [130]:
# Define batch size and whether to shuffle the data
batch_size = 256
shuffle = True

# Create data loaders for training and testing
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [131]:
model = simple_model(x_train.shape[1], len(np.unique(y_train)))

In [132]:
criterion = nn.CrossEntropyLoss()  
optimizer = optim.Adam(model.parameters(), lr=5e-4)

In [133]:
num_epochs = 40

In [134]:
# Training loop
for epoch in range(num_epochs):
    # Training phase
    model.train()
    total_train_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Validation phase
    model.eval()
    total_val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate and log metrics
    train_loss = total_train_loss / len(train_loader)
    val_loss = total_val_loss / len(test_loader)
    val_accuracy = 100 * correct / total

    if (epoch+1)%20==0:
        print(f"Epoch [{epoch + 1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")

Epoch [20/40] - Train Loss: 0.6222, Val Loss: 0.5813, Val Accuracy: 84.19%
Epoch [40/40] - Train Loss: 0.4422, Val Loss: 0.6062, Val Accuracy: 87.03%


In [135]:
def predict(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    y_pred = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            y_pred.extend(predicted.tolist())

    return y_pred

In [136]:
tune_participant = rand_parts[2]

In [137]:
tune_data = test_df[test_df['participant'] == tune_participant]

In [138]:
X_tune = tune_data.loc[:, feature_cols].values
Y_tune = (tune_data.loc[:, 'gesture']-1).values

In [139]:
X_tune = scaler.transform(X_tune)

In [140]:
test_tune_ds = TensorDataset(torch.tensor(X_tune).type(torch.float32), torch.tensor(Y_tune).type(torch.LongTensor))

In [141]:
test_tune_loader = DataLoader(test_tune_ds, batch_size=batch_size, shuffle=False)

In [142]:
x_tune_pred = predict(model, test_tune_loader)

In [143]:
print(classification_report(Y_tune, x_tune_pred))

              precision    recall  f1-score   support

           0       0.39      0.76      0.52        21
           1       1.00      0.14      0.25        21
           2       0.75      0.57      0.65        21
           3       0.00      0.00      0.00        21
           4       0.83      0.95      0.89        21
           5       1.00      0.57      0.73        21
           6       0.42      0.81      0.56        21
           7       1.00      0.48      0.65        21
           8       0.71      0.81      0.76        21
           9       0.44      0.81      0.57        21
          10       1.00      0.67      0.80        21
          11       0.84      1.00      0.91        21
          12       0.90      0.90      0.90        21
          13       0.86      0.86      0.86        21
          14       0.74      0.67      0.70        21
          15       0.70      0.76      0.73        21
          16       1.00      1.00      1.00        21

    accuracy              

In [144]:
tuning_df = test_df[(test_df['participant'] == tune_participant)&(test_df['trial'].isin([1,2,]))]
tuning_test_df = test_df[(test_df['participant'] == tune_participant)&(~test_df['trial'].isin([1,2]))]

In [145]:
tuning_df

Unnamed: 0,session,trial,filename,crest_factor_F1,dasd_F1,diffvar_F1,form_factor_F1,iemg_F1,kurtosis_F1,kurtosis_f_F1,...,rms_W9,skew_W9,skew_f_W9,ssi_W9,sum_f_W9,var_f_W9,wflen_W9,willison_W9,gesture,participant
1904,1,1,E:\DS5500-project\data\gesture-recognition-and...,5.657295,0.019440,0.000476,615.222613,373.519028,0.937125,23.409817,...,0.085353,-0.395261,8.002219,74.599399,0.207871,1.242995e-09,187.894726,1.0,10,24
1905,1,2,E:\DS5500-project\data\gesture-recognition-and...,5.879687,0.019960,0.000490,-1619.585715,386.153115,1.760145,19.179348,...,0.066937,-0.262174,4.673090,45.880620,0.385234,4.712145e-09,176.605563,10.0,10,24
1911,1,1,E:\DS5500-project\data\gesture-recognition-and...,6.183032,0.101422,0.011329,-7198.532527,2802.269411,3.252001,37.567211,...,0.068405,0.343500,5.390635,47.914872,1.088297,3.172551e-08,216.011435,24.0,11,24
1912,1,2,E:\DS5500-project\data\gesture-recognition-and...,8.924115,0.062760,0.004445,759.451426,1506.058946,6.750759,118.905049,...,0.064258,-0.228396,6.243593,42.282452,0.462765,6.222937e-09,206.824507,41.0,11,24
1918,1,1,E:\DS5500-project\data\gesture-recognition-and...,9.740596,0.071788,0.006638,2650.627232,1015.645945,11.913725,78.960091,...,0.031930,-0.469716,10.539341,10.439724,0.187835,2.078825e-09,95.044154,29.0,12,24
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12237,3,2,E:\DS5500-project\data\gesture-recognition-and...,5.143129,0.017524,0.000360,1418.890065,431.360638,0.917474,55.212639,...,0.047056,-1.212110,4.080170,22.673867,0.066180,8.173359e-11,103.554640,1.0,7,24
12243,3,1,E:\DS5500-project\data\gesture-recognition-and...,5.368640,0.020457,0.000490,8018.334128,509.312348,1.026564,33.345746,...,0.040888,-0.782017,5.758825,17.119608,0.170947,6.421941e-10,96.977743,0.0,8,24
12244,3,2,E:\DS5500-project\data\gesture-recognition-and...,5.060406,0.020450,0.000504,8454.555607,462.519508,0.829732,25.096814,...,0.021840,-0.643022,4.426234,4.884216,0.103046,2.294377e-10,53.436072,3.0,8,24
12250,3,1,E:\DS5500-project\data\gesture-recognition-and...,5.302365,0.010217,0.000131,-4565.559388,191.734896,0.804880,40.881267,...,0.094060,-0.665880,4.270417,90.596547,0.069347,1.071566e-10,202.318488,2.0,9,24


In [146]:
tuning_test_df

Unnamed: 0,session,trial,filename,crest_factor_F1,dasd_F1,diffvar_F1,form_factor_F1,iemg_F1,kurtosis_F1,kurtosis_f_F1,...,rms_W9,skew_W9,skew_f_W9,ssi_W9,sum_f_W9,var_f_W9,wflen_W9,willison_W9,gesture,participant
1906,1,3,E:\DS5500-project\data\gesture-recognition-and...,4.836888,0.018603,0.000454,2761.294938,316.852032,1.005308,27.976724,...,0.061223,-0.403719,4.663788,38.381958,0.156943,5.416719e-10,161.054083,2.0,10,24
1907,1,4,E:\DS5500-project\data\gesture-recognition-and...,6.222052,0.014499,0.000274,5690.100624,237.152544,1.210195,29.938002,...,0.055122,-0.349981,4.769289,31.113667,0.102360,2.363456e-10,142.584850,0.0,10,24
1908,1,5,E:\DS5500-project\data\gesture-recognition-and...,4.985339,0.017247,0.000368,-1442.137054,321.836498,0.993494,67.222826,...,0.067645,-0.459568,11.115638,46.856641,0.265597,3.925881e-09,178.274663,0.0,10,24
1909,1,6,E:\DS5500-project\data\gesture-recognition-and...,5.463508,0.017228,0.000383,-1354.504985,303.957350,1.258326,40.022537,...,0.049097,-0.277050,5.800400,24.683513,0.245782,2.210135e-09,147.813925,10.0,10,24
1910,1,7,E:\DS5500-project\data\gesture-recognition-and...,5.673289,0.014815,0.000274,-1387.847457,291.234275,1.203975,20.102644,...,0.083890,-0.342049,3.426597,72.064699,0.141357,3.466063e-10,182.575130,1.0,10,24
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12252,3,3,E:\DS5500-project\data\gesture-recognition-and...,5.459043,0.013609,0.000242,-2754.422405,221.864659,0.980693,34.982506,...,0.080350,-0.595216,4.555141,66.109891,0.066243,8.855938e-11,169.299838,0.0,9,24
12253,3,4,E:\DS5500-project\data\gesture-recognition-and...,5.334749,0.013376,0.000231,4069.804780,213.921800,1.181005,61.382997,...,0.083872,-0.579908,5.416733,72.033427,0.067354,9.520471e-11,191.659081,0.0,9,24
12254,3,5,E:\DS5500-project\data\gesture-recognition-and...,6.036265,0.016688,0.000343,-676.047878,291.301194,2.164869,38.134939,...,0.078554,-0.755646,3.660117,63.187730,0.074643,1.024763e-10,176.605152,1.0,9,24
12255,3,6,E:\DS5500-project\data\gesture-recognition-and...,5.293482,0.010024,0.000128,2134.327086,180.379298,0.878106,59.125723,...,0.093419,-0.695921,4.309182,89.365893,0.064123,8.778631e-11,202.945848,0.0,9,24


In [147]:
X_tune = tuning_df.loc[:, feature_cols].values
Y_tune = (tuning_df.loc[:, 'gesture']-1).values

X_tune_test = tuning_test_df.loc[:, feature_cols].values
Y_tune_test = (tuning_test_df.loc[:, 'gesture']-1).values

In [148]:
X_tune, X_tune_test = scaler.transform(X_tune), scaler.transform(X_tune_test)

In [149]:
tune_ds = TensorDataset(torch.tensor(X_tune).type(torch.float32), torch.tensor(Y_tune).type(torch.LongTensor))
test_tune_ds = TensorDataset(torch.tensor(X_tune_test).type(torch.float32), torch.tensor(Y_tune_test).type(torch.LongTensor))

tune_loader = DataLoader(tune_ds, batch_size=batch_size, shuffle=False)
test_tune_loader = DataLoader(test_tune_ds, batch_size=batch_size, shuffle=False)

In [150]:
num_tune_epochs = 20

In [151]:
#tuning
# Training loop
for epoch in range(num_tune_epochs):
    # Training phase
    model.train()
    total_train_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(tune_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Validation phase
    model.eval()
    total_val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_tune_loader):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate and log metrics
    train_loss = total_train_loss / len(train_loader)
    val_loss = total_val_loss / len(test_loader)
    val_accuracy = 100 * correct / total

    if (epoch+1)%20==0:
        print(f"Epoch [{epoch + 1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")

Epoch [20/40] - Train Loss: 0.0077, Val Loss: 0.0341, Val Accuracy: 83.53%


In [152]:
pred_tune = predict(model, test_tune_loader)

In [153]:
print(classification_report(Y_tune_test, pred_tune))

              precision    recall  f1-score   support

           0       0.75      0.40      0.52        15
           1       0.75      0.60      0.67        15
           2       1.00      0.67      0.80        15
           3       0.74      0.93      0.82        15
           4       0.88      1.00      0.94        15
           5       0.81      0.87      0.84        15
           6       0.52      0.80      0.63        15
           7       0.93      0.93      0.93        15
           8       0.88      1.00      0.94        15
           9       0.77      0.67      0.71        15
          10       1.00      1.00      1.00        15
          11       0.94      1.00      0.97        15
          12       0.94      1.00      0.97        15
          13       1.00      0.73      0.85        15
          14       0.87      0.87      0.87        15
          15       0.65      0.73      0.69        15
          16       1.00      1.00      1.00        15

    accuracy              