In [1]:
import numpy as np
from time import time
from copy import deepcopy
import socketio
import tensorflow as tf
from threading import Lock
from tensorflow.keras.backend import clear_session
tf.keras.mixed_precision.set_global_policy('mixed_float16')

CHANNEL_NUMBER = 3
SAMPLE_TIME = 1 / 500
WINDOW_SIZE = 100
RECEIVE_CHANNEL = 'inference'
EMIT_CHANNEL = 'inferenceResult'
CLASS_NUMBER = 5
NUM_IMF = 3
ID_LEN = 6
BELIEF_THRESHOLD = 0.9

KEY_CLASS = {0:'undefined action', 1:'up', 2:'down', 3:'left', 4:'right', 5:'quick touch'}
MAXCHARLEN = max([len(KEY_CLASS[key]) for key in KEY_CLASS])

model_path = './model/LickingPark0222'
server_url = 'http://localhost:3000'

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce RTX 2080 Ti, compute capability 7.5


In [8]:
class inference():
    def __init__(self, url, modelPath) -> None:
        # connect to socketIO server
        self.__sio = socketio.Client()
        self.__sio.connect(url)
        try:
            assert isinstance(self.__sio, socketio.Client)
            print("Connected to Socket server.")
        except:
            raise Exception("Maybe server is not online.")
        
        # load model
        self.__model = tf.keras.models.load_model(modelPath)
        # initialize predictor
        self.__model.predict(np.zeros((1, WINDOW_SIZE, CHANNEL_NUMBER * NUM_IMF)), verbose = False)
        
        self.__lock = Lock()
        self.__white_list = {}
        self.__req = {'uid': None, 'data': None}
        
        self.__sio.on('whiteList', self.__newClient)
        self.__sio.on('rmWhiteList', self.__clientLeave)
        
    def __receiveSignal(self, req): # received clientID + CHANNEL_NUMBER * WINDOW_SIZE size of data
        self.__lock.acquire()
        self.__req = deepcopy(req)
        self.__lock.release()
        
    def  __newClient(self, info):
        self.__white_list[info['uid']] = info['stamp']
        # print("New client: ",info['uid'])
        
    def __clientLeave(self, uid):
        self.__white_list.pop(uid)
        
    def run(self):
        print("Listening requests.")
        self.__sio.on(RECEIVE_CHANNEL, self.__receiveSignal)
        try:
            while True:
                clock = time()
                
                try:
                    clientID = self.__req['uid']
                    data = np.array(self.__req['data'].split(",")).astype(np.float32).reshape(WINDOW_SIZE, CHANNEL_NUMBER * NUM_IMF)
                    ser = self.__req['serial_num']
                    
                    self.__req.clear()
                    assert self.__white_list.get(clientID, False), "Client {} not in whitelist.".format(clientID)
                    res = self.__model.predict(data[np.newaxis, :], verbose = False).flatten()
                    
                    candidateIdx = np.argmax(res) + 1 if res[np.argmax(res)] > BELIEF_THRESHOLD else 0
                    self.__sio.emit(EMIT_CHANNEL, {'uid': clientID, 'action': KEY_CLASS[candidateIdx]})
                    print("ID: {}-{: 5d}, Res: {}, Spend time: {:.3f}".format(clientID, ser, KEY_CLASS[candidateIdx], time() - clock).ljust(MAXCHARLEN + ID_LEN + 36), end='\r')
                
                except KeyboardInterrupt:
                    break
                
                except:
                    pass
                
        except KeyboardInterrupt:
            self.__sio.disconnect()
            clear_session()
            
        finally:
            self.__sio.disconnect()
            clear_session()

In [9]:
emdCNN = inference(server_url, model_path)
emdCNN.run()

Connected to Socket server.
Listening requests.
ID: 5cac08-  503, Res: undefined action, Spend time: 0.060