In [1]:
import socket
import struct
import sys
import os
import time
import numpy as np
import matplotlib.pyplot as plt

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, RobustScaler
from scipy.signal import stft, welch, butter, filtfilt
from sklearn.preprocessing import FunctionTransformer
from sklearn.model_selection import GridSearchCV
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA

from classifiers_linear import LDA

In [2]:
class LogisticRegTune(LogisticRegression):
    def __init__(self,
                 penalty='l1',
                 *,
                 dual=False,
                 tol=0.0001,
                 C=1.0,
                 fit_intercept=True,
                 intercept_scaling=1,
                 class_weight=None,
                 random_state=None,
                 solver='liblinear',
                 max_iter=100,
                 multi_class='ovr',
                 verbose=0,
                 warm_start=False,
                 n_jobs=None,
                 l1_ratio=None,
                 scaler=None,
                 halfwin=7,
                 overlap=2,
                 num_channels=[0, 1, 2, 64, 65, 66]):
        self.scaler = scaler
        self.halfwin = halfwin
        self.overlap = overlap
        self.num_channels = num_channels
        self.sr_data = 500
        self.sample_rate = [0, 250]

        super().__init__(penalty=penalty, C=C, solver=solver, multi_class=multi_class)

    def fit_func(self, data, target):
        super().fit(data, target)

    def fit(self, data, target):
        if self.scaler is None:
            self.scaler_inst = FunctionTransformer(lambda x: x)
        else:
            self.scaler_inst = self.scaler()

        num_chn = len(self.num_channels)

        num_trials = data.shape[0]
        data = data[:, ::int(1000 / self.sr_data), self.num_channels]
        data = data[:, self.sample_rate[0]:self.sample_rate[1], :].reshape(-1, num_chn)

        self.scaler_inst.fit(data)
        
        data = self.scaler_inst.transform(data)
        data = data.reshape(num_trials, -1, num_chn)

        args = Args()
        args.sr_data = self.sr_data
        args.halfwin = self.halfwin
        args.overlap = self.overlap
        data = wavelet_transform(data, args)
        data = data.reshape(num_trials, -1)

        # fit model
        self.fit_func(data, target)

    def predict_func(self, data):
        return super().predict(data)

    def predict(self, data):
        num_chn = len(self.num_channels)

        num_trials = data.shape[0]
        data = data[:, ::int(1000 / self.sr_data), self.num_channels]
        data = data[:, self.sample_rate[0]:self.sample_rate[1], :].reshape(-1, num_chn)

        data = self.scaler_inst.transform(data)
        data = data.reshape(num_trials, -1, num_chn)

        args = Args()
        args.sr_data = self.sr_data
        args.halfwin = self.halfwin
        args.overlap = self.overlap
        data = wavelet_transform(data, args)
        data = data.reshape(num_trials, -1)

        # fit model
        return self.predict_func(data)

In [3]:
class LDATune(LinearDiscriminantAnalysis):
    def __init__(self,
                 solver='lsqr',
                 shrinkage='auto',
                 priors=None,
                 n_components=None,
                 store_covariance=False,
                 tol=0.0001,
                 covariance_estimator=None,
                 pca_comps=0,
                 scaler=None,
                 halfwin=7,
                 overlap=2,
                 num_channels=[0, 1, 2, 64, 65, 66],
                 wavelet=True):
        self.scaler = scaler
        self.halfwin = halfwin
        self.overlap = overlap
        self.num_channels = num_channels
        self.pca_comps = pca_comps
        self.sr_data = 500
        self.sample_rate = [0, 250]
        self.wavelet = wavelet

        super().__init__(solver=solver, shrinkage=shrinkage)

    def fit(self, data, target):
        if self.scaler is None:
            self.scaler_inst = FunctionTransformer(lambda x: x)
        else:
            self.scaler_inst = self.scaler()

        num_chn = len(self.num_channels)

        num_trials = data.shape[0]
        data = data[:, ::int(1000 / self.sr_data), self.num_channels]
        data = data[:, self.sample_rate[0]:self.sample_rate[1], :].reshape(-1, num_chn)

        self.scaler_inst.fit(data)
        
        data = self.scaler_inst.transform(data)
        data = data.reshape(num_trials, -1, num_chn)

        args = Args()
        args.sr_data = self.sr_data
        args.halfwin = self.halfwin
        args.overlap = self.overlap
        if self.wavelet:
            data = wavelet_transform(data, args)
        data = data.reshape(num_trials, -1)

        if self.pca_comps:
            self.pca = PCA(n_components=self.pca_comps)
            data = self.pca.fit_transform(data)

        # fit model
        super().fit(data, target)

    def predict(self, data):
        num_chn = len(self.num_channels)

        num_trials = data.shape[0]
        data = data[:, ::int(1000 / self.sr_data), self.num_channels]
        data = data[:, self.sample_rate[0]:self.sample_rate[1], :].reshape(-1, num_chn)

        data = self.scaler_inst.transform(data)
        data = data.reshape(num_trials, -1, num_chn)

        args = Args()
        args.sr_data = self.sr_data
        args.halfwin = self.halfwin
        args.overlap = self.overlap
        if self.wavelet:
            data = wavelet_transform(data, args)
        data = data.reshape(num_trials, -1)

        if self.pca_comps:
            data = self.pca.transform(data)

        # fit model
        return super().predict(data)

