In [1]:
import itertools
from pathlib import Path
import pickle

import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

from online_semg_posture_adaptation import dataset as ds
from online_semg_posture_adaptation.learning import mlpuniboinail as mui
from online_semg_posture_adaptation.learning import learning as learn
from online_semg_posture_adaptation.learning import quantization as quant
from online_semg_posture_adaptation.learning import goodness as good

In [2]:
DOWNSAMPLING_FACTOR = 1

NUM_TRAIN_REPETITIONS = 5

NUM_EPOCHS_FP = 4
QUANTIZE = True
NUM_EPOCHS_QAT = 8
INPUT_SCALE = 0.999

RESULTS_FILENAME = 'results_multitrain.pkl'
RESULTS_DIR_PATH = './results/'
RESULTS_FILE_FULLPATH = RESULTS_DIR_PATH + RESULTS_FILENAME

In [3]:
# structure for storing the results

results = {'subject': {}}

for idx_subject in range(ds.NUM_SUBJECTS):

    results['subject'][idx_subject] = {'day': {}}

    for idx_day in range(ds.NUM_DAYS):

        results['subject'][idx_subject]['day'][idx_day] = {'posture': {}}

        for idx_valid_posture in range(ds.NUM_POSTURES):

            results['subject'][idx_subject]['day'][idx_day]['posture'][idx_valid_posture] = {
                # just classification metrics, no models or labels
                'training': {},  # classification metrics dictionary
                'validation': {},  # classification metrics dictionary
            }

In [4]:
for idx_subject, idx_day in itertools.product(
    range(ds.NUM_SUBJECTS), range(ds.NUM_DAYS)
):
    
    # ----------------------------------------------------------------------- #

    # print a header
    print(
        f"\n"
        f"------------------------------------------------------------------\n"
        f"SUBJECT\t{idx_subject + 1 :d}/{ds.NUM_SUBJECTS:d}\n"
        f"DAY\t{idx_day + 1 :d}/{ds.NUM_DAYS:d}\n"
        f"(all indices are one-based)\n"
        f"------------------------------------------------------------------\n"
        f"\n"
    )

    # ----------------------------------------------------------------------- #

    # load training data

    xtrain_list = []
    ytrain_list = []

    for idx_train_posture in range(ds.NUM_POSTURES):
        
        train_session_data_dict = ds.load_session(
            idx_subject, idx_day, idx_train_posture)

        emg_train = train_session_data_dict['emg']
        relabel_train = train_session_data_dict['relabel']
        gesture_counter_train = train_session_data_dict['gesture_counter']
        del train_session_data_dict

        # "_p" stands for single posture
        xtrain_p, ytrain_p, _, _ = ds.split_into_calib_and_valid(
            emg=emg_train,
            relabel=relabel_train,
            gesture_counter=gesture_counter_train,
            num_calib_repetitions=NUM_TRAIN_REPETITIONS,
        )
        del emg_train, relabel_train, gesture_counter_train

        # add to the lists
        xtrain_list.append(xtrain_p)
        ytrain_list.append(ytrain_p)
        del xtrain_p, ytrain_p

    # concatenate into single arrays
    xtrain = np.concatenate(xtrain_list, axis=1)
    ytrain = np.concatenate(ytrain_list, axis=0)
    del xtrain_list, ytrain_list

    # ----------------------------------------------------------------------- #

    # downsampling
    xtrain = xtrain[:, ::DOWNSAMPLING_FACTOR]
    ytrain = ytrain[::DOWNSAMPLING_FACTOR]

    # standard scaling and de-correlation, as preprocessing before training
    stdscaler_train = StandardScaler()
    xtrain_stdscaled = stdscaler_train.fit_transform(xtrain.T).T
    del xtrain
    pca_train = PCA(n_components=ds.NUM_CHANNELS, whiten=False)
    xtrain_pc = pca_train.fit_transform(xtrain_stdscaled.T).T
    del xtrain_stdscaled

    # ----------------------------------------------------------------------- #

    # MLP training and validation

    mlp = mui.MLPUniboINAIL()
    mui.summarize(mlp)

    # full-precision training
    mlp, history, yout_train, yout_valid = learn.do_training(
        xtrain=xtrain_pc,
        ytrain=ytrain,
        model=mlp,
        xvalid=None,
        yvalid=None,
        num_epochs=NUM_EPOCHS_FP,
    )
    # Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT)
    if QUANTIZE:
        (
            mlp_q,
            output_scale,
            history_q,
            metrics_train_q,
            _,  # (in general, this is metrics_valid_q)
            yout_train_q,
            _,  # (in general, this is yout_valid_q)
        ) = quant.quantlib_flow(
            xtrain=xtrain_pc,
            ytrain=ytrain,
            model=mlp,
            xvalid=None,
            yvalid=None,
            do_qat=True,
            num_epochs_qat=NUM_EPOCHS_QAT,
            input_scale=INPUT_SCALE,
            export=False,
            onnx_filename=None,
        )
        mlp = mlp_q  # replace the model
        del mlp_q
    else:
        output_scale = 1.0
    
    del xtrain_pc, ytrain
    
    # ----------------------------------------------------------------------- #

    for idx_valid_posture in range(ds.NUM_POSTURES):

        # ------------------------------------------------------------------- #

        # print a header
        print(
            f"\n"
            f"--------------------------------------------------------------\n"
            f"VALIDATION ON POSTURE {idx_valid_posture + 1 :d}\n"
            f"--------------------------------------------------------------\n"
            f"\n"
        )

        # ------------------------------------------------------------------- #
        
        # load validation data

        valid_session_data_dict = ds.load_session(
            idx_subject, idx_day, idx_valid_posture)

        emg_valid = valid_session_data_dict['emg']
        relabel_valid = valid_session_data_dict['relabel']
        gesture_counter_valid = valid_session_data_dict['gesture_counter']
        del valid_session_data_dict

        # "_p" stands for single posture
        xtrain_p, ytrain_p, xvalid, yvalid = ds.split_into_calib_and_valid(
            emg=emg_valid,
            relabel=relabel_valid,
            gesture_counter=gesture_counter_valid,
            num_calib_repetitions=NUM_TRAIN_REPETITIONS,
        )
        del emg_valid, relabel_valid, gesture_counter_valid

        # ------------------------------------------------------------------- #

        # preprocessing

        xtrain_p = xtrain_p[:, ::DOWNSAMPLING_FACTOR]
        ytrain_p = ytrain_p[::DOWNSAMPLING_FACTOR]
        xvalid = xvalid[:, ::DOWNSAMPLING_FACTOR]
        yvalid = yvalid[::DOWNSAMPLING_FACTOR]

        xtrain_p_standardscaled = stdscaler_train.transform(xtrain_p.T).T
        xvalid_standardscaled = stdscaler_train.transform(xvalid.T).T
        del xtrain_p, xvalid
        xtrain_p_pc = pca_train.transform(xtrain_p_standardscaled.T).T
        xvalid_pc = pca_train.transform(xvalid_standardscaled.T).T
        del xtrain_p_standardscaled, xvalid_standardscaled

        # ------------------------------------------------------------------- #

        # MLP inference
        
        yout_train_p = learn.do_inference(xtrain_p_pc, mlp)
        yout_valid = learn.do_inference(xvalid_pc, mlp)
        del xtrain_p_pc, xvalid_pc

        metrics_train_p = good.compute_classification_metrics(ytrain_p, yout_train_p)
        metrics_valid = good.compute_classification_metrics(yvalid, yout_valid)

        print("\n\n")
        print("On training repetitions:")
        print(metrics_train_p)
        print("On validation repetitions:")
        print(metrics_valid)
        print("\n\n")
        
        # ------------------------------------------------------------------- #

        # store results
        results['subject'][idx_subject]['day'][idx_day]['posture'][idx_valid_posture]['training'] = metrics_train_p
        results['subject'][idx_subject]['day'][idx_day]['posture'][idx_valid_posture]['validation'] = metrics_valid
        
        # save to file
        # save the updated results dictionary after each validation
        results_outer_dict = {'results': results}
        Path(RESULTS_DIR_PATH).mkdir(parents=True, exist_ok=True)
        with open(RESULTS_FILE_FULLPATH, 'wb') as f:
            pickle.dump(results_outer_dict, f)


