In [10]:
import time
import numpy as np
from time import sleep
from emd.sift import sift
import socketio
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy('mixed_float16')

COM_PORT = 'COM3'    # 指定通訊埠名稱
NO_ARDUINO = True
REAL_DATA = True
BAUD_RATE = 115200    # 設定傳輸速率
CHANNEL_NUMBER = 4
SAMPLE_TIME = 1 / 500
WINDOW_SIZE = 300
THRESHOLD = 0.5
NUM_IMF = 6

ACTIONS = ['undefined action', 'right', 'left']
MAXCHARLEN = max([len(x) for x in ACTIONS])

model_path = './model/LickenPark'
server_url = 'http://localhost:3000'

In [18]:
class inference():
    def __init__(self, url, modelPath) -> None:
        self._model = tf.keras.models.load_model(modelPath)
        self._sio = socketio.Client()
        self._sio.connect(url)
        try:
            assert isinstance(self._sio, socketio.Client)
            print("Connected to Socket server.")
        except:
            raise Exception("Maybe server is not online.")
        self._container = None
        
    def _revieveSignal(self, raw):
        data = np.array(raw.split(",")).astype(np.float32)
        if isinstance(self._container, np.ndarray):
            if self._container.shape[0] < WINDOW_SIZE:
                self._container = np.concatenate([self._container, data[np.newaxis,:]], axis=0)
            else:
                self._container = np.concatenate([self._container[1:], data[np.newaxis,:]], axis=0)
        else:
            self._container = data[np.newaxis,:]
    
    def run(self):
        print("Starting inference.")
        self._sio.on('SignalInput', self._revieveSignal)
        while True:
            try:
                data = self._container.copy()[np.newaxis, :]
                assert data.shape[1] == WINDOW_SIZE
                # timer = time.time()
                inputData = self.__emdSignal(data)
                # emdTime = time.time() - timer
                # timer = time.time()
                res = self._model.predict(inputData,
                                          verbose = False)
                # print(res)
                # print("Spend {:.3f}s handling signal, {:.3f}s predicting, the result is {}.".format(emdTime, time.time() - timer, res))
                print(ACTIONS[np.argmax(res)].ljust(MAXCHARLEN), end='\r')
                self._sio.emit('InferenceResults', ACTIONS[np.argmax(res)])
            except KeyboardInterrupt:
                self._sio.disconnect()
                break 
            except:
                continue #it's a None
                   
        print("Inference finished.")
        
    def __emdSignal(self, sig):
        channel = sig.shape[-1]
        ret = None
    
        for c in range(channel):
            raw = sig[0, :, c]
            imf = sift(raw, max_imfs=NUM_IMF, imf_opts={'sd_thresh': 0.1})
            
            if imf.shape[-1] < NUM_IMF:
                compensate = np.zeros((WINDOW_SIZE, NUM_IMF - imf.shape[-1]))
                imf = np.concatenate([imf, compensate], axis = 1)
            
            if not type(ret) == np.ndarray: 
                ret = imf
            else: 
                ret = np.concatenate([ret, imf], axis = 1)
                     
        return ret[np.newaxis, :]

In [None]:
emdCNN = inference(server_url, model_path)
emdCNN.run()