In [4]:
class LogisticRegL1(LDA):
    '''
    Logistic Regression model using the functionalities of the LDA class.
    Uses L1 regularization.
    '''
    def __init__(self, args):
        self.model = LogisticRegression(multi_class='ovr',
                                        penalty='l1',
                                        solver='liblinear',
                                        C=args.C_reg)
        self.fit_pca = False

In [10]:
class Args:
    def __init__(self):
        n = 1  # can be used to do multiple runs, e.g. over subjects

        self.model = LogisticRegL1
        self.scaler = StandardScaler
        self.C_reg = 1
        self.sample_rate = [0, 250]
        self.num_channels = [2, 64]  # 2, 64
        self.sr_data = 500  # sampling rate used for downsampling
        self.streaming_SR = 5
        self.halfwin = 15  # 10
        self.overlap = 2 # 2
        self.result_dir = os.path.join(
            '..',  # path(s) to save model and others
            'results',
            'stream2_jaw_6chan')

        # experiment arguments
        self.name = 'args.py'  # name of this file, don't change
        self.fix_seed = False
        self.common_dataset = False
        self.load_dataset = True  # whether to load self.dataset
        self.learning_rate = 0.0001  # learning rate for Adam
        self.max_trials = 1  # ratio of training data (1=max)
        self.val_max_trials = False
        self.batch_size = 20  # batch size for training and validation data
        self.epochs = 5000  # number of loops over training data
        self.val_freq = 20  # how often to validate (in epochs)
        self.print_freq = 5  # how often to print metrics (in epochs)
        self.save_curves = True  # whether to save loss curves to file
        self.load_model = False  # class of model to use
        self.dataset = None  # dataset class for loading and handling data

        # wavenet arguments
        self.activation = None  # activation function for models
        self.subjects = 1  # number of subjects used for training
        self.embedding_dim = 0  # subject embedding size
        self.p_drop = 0.6  # dropout probability
        self.ch_mult = 2  # channel multiplier for hidden channels in wavenet
        self.kernel_size = 2  # convolutional kernel size
        self.timesteps = 1  # how many timesteps in the future to forecast
        self.rf = 256  # receptive field of wavenet
        rf = 128
        ks = self.kernel_size
        nl = int(np.log(rf) / np.log(ks))
        dilations = [ks**i for i in range(nl)]
        self.dilations = dilations + dilations   # dilation: 2^num_layers
        #self.dilations = [1] + [2] + [4] * 7  # costum dilations

        # classifier arguments
        self.wavenet_class = None  # class of wavenet model
        self.load_conv = 'y'  # where to load neural nerwork weights from
        self.pred = False  # whether to use wavenet in prediction mode
        self.init_model = True  # whether to reinitialize classifier
        self.reg_semb = True  # whether to regularize subject embedding
        self.fixed_wavenet = False  # whether to fix weights of wavenet
        self.alpha_norm = 0.0  # regularization multiplier on weights
        self.num_classes = 118  # number of classes for classification
        self.units = [2200, 2000]  # hidden layer sizes of fully-connected block
        self.dim_red = 80  # number of pca components for channel reduction
        self.stft_freq = 0  # STFT frequency index for LDA_wavelet_freq model
        self.decode_peak = 0.1
        self.trial_average = False

        # quantized wavenet arguments
        self.mu = 255
        self.residual_channels = 1024
        self.dilation_channels = 1024
        self.skip_channels = 1024
        self.class_emb = 10
        self.channel_emb = 30
        self.cond_channels = self.class_emb + self.channel_emb
        self.head_channels = int(self.skip_channels/2)
        self.conv_bias = False

        # dataset arguments
        data_path = os.path.join('/', 'gpfs2', 'well', 'woolrich', 'projects',
                                 'cichy118_cont', 'preproc_data_onepass', 'epoched')
        self.data_path = [os.path.join(data_path, f'subj{i}') for i in range(n)]  # path(s) to data directory
        self.numpy = True  # whether data is saved in numpy format
        self.crop = 1  # cropping ratio for trials
        self.shuffle = True
        self.whiten = False  # pca components used in whitening
        self.group_whiten = False  # whether to perform whitening at the GL
        self.split = np.array([0, 0.2])  # validation split (start, end)
        self.original_sr = 1000
        self.save_data = True  # whether to save the created data
        self.save_whiten = False
        self.subjects_data = False  # list of subject inds to use in group data
        self.num_clip = 25
        self.dump_data = [os.path.join(data_path, f'subj{i}', 'train_data_trialnorm', 'c') for i in range(n)]  # path(s) for dumping data
        self.load_data = self.dump_data  # path(s) for loading data files

        # analysis arguments
        self.kernelPFI = False
        self.closest_chs = 'notebooks/closest1'  # channel neighbourhood size for spatial PFI
        self.PFI_inverse = False  # invert which channels/timesteps to shuffle
        self.pfich_timesteps = [[0, 50]]  # time window for spatiotemporal PFI
        self.PFI_perms = 10  # number of PFI permutations
        self.halfwin_uneven = False  # whether to use even or uneven window
        self.generate_noise = 1  # noise used for wavenet generation
        self.generate_length = self.sr_data * 1000  # generated timeseries len
        self.generate_mode = 'IIR'  # IIR or FIR mode for wavenet generation
        self.generate_input = 'gaussian_noise'  # input type for generation
        self.individual = True  # whether to analyse individual kernels
        self.anal_lr = 0.001  # learning rate for input backpropagation
        self.anal_epochs = 200  # number of epochs for input backpropagation
        self.norm_coeff = 0.0001  # L2 of input for input backpropagation
        self.kernel_limit = 300  # max number of kernels to analyse

        # simulation arguments
        self.nonlinear_prenoise = True
        self.nonlinear_data = True
        self.seconds = 3000
        self.events = 8
        self.sim_num_channels = 1
        self.sim_ar_order = 2
        self.gamma_shape = 14
        self.gamma_scale = 14
        self.noise_std = 2.5
        self.lambda_exp = 0.005
        self.ar_shrink = 1.0
        self.freqs = []
        self.ar_noise_std = np.random.rand(self.events) / 5 + 0.8
        self.max_len = 1000

        # AR model arguments
        self.order = 64
        self.uni = False
        self.save_AR = False
        self.do_anal = False
        self.AR_load_path = os.path.join(
            'results',
            'mrc',
            '60subjects_notch_sensors_multiAR64')

        # unused
        self.num_plot = 1
        self.plot_ch = 1
        self.linear = False
        self.num_samples_CPC = 20
        self.dropout2d_bad = False
        self.k_CPC = 1
        self.groups = 1
        self.conv1x1_groups = 1
        self.pos_enc_type = 'cat'
        self.pos_enc_d = 128
        self.l1_loss = False
        self.norm_alpha = self.alpha_norm
        self.num_components = 0
        self.resample = 7
        self.save_norm = True
        self.norm_path = os.path.join(data_path, 'norm_coeff')
        self.pca_path = os.path.join(data_path, 'pca128_model')
        self.load_pca = False
        self.compare_model = False
        self.channel_idx = 0