------------------------------------------------------------------
SUBJECT	1/7
DAY	1/8
(all indices are one-based)
------------------------------------------------------------------



		TRAINING		VALIDATION

EPOCH		Loss	Bal.acc.	Loss	Bal.acc.	Time (s)

1/4		0.2819	0.8869		none	none		8.1
2/4		0.2469	0.9044		none	none		7.6
3/4		0.2407	0.9075		none	none		7.8
4/4		0.2368	0.9083		none	none		7.6

		TRAINING		VALIDATION

EPOCH		Loss	Bal.acc.	Loss	Bal.acc.	Time (s)

1/8		0.2615	0.9020		none	none		42.5
2/8		0.2457	0.9056		none	none		53.8
3/8		0.2440	0.9067		none	none		52.5
4/8		0.2452	0.9069		none	none		47.6
5/8		0.2418	0.9066		none	none		52.9
6/8		0.2410	0.9076		none	none		52.7
7/8		0.2385	0.9104		none	none		54.2
8/8		0.2440	0.9073		none	none		44.1
[FinalEpsTunnelRemover] output: removing EpsTunnel with scaling factor tensor([[0.0001524924737168, 0.0002130363136530, 0.0002445808495395,
         0.0002233606064692, 0.0001955179614015, 0.0001720564323477]])
[FinalEpsTunnelRemover] output: outp




--------------------------------------------------------------
VALIDATION ON POSTURE 1
--------------------------------------------------------------





On training repetitions:
{'balanced_crossentropy': 354.7293701171875, 'balanced_accuracy': 0.927797578504841, 'accuracy': 0.9407439282143445}
On validation repetitions:
{'balanced_crossentropy': 450.1875305175781, 'balanced_accuracy': 0.9233168117105911, 'accuracy': 0.9360813852596104}




--------------------------------------------------------------
VALIDATION ON POSTURE 2
--------------------------------------------------------------





On training repetitions:
{'balanced_crossentropy': 419.3710021972656, 'balanced_accuracy': 0.9295673531195936, 'accuracy': 0.9336306061781052}
On validation repetitions:
{'balanced_crossentropy': 1284.4581298828125, 'balanced_accuracy': 0.8938587482459145, 'accuracy': 0.9116027306879233}




--------------------------------------------------------------
VALIDATION ON POSTURE 3
------------------

KeyboardInterrupt: 