In [1]:
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

plt.style.use('ggplot')

%matplotlib inline

In [2]:
def save_model(classifier, directory, model_type, hz):
    '''Saves model to defined folder.'''
    import m2cgen as m2c
    code = m2c.export_to_python(classifier)
    PATH = f'models/{directory}/{model_type}/{model_type}_{hz}hz.py'
    with open(PATH,'w') as f:
        f.writelines(code)

    return PATH

In [None]:
def transform_data_for_inference(df):
    '''
        Transoforms dataset for inference.
        ms,acc,gyro -> acc_x_0, gyro_x_0, acc_x_10, gyro_x_10, .... acc_x_n, gyro_x_n
    '''

    df_list=[]

    for time in df.index:
        _df = pd.DataFrame(df.loc[time]).T
        # _df.drop('ms',axis=1, inplace=True)
        df_list.append(_df.add_suffix(f'_{str(int(time))}').reset_index(drop=True))

    return pd.concat(df_list, axis=1)


def get_filter_string(start, step):
    '''
        Creates a string to filter dataset for defined timetimestamps.
        To be used with df.filter(regex='<string returned from this functions>')
        Example: 0|50|100
    '''

    keep = np.arange(start, start+1+1000, step=int(step))
    return '|'.join(map(str, keep.astype(int)))


def line_color(inf_result):
    '''Returns color associated with inference result.'''
    colors = {
        1:'blue',
        2:'red',
        3:'green'
    }
    return colors[inf_result]


def downsample_df(df, freq):
    '''Downsamples dataset.'''

    def get_period(frequency):
        return int(1000 / frequency)

    period = get_period(freq)

    last_index_ms = df.index[-1]
    keep = np.arange(last_index_ms, step=period)

    return df.loc[keep]


def run_inference(df, model, start, step):
    '''Runs inference.'''
    regex_filter = get_filter_string(start=start, step=step)
    data = list(df.filter(regex=f'_({regex_filter})$').loc[0])
    # print(len(data))
    return model.score(data)

def calculate_error(res, move_type):
    '''
        Calculates inference error rate in validation data.
    '''

    error_setup = {
        'circle': {'err_1':1,'err_2':2},
        'x':{'err_1':2, 'err_2':3},
        'y':{'err_1':1,'err_2':3}
    }

    err_1 = error_setup[move_type]['err_1']
    err_2 = error_setup[move_type]['err_2']

    val_counts = res['result'].value_counts().drop(0) # dropping `no movement`
    val_counts_keys = val_counts.keys()

    total_wrong = 0

    if err_1 in val_counts_keys:
        total_wrong += val_counts[err_1]
    if err_2 in val_counts_keys:
        total_wrong += val_counts[err_2]
    
    return (total_wrong / val_counts.sum()) * 100