In [71]:
def control_code(code):
    code_dict = {'CTRL_FromServer': 1, 'CTRL_FromClient': 2}
    return code_dict.get(code, -1)

def data_type(code):
    data_dict = {'Data_Info': 1, 'Data_Eeg': 2, 'Data_Events': 3, 'Data_Impedance': 4}
    return data_dict.get(code, -1)

def request_type(code):
    request_dict = {'RequestVersion': 1,
                    'RequestChannelInfo': 3,
                    'RequestBasicInfoAcq': 6,
                    'RequestStreamingStart': 8,
                    'RequestStreamingStop': 9}
    return request_dict.get(code, -1)

def init_header(chanID, code, request, samples, size_body, sizeUn):
    # convert each character in chanID to uint8
    c_chID = struct.pack('4B', *map(ord, chanID))
    w_Code = struct.pack('>H', code)
    w_Request = struct.pack('>H', request)
    un_Sample = struct.pack('>I', samples)
    un_Size = struct.pack('>I', size_body)
    un_SizeUn = struct.pack('>I', sizeUn)

    return c_chID + w_Code + w_Request + un_Sample + un_Size + un_SizeUn

def block_type(code):
    block_dict = {'DataTypeFloat32bit': 1,
                  'DataTypeFloat32bitZIP': 2,
                  'DataTypeEventList': 3}
    return block_dict.get(code, -1)

def info_type(code):
    info_dict = {'InfoType_Version': 1,
                 'InfoType_BasicInfo': 2,
                 'InfoType_ChannelInfo': 4,
                 'InfoType_StatusAmp': 7,
                 'InfoType_Time': 9}
    return info_dict.get(code, -1)

def request_packet(con, packet_size):
    count = 0
    timeout = 20

    while True:
        data = con.recv(packet_size, 0)
        if data or count == timeout:
            break

        count += 1
        time.sleep(0.2)

    return data

