In [58]:
# Работа с табличными данными
import pandas as pd
import numpy as np

# Пайплайн
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator, TransformerMixin

# Преобразование признаков
from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler

# Модели
from sklearn.linear_model import LogisticRegression

# Валидация
from sklearn.model_selection import cross_val_predict, cross_val_score
from sklearn.metrics import f1_score, accuracy_score, classification_report
from sklearn.model_selection import train_test_split

# Визуализация
import plotly.express as px
import plotly.io as pio
pio.templates.default = 'plotly_dark'

from collections import deque

from motorica.utils import *

## Базовое решение с использованием *Logistic Regression*

In [59]:
METAINFO_PATH = 'marked/selected_montages.csv'
read_meta_info(METAINFO_PATH)

Unnamed: 0_level_0,pilote_id,last_train_idx,len(train),len(test),ts_delta,ticks_per_gest,n_gestures,ACC,GYR,hi_val_sensors,mark_sensors
montage,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
2023-05-15_16-16-08.palm,1,23337,23337,5810,33.0,46.0,271.0,True,True,"[3, 4, 6, 12, 13, 16, 17, 21, 22, 27, 28, 30, ...","[3, 4, 6, 12, 13, 16, 17, 21, 22, 27, 28, 30, ..."
2023-05-15_17-12-24.palm,1,23336,23336,5803,33.0,46.0,271.0,True,True,"[3, 4, 6, 12, 13, 16, 17, 21, 22, 27, 28, 30, ...","[3, 4, 6, 12, 13, 16, 17, 21, 22, 27, 28, 30, ..."
2023-06-05_16-12-38.palm,1,17939,17939,4431,33.0,30.0,361.0,True,True,"[3, 4, 5, 6, 12, 13, 16, 17, 21, 22, 27, 28, 3...","[3, 4, 5, 6, 12, 13, 16, 17, 21, 22, 27, 28, 3..."
2023-06-05_17-53-01.palm,1,17771,17771,4435,33.0,31.0,361.0,True,True,"[3, 4, 5, 6, 12, 13, 16, 17, 21, 22, 27, 28, 3...","[3, 4, 5, 6, 12, 13, 16, 17, 21, 22, 27, 28, 3..."
2023-06-20_14-43-11.palm,1,17936,17936,4441,33.0,31.0,361.0,True,True,"[3, 4, 5, 6, 12, 13, 16, 17, 21, 27, 28, 30, 3...","[3, 5, 6, 12, 13, 16, 17, 21, 27, 28, 30, 31, ..."
2023-06-20_13-30-15.palm,1,17928,17928,4435,33.0,31.0,361.0,True,True,"[3, 4, 5, 6, 12, 13, 16, 17, 21, 22, 27, 28, 3...","[3, 4, 5, 6, 12, 13, 16, 17, 21, 22, 27, 28, 3..."
2023-06-20_12-34-17.palm,1,17758,17758,4444,33.0,31.0,361.0,True,True,"[3, 4, 5, 6, 12, 13, 16, 17, 21, 22, 27, 28, 3...","[3, 4, 5, 6, 12, 13, 16, 17, 21, 22, 27, 28, 3..."
2023-09-30_08-06-44.palm,2,5693,5693,5509,33.0,31.0,181.0,True,True,"[7, 9, 10, 18, 20, 23, 26, 28, 31, 34, 37, 39]","[7, 9, 10, 18, 20, 23, 26, 28, 31, 34, 37, 39]"
2023-09-29_11-03-50.palm,2,5694,5694,5511,33.0,31.0,181.0,True,True,"[7, 9, 10, 18, 20, 23, 26, 28, 31, 34, 37, 39]","[7, 9, 10, 18, 20, 23, 26, 28, 34, 37, 39]"
2023-09-29_09-20-47.palm,2,5690,5690,5507,33.0,31.0,181.0,True,True,"[7, 9, 10, 18, 20, 23, 26, 28, 31, 34, 37, 39]","[7, 9, 10, 18, 20, 23, 26, 28, 31, 34, 37, 39]"


In [60]:
def read_train_and_test(
        montage: str,
        features: List[str], 
        target_col: str = 'act_label'
) -> List:
    
    data_train = pd.read_csv("marked/" + montage + ".train", index_col=0)
    data_test = pd.read_csv("marked/" + montage + ".test", index_col=0)
    X_train = data_train.drop(target_col, axis=1)[features]
    y_train = data_train[target_col]
    X_test = data_test.drop(target_col, axis=1)[features]
    y_test = data_test[target_col]
    return X_train, X_test, y_train, y_test

In [61]:
montage = "2023-05-22_17-04-29.palm"
montage_info = read_meta_info(METAINFO_PATH).loc[montage]

print(montage)
display(montage_info)

features = montage_info['hi_val_sensors'] + cols_gyr

X_train, X_test, y_train, y_test = read_train_and_test(montage, features)

scaler = MinMaxScaler()
X_train_scaled = pd.DataFrame(
    scaler.fit_transform(X_train),
    columns=X_train.columns
)
X_test_scaled = pd.DataFrame(
    scaler.transform(X_test),
    columns=X_test.columns
)

2023-05-22_17-04-29.palm


pilote_id                                                         2
last_train_idx                                                23289
len(train)                                                    23289
len(test)                                                      5796
ts_delta                                                       33.0
ticks_per_gest                                                 46.0
n_gestures                                                    271.0
ACC                                                            True
GYR                                                            True
hi_val_sensors    [5, 7, 9, 10, 15, 18, 20, 23, 26, 28, 34, 37, 39]
mark_sensors      [5, 7, 9, 10, 15, 18, 20, 23, 26, 28, 34, 37, 39]
Name: 2023-05-22_17-04-29.palm, dtype: object

In [62]:
lr = LogisticRegression(C=500, max_iter=5000)
lr.fit(X_train_scaled, y_train)

y_pred = lr.predict(X_test_scaled)

proba = lr.predict_proba(X_test_scaled)
y_proba = np.array([p[y_pred[i]] for i, p in enumerate(proba)])

In [63]:
fig_data = X_test.copy()
fig_data['true'] = y_test.copy() * 100
#fig_data['true'] *= 100
fig_data['pred'] = y_pred.copy() * 100
#fig_data['pred'] *= 100
fig_data['proba'] = y_proba * 100

#fig = px.line(fig_data, y=features + ['true', 'pred'], width=1000, height=700, title=montage)#, color='proba')
fig = px.line(fig_data, width=1000, height=700, title=montage)#, color='proba')
fig.update_traces(line=dict(width=1))
#fig.add_scatter(x = fig_data.index, y=-fig_data.proba, line=dict(width=1), name='proba')
fig.show()

In [64]:
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.93      0.95      0.94      4141
           1       0.96      0.89      0.92       352
           2       0.89      0.84      0.86       319
           3       0.96      0.87      0.91       347
           4       0.66      0.97      0.79       285
           5       0.95      0.51      0.67       349

    accuracy                           0.91      5793
   macro avg       0.89      0.84      0.85      5793
weighted avg       0.92      0.91      0.91      5793

