In [292]:
%serialconnect to --port="/dev/tty.usbserial-0001" --baud=115200

[34mConnecting to --port=/dev/tty.usbserial-0001 --baud=115200 [0m
[34mReady.
[0m

In [293]:
from ulab import numpy as np
import gc

from lib.computation import solve_eig_qr, standardise, solve_gen_eig_prob
    
class CCA():
    
    def __init__(self, stim_freqs, fs, Nh=2):
        self.Nh = Nh
        self.stim_freqs = stim_freqs
        self.fs = fs
        
    def compute_corr(self, X_test):            
        result = {}
        Cxx = np.dot(X_test, X_test.transpose()) # precompute data auto correlation matrix
        for f in self.stim_freqs:
            Y = harmonic_reference(f, self.fs, np.max(X_test.shape), Nh=self.Nh, standardise_out=False)
            rho = self.cca_eig(X_test, Y, Cxx=Cxx) # canonical variable matrices. Xc = X^T.W_x
            result[f] = rho
        return result
    
    @staticmethod
    def cca_eig(X, Y, Cxx=None, eps=1e-6):
        if Cxx is None:
            Cxx = np.dot(X, X.transpose()) # auto correlation matrix
        Cyy = np.dot(Y, Y.transpose()) 
        Cxy = np.dot(X, Y.transpose()) # cross correlation matrix
        Cyx = np.dot(Y, X.transpose()) # same as Cxy.T

        M1 = np.dot(np.linalg.inv(Cxx+eps), Cxy) # intermediate result
        M2 = np.dot(np.linalg.inv(Cyy+eps), Cyx)

        lam, _ = solve_eig_qr(np.dot(M1, M2), 20)
        return np.sqrt(lam)
    
class UnivariateMsetCCA():
    """
    Multiset CCA algorithm for SSVEP decoding.
    
    Computes optimised reference signal set based on historical observations
    and uses ordinary CCA for final correlation computation given a new test
    signal.
    
    Note: this is a 1 channel implementation (Nc=1)
    """
    
    def __init__(self):
        self.Ns, self.Nt = None, None
        
    def fit(self, X, compress_ref=True):
        """
        Expects a training matrix X of shape Nt x Ns. If `compress_ref=True`, the `Nt` components
        in optimised reference signal Y will be averaged to form a single reference vector. This
        can be used for memory optimisation but will likely degrade performance slightly.         
        """
        if X.shape[0] > X.shape[1]:
            print("Warning: received more trials than samples. This is unusual behaviour: check orientation of X"
                 )
        R = np.dot(X, X.transpose()) # inter trial covariance matrix
        S = np.eye(len(R))*np.diag(R) # intra-trial diag covariance matrix

        lam, V = solve_gen_eig_prob((R-S), S) # solve generalised eig problem
        w = V[:, np.argmax(lam)] # find eigenvector corresp to largest eigenvalue
        Y = np.array([x*w[i] for i, x in enumerate(X)]) # store optimised reference vector Nt x Ns
        self.Y  = Y
        if compress_ref:
            self.Y = np.mean(Y, axis=0).reshape((1, max(Y.shape))) # this will average Nt components in Y: Nc x Nt -> 1 x Nt
        
    def compute_corr(self, X_test):
        if self.Y is None:
            raise ValueError("Reference matrix Y must be computed using `fit` before computing corr")
        if len(X_test.shape) == 1:
            X_test = X_test.reshape((1, len(X_test)))
            
        print(X_test.shape, self.Y.shape)
        return CCA.cca_eig(X_test, self.Y)[0]
          

def harmonic_reference(f0, fs, Ns, Nh=1, standardise_out=False):
    
    '''
    Generate reference signals for canonical correlation analysis (CCA)
    -based steady-state visual evoked potentials (SSVEPs) detection [1, 2].
    function [ y_ref ] = cca_reference(listFreq, fs,  Ns, Nh) 
    Input:
      f0        : stimulus frequency
      fs              : Sampling frequency
      Ns              : # of samples in trial
      Nh          : # of harmonics
    Output:
      y_ref           : Generated reference signals with shape (Nf, Ns, 2*Nh)
    '''  
    X = np.zeros((Nh*2, Ns))
    
    for harm_i in range(Nh):
        # Sin and Cos
        X[2*harm_i, :] = np.sin(np.arange(1,Ns+1)*(1/fs)*2*np.pi*(harm_i+1)*f0)
        gc.collect()
        X[2*harm_i+1, :] = np.cos(np.arange(1,Ns+1)*(1/fs)*2*np.pi*(harm_i+1)*f0)
        gc.collect()

    # print(micropython.mem_info(1))
    if standardise_out: # zero mean, unit std. dev
        return standardise(X)
    return X


In [239]:
from ulab import numpy as np

data1 = np.array([])
data2 = np.array([1,1,1])
data4 = np.array([1,1,1])

data3 = np.concatenate((data1,data2,data4))

print(data1)
print(data2)
print(data3)
data3 = data3.reshape((2,3))
print(data3)

array([], dtype=float32)
array([1.0, 1.0, 1.0], dtype=float32)
array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], dtype=float32)
array([[1.0, 1.0, 1.0],
       [1.0, 1.0, 1.0]], dtype=float32)


# For testing on Notebook

### Get the data for calibration

In [294]:
from ulab import numpy as np
import utime as time
from lib.runner import Runner

decode_period_s = 4 # decode every x seconds
number_of_calibrations = 4

runner = Runner() # initialise a base runner
runner.setup()

def calibration(number_of_calibrations, decode_period_s):

    global runner
    
    data = np.array([])
    runner.run() #depending on implementation, if already running then not needed, however may want to manually call this
    time.sleep(5)
    count=0

    while count < number_of_calibrations:
        time.sleep(decode_period_s)
        data = np.concatenate((data,np.array(np.array(runner.output_buffer))))
        count+=1
        gc.collect()
    runner.stop()

    print(data.shape)

    data = data.reshape((number_of_calibrations,256))

    print(data.shape)
    gc.collect()
    return data

# def preprocess_data(signal):
    
#     """Preprocess incoming signal before decoding algorithms.
#     This involves applying a bandpass filter to isolate the target SSVEP range
#     and then downsampling the signal to the Nyquist boundary.
    
#     Returns:
#         [np.ndarray]: filtered and downsampled signal
#     """
#     from lib.signal import sos_filter
#     downsample_freq = 64
#     ds_factor = 256//downsample_freq
#     return sos_filter(signal)[::ds_factor]





ADC initialised
SPI initialised
DigiPot set to 100 = gain of 10.62498


### Instantiate the MsetCCA Decoder, calibrate

In [295]:
mset7hz = UnivariateMsetCCA()
data7 = calibration(number_of_calibrations,decode_period_s)
gc.collect()
mset7hz.fit(data7, False)

....(1024,)
(4, 256)


In [296]:
mset10hz = UnivariateMsetCCA()
data10 = calibration(number_of_calibrations,decode_period_s)
gc.collect()
mset10hz.fit(data10, False)

....(1024,)
(4, 256)


In [167]:
print(mset10hz.compute_corr(data10))

1.0


In [207]:
mset12hz = UnivariateMsetCCA()
data12 = calibration(number_of_calibrations,decode_period_s)
mset12hz.fit(data12, False)

....(1024,)
(4, 256)


In [209]:
import json

data = {"12hz": data12list}

with open('data12hz.json', 'w') as jsonfile:
    json.dump(data, jsonfile)

In [176]:
print(mset7hz.compute_corr(data7))

0.9999975


In [247]:
print(mset7hz.Y.shape)
runner.run()


(2, 256)


In [250]:
print(np.array(runner.output_buffer).shape)

(256,)


In [251]:
runner.stop()

In [291]:
try:
    while True:
        runner.run()
        time.sleep(decode_period_s)
        signals = np.array(runner.output_buffer)
        seven = mset7hz.compute_corr(signals)
        ten = mset10hz.compute_corr(signals)
        if seven > ten:
            print("LEFT",seven, ten)
        else: 
            print("UP",seven, ten)
#         print("12:",mset12hz.compute_corr(signals))
        gc.collect()
except KeyboardInterrupt:
    runner.stop()
    print('received SIGINT - stopping')

(1, 256) (4, 256)
(1, 256) (4, 256)
LEFT 0.2957219 0.1061028
.Traceback (most recent call last):
  File "lib/runner.py", line 95, in decode
  File "lib/decoding.py", line 20, in compute_corr
  File "lib/decoding.py", line 57, in harmonic_reference
MemoryError: memory allocation failed, allocating 4096 bytes
(1, 256) (4, 256)
(1, 256) (4, 256)
UP 0.1349685 0.2012595
.(1, 256) (4, 256)
(1, 256) (4, 256)
UP 0.1371359 0.1429319
.(1, 256) (4, 256)
(1, 256) (4, 256)
UP 0.1037697 0.1103653
(1, 256) (4, 256)
(1, 256) (4, 256)
UP 0.1556868 0.1624804
.(1, 256) (4, 256)
(1, 256) (4, 256)
LEFT 0.1528386 0.114336
.(1, 256) (4, 256)
(1, 256) (4, 256)
UP 0.2112024 0.2341473
.(1, 256) (4, 256)
(1, 256) (4, 256)
LEFT 0.2002512 0.09901379
.(1, 256) (4, 256)
(1, 256) (4, 256)
UP 0.1278256 0.2234941
(1, 256) (4, 256)
(1, 256) (4, 256)
UP 0.1290136 0.1468532
.(1, 256) (4, 256)
(1, 256) (4, 256)
LEFT 0.1799493 0.1408222
.(1, 256) (4, 256)
(1, 256) (4, 256)
LEFT 0.1529382 0.1180655
Traceback (most recent cal

In [114]:
test = [0.1242559, 11.46469, 74.65332, 103.8912, -34.06412, -60.70904, 14.68118, -53.45461, 8.520441, 5.189057, -52.20088, 18.83386, -39.87533, -63.34948, -62.88006, -7.184746, 7.670803, -55.21675, -6.25758, -4.994353, 0.5614357, -31.81414, -39.97784, -13.29278, -29.46731, -27.87234, -32.13979, -1.18476, 15.23, 28.73038, 5.43701, -17.22902, -13.46323, 34.19746, 56.66936, -42.65666, -11.44159, 47.63587, 23.49923, 33.89868, 36.91066, 53.96599, 34.08094, 10.92797, 15.98665, -6.482882, 0.408886, 5.840509, 18.16956, 10.95813, -15.54138, 34.28454, 13.00534, -46.4482, -15.40177, -1.627297, 5.766224, 8.314998, -1.119732, 10.07932, 3.315748, 0.03746986, -1.078606, -27.79785, 0.01785966, 5.351636, 50.84169, 101.661, 0.7994409, -71.67712, 27.90154, -21.18307, -16.46765, 116.172, 18.59735, -76.45651, 2.352575, 37.39071, -17.10242, 5.487722, 29.96381, 22.48346, -15.15711, -54.20102, 10.04527, 2.682909, -15.34871, -5.530307, 0.474077, -0.7739229, -13.96313, -3.159647, -41.81649, -59.99507, -53.90737, -17.03086, -34.74346, -53.52586, 12.70059, 29.6495, 19.60801, 0.9641957, 32.63427, 17.31091, -9.930633, 7.225876, -0.9699779, 44.28185, 46.07536, 33.65012, 30.92655, -0.7912481, 26.36427, 51.49652, -5.184861, -37.20837, 20.43198, 42.41333, 3.472577, -17.00937, 8.404181, 9.773996, -18.75221, -23.85296, -67.33251, -98.90683, -43.48698, 34.19974, -0.1304363, -10.77824, -56.37276, -55.54525, 38.63149, 0.2614212, -60.9148, 11.9404, -26.60442, -14.86391, 44.23504, 15.14816, -14.136, -48.55381, 8.350963, -28.15352, -73.79556, -17.79183, 9.946865, 8.26569, -45.27136, 3.866302, 18.80053, -11.55937, -21.47431, -42.30291, -9.220036, -29.43005, -26.04054, 1.93306, 15.30578, 22.80381, 39.88576, 55.90287, 27.19797, 45.74711, 67.56344, 13.89118, -12.04704, 74.86397, 66.64375, 37.44535, 36.4254, -9.222485, 27.44955, -9.426237, -9.804419, -4.118249, -20.63764, 2.659115, -5.301051, 32.51813, -19.36118, -28.64934, 2.767703, -30.41734, -22.98153, 30.98759, 3.012625, -50.10794, 5.20744, -16.45278, 28.40638, 57.91431, -0.1297985, -11.91543, -74.49046, -123.9902, -38.90279, 33.8967, 15.32243, -21.15151, -52.58987, -2.684384, -13.47912, 22.47922, 122.6188, -10.70152, -74.81825, 42.68987, 25.29355, -20.02717, 3.962679, 10.90367, 28.50235, -35.6212, -50.40169, 11.51803, 3.826478, 2.460785, -9.174918, 18.92617, 5.485937, 18.22744, 44.68678, -40.20915, -54.41223, 7.552943, 58.81316, -6.156619, 1.746366, 39.31748, -17.60847, -19.35892, -16.52242, 11.28899, -4.493011, -8.112165, -10.51986, 1.142932, 40.35253, 28.09395, 34.61413, 1.831679, -2.851778, 34.99657, 59.66302, -9.836491, -34.39755, 52.23692, 24.66334, 13.21892, 6.807874, -12.62385, -4.796031, -30.46659, -29.79107, -28.90603]

In [145]:
print(data10)

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'data10' isn't defined


In [198]:
import json

data = {"7hz": data7list(),
        "10hz": data10list(),
        "12hz": data12list()}

with open('data1.json', 'w') as jsonfile:
    json.dump(data, jsonfile)

Traceback (most recent call last):
  File "<stdin>", line 3, in <module>
AttributeError: 'ndarray' object has no attribute 'tolist'


In [143]:
%ls

Listing directory '/'.
       55    .env
      666    data.json
             lib/


In [180]:
print(data7list)

[leftinbuffer] ['>']
[leftinbuffer] ['MicroPython v1.16-141-g224ac355c-dirty on 2021-07-28; ESP32 module with ESP32']
[leftinbuffer] ['Type "help()" for more information.']
[leftinbuffer] ['>>> ']
[leftinbuffer] ['Brownout detector was triggered']
[leftinbuffer] ['ets Jul 29 2019 12:21:46']
[leftinbuffer] ['rst:0xc (SW_CPU_RESET),boot:0x13 (SPI_FAST_FLASH_BOOT)']
[leftinbuffer] ['configsip: 0, SPIWP:0xee']
[leftinbuffer] ['clk_drv:0x00,q_drv:0x00,d_drv:0x00,cs0_drv:0x00,hd_drv:0x00,wp_drv:0x00']
[leftinbuffer] ['mode:DIO, clock div:2']
[leftinbuffer] ['load:0x3fff0018,len:4']
[leftinbuffer] ['load:0x3fff001c,len:5204']
[leftinbuffer] ['load:0x40078000,len:12136']
[leftinbuffer] ['load:0x40080400,len:3496']
[leftinbuffer] ['entry 0x4008063c']
[leftinbuffer] ['MicroPython v1.16-141-g224ac355c-dirty on 2021-07-28; ESP32 module with ESP32']
[leftinbuffer] ['Type "help()" for more information.']
[leftinbuffer] ['>>> ']

[missing-OK]print(data7list)

[missing-OK]Traceback (most recent call l