def client_process_request(con, header, code, request, init):
    header_size = len(header)

    # send header if streaming start
    if not init:
        con.send(header)
    
    # get response header
    data = request_packet(con, 20)
    
    temp_packet_size = 0
    count = 0
    timeout = 10
    synch_packets = 5
    data_out = bytearray()
    message = {'code': None, 'request': None, 'start_sample': None, 'packet_size': None}
    
    message['code'] = struct.unpack('>H', data[4:6])[0]
    message['request'] = struct.unpack('>H', data[6:8])[0]
    message['start_sample'] = struct.unpack('>I', data[8:12])[0]
    message['packet_size'] = struct.unpack('>I', data[12:16])[0]
    
    if message['code'] in code and message['request'] in request:
        while temp_packet_size < message['packet_size'] and count < timeout:
            data = request_packet(con, message['packet_size'])
            temp_packet_size += len(data)
            data_out += data
            count += 1
    else:
        while count < synch_packets:
            request_packet(con, message['packet_size'])
            count += 1
    
    return data_out, message

def client_get_basic_info(con):
    basic_info = {}
    max_chans = 300
    header = init_header("CTRL",
                         control_code("CTRL_FromClient"),
                         request_type("RequestBasicInfoAcq"),
                         0, 0, 0)

    basic_info_raw, message = client_process_request(con,
                                                     header,
                                                     [data_type("Data_Info")],
                                                     [info_type("InfoType_BasicInfo")],
                                                     0)

    size = struct.unpack('<I', basic_info_raw[0:4])[0]
    eeg_chan = struct.unpack('<I', basic_info_raw[4:8])[0]
    sample_rate = struct.unpack('<I', basic_info_raw[8:12])[0]
    data_size = struct.unpack('<I', basic_info_raw[12:16])[0]
    allow_client_to_control_amp = struct.unpack('<I', basic_info_raw[16:20])[0]
    allow_client_to_control_rec = struct.unpack('<I', basic_info_raw[20:24])[0]
    
    basic_info = {
        'size': size,
        'eeg_chan': eeg_chan,
        'sample_rate': sample_rate,
        'data_size': data_size,
        'allow_client_to_control_amp': allow_client_to_control_amp,
        'allow_client_to_control_rec': allow_client_to_control_rec
    }

    return basic_info

def request_data_packet(con, basic_info, init=0):
    segments = []
    offset_event_type = 0
    offset_event_latency = offset_event_type + 4
    offset_event_start = offset_event_latency + 4
    offset_event_end = offset_event_start + 4
    offset_event_annotation = offset_event_end + 4

    # raw length
    event_struct_length = (offset_event_annotation + 520)//8*8

    # Protocol variable definitions
    data_types   = [data_type('Data_Eeg'), data_type('Data_Events'), data_type('Data_Impedance')]
        
    block_types  = [block_type('DataTypeFloat32bit'), block_type('DataTypeEventList')]

    header = init_header('CTRL',
                         control_code('CTRL_FromClient'),
                         request_type('RequestStreamingStart'),
                         0,0,0)

    # get data
    data, message = client_process_request(con, header, data_types, block_types, init=init)

    # if data packet
    if message['code'] == 2: 
        #receivedSamples = len(data) / (basic_info['data_size'] * basic_info['eeg_chan']) 
        #print(f"Received {len(data) / 1000} kBytes, EEG, {receivedSamples} samples, Start sample = {message['startSample']}")
        return data, message

    # if event packet
    elif message['code'] == 3: 
        if message['packet_size'] % event_struct_length == 0:
            num_events = message['packet_size'] // event_struct_length

            if num_events > 0:
                event_type = struct.unpack(
                    '<I', data[offset_event_type:offset_event_latency])[0]
                event_latency = struct.unpack(
                    '<I', data[offset_event_latency:offset_event_start])[0]
                event_annotation = struct.unpack(
                    '<H', data[offset_event_annotation:offset_event_annotation+2])[0]
                #print(f"Event type {eventType}, Latency: {eventLatency}, Annotation: {chr(eventAnnotation)}")

                return {'event_type': event_type,
                        'event_latency': event_latency,
                        'event_annotation': chr(event_annotation)}, None
        else:
            print("ClientRequestDataPacket failed: unmatching event structure size")

    return data, message

def stop_stream(con):
    header = init_header('CTRL',
                         control_code('CTRL_FromClient'),
                         request_type('RequestStreamingStop'),
                         0,0,0)
    con.send(header)

# decode data to numpy
def decode_data(data, num_samples, basic_info):
    dtype = np.float32 if basic_info['data_size'] == 4 else np.int16
    return np.frombuffer(data, dtype=dtype).reshape(num_samples, basic_info["eeg_chan"])

# notch filter function
def notch_filter(data, fs, f):
    # creater butter filter
    b, a = butter(5, [f - 1.5, f + 1.5], btype='bandstop', fs=fs, output='ba')
    return filtfilt(b, a, data)

def get_initial_data(sock, basic_info, SR):
    data, _ = request_data_packet(sock, basic_info)
    num_samples = len(data) // (basic_info["data_size"] * basic_info["eeg_chan"])

    segments = []
    events = []
    train_data = []
    train_target = []
    start_sample = 0
    last_sample_time = 0
    sample_count = SR + 1

    while True:
        # start streaming
        data, message = request_data_packet(sock, basic_info, 1)

        # check if data is dict (event packet)
        if isinstance(data, dict):
            sample_count = 0
            last_event = data['event_type']
            last_latency = data['event_latency'] - start_sample
            events.append(np.array([last_event, last_latency]))

            # if we are at the end of experiment exit loop
            if last_event == 7:
                break
        else:
            # set start sample index
            if not segments:
                start_sample = message['start_sample']
                last_sample_time = start_sample - num_samples
            
            if last_sample_time + num_samples != message['start_sample']:
                print('Wrong sample time')
            last_sample_time = message['start_sample']

            # shape: samples x channels
            packet = decode_data(data, num_samples, basic_info)
            segments.append(packet)
            sample_count += 1

        # if we have enough samples to make a trial
        if sample_count == SR:
            # concatenate last 1 second of samples
            trial = np.concatenate(segments[-SR-1:], axis=0)

            # calculate latency of last event compared to trial start
            latency = int(last_latency - (len(segments) - SR - 1) * 1000 / SR)
            train_data.append(trial[latency:latency+1000, :])
            train_target.append(last_event - 2)

            print(f"Trial {len(train_data)}: {last_event} at {last_latency} ms")

    stop_stream(sock)

    return np.array(segments), np.array(events), np.array(train_data), np.array(train_target)

def save_data(args, segments, events, train_data, train_target):
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    np.save(os.path.join(args.result_dir, 'events.npy'), events)
    np.save(os.path.join(args.result_dir, 'segments.npy'), segments)
    np.save(os.path.join(args.result_dir, 'train_data.npy'), train_data)
    np.save(os.path.join(args.result_dir, 'train_target.npy'), train_target)

def resegment_data(data, events):
    train_data = []
    train_target = []

    for e in events[:-1]:
        event = e[0]
        latency = e[1]

        # get trial data
        trial = data[:, latency:latency+1000]

        # append to train data
        train_data.append(trial)
        train_target.append(event - 2)

    train_data = np.array(train_data).transpose(0, 2, 1)
    return train_data, np.array(train_target)


def wavelet_transform(data, args):
    # trials, channels, samples
    data = data.transpose(0, 2, 1)

    f, t, data = stft(data,
                      fs=args.sr_data,
                      window='hamming',
                      nperseg=args.halfwin*2,
                      noverlap=args.overlap,
                      boundary=None)

    # concatenate wavelet coefficients
    data = np.concatenate((data.real, data.imag), axis=2)

    return data

def train_and_predict(args, data, target=None, lda=None, scaler=None, wavelet=False):
    init = False
    if lda is None:
        lda = args.model(args)

        if args.scaler is None:
            scaler = FunctionTransformer(lambda x: x)
        else:
            scaler = args.scaler()
        init = True

    num_chn = len(args.num_channels)

    num_trials = data.shape[0]
    data = data[:, ::int(1000 / args.sr_data), args.num_channels]
    data = data[:, args.sample_rate[0]:args.sample_rate[1], :].reshape(-1, num_chn)

    if init:
        scaler.fit(data)
    
    data = scaler.transform(data)
    data = data.reshape(num_trials, -1, num_chn)

    # compute wavelet transform
    if wavelet:
        data = wavelet_transform(data, args)
    data = data.reshape(num_trials, -1)

    if init:
        # fit lda model
        lda.model.fit(data, target)
        print(data.shape)

    if target is not None:
        # calculate accuracy
        pred = lda.model.score(data, target)
    else:
        # predict probability of test data
        pred = lda.model.predict_proba(data)

    return pred, lda, scaler

def real_time_predict(args, sock, basic_info, lda, scaler):
    # start streaming
    data, _ = request_data_packet(sock, basic_info)
    num_samples = len(data) // (basic_info["data_size"] * basic_info["eeg_chan"])
    SR = args.streaming_SR

    events = {0: 'hungry',
              1: 'tired',
              2: 'thirsty',
              3: 'toilet',
              4: 'pain'}

    segments = []
    latencies = []
    val_data = []
    val_target = []
    start_sample = 0
    num_trials = 0
    sample_count = SR + 1
    restart = True
    event_probs_list = []

    while True:
        data, message = request_data_packet(sock, basic_info, 1)

        # check if data is dict
        if isinstance(data, dict):
            sample_count = 0
            last_event = data['event_type']
            last_latency = data['event_latency'] - start_sample
        else:
            # decode data to numpy
            packet = decode_data(data, num_samples, basic_info)
        
            if restart:
                restart = False
                start_sample = message['start_sample']

            segments.append(packet)
            sample_count += 1

        # if we have enough samples to make a trial
        if sample_count == SR:
            num_trials += 1
            trial = np.concatenate(segments[-SR-1:], axis=0)

            latency = int(last_latency - (len(segments) - SR - 1) * 1000 / SR)
            val_data.append(trial[latency:latency+1000, :])
            val_target.append(last_event - 2)
            print(events[last_event-2])

        # make a prediction after 4 trials
        if num_trials == 4:
            num_trials = 0
            stop_stream(sock)
        
            # make a prediction
            probs, _, _ = train_and_predict(args, np.array(val_data[-4:]), lda=lda, scaler=scaler, wavelet=args.wavelet)
            event_probs_list.append(probs)

            for p in probs:
                # format to 2 decimals and sort by probability
                event_probs = {events[j]: round(p[j], 2)*100 for j in events.keys()}

                sorted_events = sorted(event_probs.items(),
                                       key=lambda item: item[1],
                                       reverse=True)
                event_probs = {k: int(v) for k, v in sorted_events}
                print(event_probs)

            # press enter to continue or q to quit
            key = input()
            if key == 'q':
                break
            
            # restart streaming
            restart = True
            segments = []
            latencies = []

            data, _ = request_data_packet(sock, basic_info)

    return event_probs_list, val_data, val_target

# Initialise socket

In [16]:
# Create a TCP/IP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

# Bind the socket to the port
server_address = ('192.168.0.1', 4455)
sock.connect(server_address)

In [17]:
basic_info = client_get_basic_info(sock)
basic_info

{'size': 24,
 'eeg_chan': 69,
 'sample_rate': 1000,
 'data_size': 4,
 'allow_client_to_control_amp': 0,
 'allow_client_to_control_rec': 0}

# Initial data collection

In [14]:
args = Args()

In [24]:
segments, events, train_data, train_target = get_initial_data(sock,
                                                              basic_info,
                                                              args.streaming_SR)

Trial 1: 2 at 7285 ms
Trial 2: 2 at 8375 ms
Trial 3: 2 at 9464 ms
Trial 4: 2 at 10554 ms
Trial 5: 5 at 14914 ms
Trial 6: 5 at 16004 ms
Trial 7: 5 at 17094 ms
Trial 8: 5 at 18184 ms
Trial 9: 6 at 22613 ms
Trial 10: 6 at 23703 ms
Trial 11: 6 at 24793 ms
Trial 12: 6 at 25883 ms
Trial 13: 3 at 30393 ms
Trial 14: 3 at 31483 ms
Trial 15: 3 at 32572 ms
Trial 16: 3 at 33662 ms
Trial 17: 4 at 38092 ms
Trial 18: 4 at 39182 ms
Trial 19: 4 at 40272 ms
Trial 20: 4 at 41362 ms
Trial 21: 2 at 45531 ms
Trial 22: 2 at 46621 ms
Trial 23: 2 at 47711 ms
Trial 24: 2 at 48801 ms
Trial 25: 4 at 53101 ms
Trial 26: 4 at 54191 ms
Trial 27: 4 at 55281 ms
Trial 28: 4 at 56370 ms
Trial 29: 6 at 60770 ms
Trial 30: 6 at 61860 ms
Trial 31: 6 at 62950 ms
Trial 32: 6 at 64040 ms
Trial 33: 5 at 68449 ms
Trial 34: 5 at 69539 ms
Trial 35: 5 at 70629 ms
Trial 36: 5 at 71719 ms
Trial 37: 3 at 76109 ms
Trial 38: 3 at 77199 ms
Trial 39: 3 at 78289 ms
Trial 40: 3 at 79378 ms
Trial 41: 2 at 83578 ms
Trial 42: 2 at 84668 ms
Tria

In [27]:
# save data
save_data(args, segments, events, train_data, train_target)

In [25]:
# checking data
cont_segments = segments.transpose(2, 0, 1).reshape(69, -1)

In [119]:
# notch filter data at 50Hz and harmonics
cont_segments = notch_filter(cont_segments, 1000, 49)
cont_segments = notch_filter(cont_segments, 1000, 99)
cont_segments = notch_filter(cont_segments, 1000, 149)
cont_segments = notch_filter(cont_segments, 1000, 199)
cont_segments = notch_filter(cont_segments, 1000, 249)

In [28]:
events[:10]

array([[    2,  7285],
       [    2,  8375],
       [    2,  9464],
       [    2, 10554],
       [    5, 14914],
       [    5, 16004],
       [    5, 17094],
       [    5, 18184],
       [    6, 22613],
       [    6, 23703]])

In [29]:
%matplotlib widget
plt.plot(cont_segments[-1, :30000])
e=0

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [30]:
%matplotlib widget
# compute PSD for each channel
f, psd = welch(cont_segments[[0, 1, 2, 64, 65, 66]], fs=1000, nperseg=8000, noverlap=4000)

# plot PSD
plt.figure(figsize=(10, 5))
plt.plot(f, psd.T)

plt.xlim(3, 250)
plt.ylim(0, 20)
plt.xlabel('Frequency (Hz)')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 0, 'Frequency (Hz)')

In [43]:
cont_segments = cont_segments[:, ::2]

In [44]:
# plot spectogram of cont_segments
plt.figure(figsize=(10, 5))
plt.specgram(cont_segments[2, 5000:5500], Fs=500, NFFT=32, noverlap=31)
e=0

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [71]:
%matplotlib widget
plt.plot(cont_segments[[3], 20000:40000].T, linewidth=0.5)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7fd18145cd60>]

In [34]:
# check evoked responses
%matplotlib widget
plt.plot(train_data[train_target==2, :, 64].mean(axis=0), linewidth=0.5)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7f9c1a70e460>]

In [26]:
# resegment data using events
train_data, train_target = resegment_data(cont_segments, events)

In [36]:
train_data.shape

(280, 1000, 69)

In [178]:
# plot spectogram of cont_segments
%matplotlib widget
# set log y scale
plt.specgram(train_data[50, ::2, 2], Fs=500, NFFT=25, noverlap=24, scale='dB')
e=0

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train LDA

In [104]:
# load data
train_data = np.load(os.path.join(args.result_dir, 'train_data.npy'))
train_target = np.load(os.path.join(args.result_dir, 'train_target.npy'))

segments = np.load(os.path.join(args.result_dir, 'segments.npy'))
events = np.load(os.path.join(args.result_dir, 'events.npy'))

In [105]:
args.num_channels = [0, 1, 2, 64, 65, 66]  # 2, 64, 65 / 0, 1, 2, 64, 65, 66
args.halfwin = 8  # 10 / 10
args.overlap = 4 # 5 / 2
args.C_reg = 0.05 # 0.5
args.model = LDA # LDA / LogisticRegL1
args.scaler = StandardScaler

In [106]:
# train on training split
ntrials = 280
acc, lda, scaler = train_and_predict(
    args, train_data[:ntrials], train_target[:ntrials], wavelet=True)
print(acc)

Couldn't load conv model for lda.
(280, 2268)
1.0


In [59]:
# test on validation split
acc, _, _ = train_and_predict(
    args, train_data[ntrials:], train_target[ntrials:], lda=lda, scaler=scaler, wavelet=False)
print(acc)

0.5


In [141]:
# optimize hyperparams
C_reg = [0.05, 0.1, 1, 2, 10]
halfwin = [5, 7, 9, 11, 13, 15]
overlap = [1, 3, 5, 7]
num_channels = [[0, 1, 2, 64, 65, 66]]
scalers = [StandardScaler]

In [194]:
# optimize hyperparams
C_reg = [0.1, 1, 10]
halfwin = [8]
overlap = [4]
num_channels = [[2, 64]]
pca_comps = [0]
scalers = [RobustScaler]
wavelet = [False]

In [196]:
gridCV = GridSearchCV(estimator=LDATune(),
                      param_grid={
                                  'halfwin': halfwin,
                                  'overlap': overlap,
                                  'num_channels': num_channels,
                                  'scaler': scalers},
                      scoring='accuracy',
                      n_jobs=-1,
                      verbose=3)
ntrials = 260
gridCV.fit(train_data[:ntrials], train_target[:ntrials])

Fitting 5 folds for each of 1 candidates, totalling 5 fits
[CV 1/5] END halfwin=8, num_channels=[2, 64], overlap=4, scaler=<class 'sklearn.preprocessing._data.RobustScaler'>;, score=0.500 total time=   0.7s
[CV 3/5] END halfwin=8, num_channels=[2, 64], overlap=4, scaler=<class 'sklearn.preprocessing._data.RobustScaler'>;, score=0.423 total time=   0.7s
[CV 4/5] END halfwin=8, num_channels=[2, 64], overlap=4, scaler=<class 'sklearn.preprocessing._data.RobustScaler'>;, score=0.404 total time=   0.7s
[CV 5/5] END halfwin=8, num_channels=[2, 64], overlap=4, scaler=<class 'sklearn.preprocessing._data.RobustScaler'>;, score=0.462 total time=   0.7s
[CV 2/5] END halfwin=8, num_channels=[2, 64], overlap=4, scaler=<class 'sklearn.preprocessing._data.RobustScaler'>;, score=0.404 total time=   0.7s


GridSearchCV(estimator=LDATune(), n_jobs=-1,
             param_grid={'halfwin': [8], 'num_channels': [[2, 64]],
                         'overlap': [4],
                         'scaler': [<class 'sklearn.preprocessing._data.RobustScaler'>]},
             scoring='accuracy', verbose=3)

In [143]:
print(gridCV.best_estimator_)
print(gridCV.best_score_)

LogisticRegTune(C=1, halfwin=15, overlap=1,
                scaler=<class 'sklearn.preprocessing._data.StandardScaler'>)
0.6807692307692308


In [145]:
ntrials = 260
best_model = LogisticRegTune(C=1, halfwin=15, overlap=1, num_channels=[0, 1, 2, 64, 65, 66], scaler=StandardScaler)
best_model.fit(train_data[:ntrials], train_target[:ntrials])
print(best_model.score(train_data[:ntrials], train_target[:ntrials]))
print(best_model.score(train_data[ntrials:], train_target[ntrials:]))

0.9192307692307692
0.55


In [136]:
ntrials = 260
best_model = LDATune(halfwin=8, overlap=4, num_channels=[0, 1, 2, 64, 65, 66], scaler=StandardScaler, wavelet=True)
best_model.fit(train_data[:ntrials], train_target[:ntrials])
print(best_model.score(train_data[:ntrials], train_target[:ntrials]))
print(best_model.score(train_data[ntrials:], train_target[ntrials:]))

0.9961538461538462
0.65


# Closed-loop prediction

In [129]:
overall_data = [train_data]
overall_target = [train_target]

In [163]:
# extend overall lists
overall_data.append(val_data)
overall_target.append(val_target)

In [164]:
# retrain on overall data
acc, lda, scaler = train_and_predict(
    args, np.concatenate(overall_data, axis=0), np.concatenate(overall_target, axis=0), wavelet=True)
print(acc)

Couldn't load conv model for lda.
(596, 2268)
0.9630872483221476


In [165]:
args.wavelet = True
probs_list, val_data, val_target = real_time_predict(args, sock, basic_info, lda, scaler)

thirsty
thirsty
thirsty
thirsty
{'thirsty': 90, 'pain': 10, 'hungry': 0, 'tired': 0, 'toilet': 0}
{'toilet': 100, 'hungry': 0, 'tired': 0, 'thirsty': 0, 'pain': 0}
{'thirsty': 100, 'hungry': 0, 'tired': 0, 'toilet': 0, 'pain': 0}
{'toilet': 100, 'hungry': 0, 'tired': 0, 'thirsty': 0, 'pain': 0}
tired
tired
tired
tired
{'pain': 100, 'hungry': 0, 'tired': 0, 'thirsty': 0, 'toilet': 0}
{'pain': 83, 'toilet': 13, 'tired': 5, 'hungry': 0, 'thirsty': 0}
{'tired': 100, 'hungry': 0, 'thirsty': 0, 'toilet': 0, 'pain': 0}
{'pain': 100, 'hungry': 0, 'tired': 0, 'thirsty': 0, 'toilet': 0}
toilet
toilet
toilet
toilet
{'pain': 100, 'hungry': 0, 'tired': 0, 'thirsty': 0, 'toilet': 0}
{'pain': 48, 'toilet': 37, 'tired': 15, 'hungry': 0, 'thirsty': 0}
{'toilet': 100, 'hungry': 0, 'tired': 0, 'thirsty': 0, 'pain': 0}
{'toilet': 100, 'hungry': 0, 'tired': 0, 'thirsty': 0, 'pain': 0}
thirsty
thirsty
thirsty
thirsty
{'thirsty': 100, 'hungry': 0, 'tired': 0, 'toilet': 0, 'pain': 0}
{'thirsty': 100, 'hungry'

In [160]:
# convert to array
probs = np.array(probs_list)
val_data = np.array(val_data)
val_target = np.array(val_target)

In [161]:
# save probs list and val_data and val_target
np.save(os.path.join(args.result_dir, 'probs_list_silent7.npy'), probs)
np.save(os.path.join(args.result_dir, 'val_data_silent7.npy'), val_data)
np.save(os.path.join(args.result_dir, 'val_target_silent7.npy'), val_target)

In [132]:
# load probs list and val_data and val_target
probs = np.load(os.path.join(args.result_dir, 'probs_list_loud4.npy'))
val_data = np.load(os.path.join(args.result_dir, 'val_data_loud4.npy'))
val_target = np.load(os.path.join(args.result_dir, 'val_target_loud4.npy'))

In [113]:
val_target

array([1, 1, 1, 1, 3, 3, 3, 3, 2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1, 0, 0,
       0, 0, 4, 4, 4, 4, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3, 4, 4, 4, 4])

In [114]:
np.argmax(probs, axis=1)

array([1, 3, 3, 1, 0, 1, 3, 3, 1, 4, 2, 2, 1, 1, 3, 3, 0, 1, 3, 2, 0, 1,
       1, 1, 0, 4, 2, 2, 2, 1, 2, 1, 4, 1, 1, 1, 1, 3, 3, 3, 1, 2, 0, 3])

In [162]:
# calculate accuracy
probs = probs.reshape(-1, probs.shape[2])
acc = np.mean(np.argmax(probs, axis=1) == val_target)
acc

0.625

In [121]:
acc, _, _ = train_and_predict(
    args, val_data, val_target, lda=lda, scaler=scaler, wavelet=True)
print(acc)

0.45454545454545453


In [None]:
# extend overall lists
overall_data.append(val_data)
overall_target.append(val_target)

In [134]:
# retrain on overall data
acc, lda, scaler = train_and_predict(
    args, np.concatenate(overall_data, axis=0), np.concatenate(overall_target, axis=0), wavelet=True)
print(acc)

Couldn't load conv model for lda.
(404, 2268)
0.9876237623762376
