In [1]:
import serial  # 引用pySerial模組
import time
import numpy as np
from time import sleep
from keyboard import is_pressed, read_key
import socketio

BAUD_RATE = 115200    # 設定傳輸速率
COM_PORT = 'COM3'    # 指定通訊埠名稱
CHANNEL_NUMBER = 4
SAMPLE_TIME = 1 / 500
NOARDUINO = False

record_path = './data/lick2.npy'
server_url = 'http://localhost:3000'

In [24]:
class reciever():
    def __init__(self, url, port, baud_rate, record = None) -> None: 
        if isinstance(record, loader):
            self._serial = record
        else:
            self._serial = serial.Serial(port, baud_rate)# 初始化序列通訊埠
        self._sio = socketio.Client()
        self._sio.connect(url)
        try:
            assert isinstance(self._sio, socketio.Client)
            print("Connected to Socket server.")
        except:
            raise Exception("Maybe server is not online.")
    
    def run(self):
        try:
            while True:
                while self._serial.in_waiting:          # 若收到序列資料…
                    data_raw = self._serial.readline()  # 讀取一行
                    rcv = data_raw.decode(errors='surrogateescape').rstrip() # 用預設的UTF-8解碼
                    try:
                        self._sio.emit("ArduinoSignal", rcv)
                        print(rcv)
                    except KeyboardInterrupt:
                        break
                    except:
                        pass
                    finally:
                        if isinstance(self._serial, loader):
                            sleep(SAMPLE_TIME) 

        except KeyboardInterrupt:
            self._serial.close()    # 清除序列通訊物件
            self._sio.disconnect()
            print('Serial disconnected.')
            
class recorder():
    def __init__(self, port, baud_rate, record_name, record_path = './data') -> None: 
        self._serial = serial.Serial(port, baud_rate)# 初始化序列通訊埠
        self.X = None
        self.y = []
        self._record_path = record_path + '/' + record_name
        self.keyclass = {0:'undefined action', 1:'up', 2:'down', 3:'left', 4:'right'} 
    
    def run(self):
        try:
            while True:
                while self._serial.in_waiting:          # 若收到序列資料…
                    data_raw = self._serial.readline()  # 讀取一行
                    rcv = data_raw.decode(errors='surrogateescape').rstrip() # 用預設的UTF-8解碼
                    try:
                        data = np.array(rcv.split(',')).astype(np.float32)
                        if data.shape[0] < 4:
                            continue
                        
                        if is_pressed('w'):
                            self.y.append(1)
                        elif is_pressed('s'):
                            self.y.append(2)
                        elif is_pressed('a'):
                            self.y.append(3)
                        elif is_pressed('d'):
                            self.y.append(4)
                        else:
                            self.y.append(0)
                        
                        if isinstance(self.X, np.ndarray):
                            self.X = np.concatenate([self.X, data[np.newaxis, :].copy()], axis=0)
                        else:
                            self.X = data[np.newaxis, :].copy()
                            
                        print(self.keyclass[self.y[-1]].ljust(len('undefined action')), end='\r')
                            
                    except KeyboardInterrupt:
                        break
                    except:
                        pass

        except KeyboardInterrupt:
            np.save(self._record_path + '_X', self.X)
            print("Saved file {}.".format(self._record_path + '_X'))
            np.save(self._record_path + '_y', np.array(self.y).astype(np.int32))
            print("Saved file {}.".format(self._record_path + '_y'))
            self._serial.close()    # 清除序列通訊物件
            print('Serial disconnected.')
            
class loader():
    def __init__(self, path) -> None:
        self.data = np.load(path)
        self.in_waiting = True
        self.__i = 0
        
    def readline(self):
        try:
            ret = self.data[self.__i, :].astype(np.str_).tolist()
            return ",".join(ret).encode(encoding='UTF-8')
        except KeyboardInterrupt:
            self.in_waiting = False
        except:
            self.__i = 0
            ret = self.data[self.__i, :].astype(np.str_).tolist()
            return ",".join(ret).encode(encoding='UTF-8')
        finally:
            self.__i += 1
    
    def close(self):
        pass

In [None]:
if NOARDUINO:
    arduino = reciever(server_url, COM_PORT, BAUD_RATE, loader(record_path))
else:
    arduino = reciever(server_url, COM_PORT, BAUD_RATE)
arduino.run()

In [27]:
record_name = 'test'
arduino = recorder(COM_PORT, BAUD_RATE, record_name)
arduino.run()

Saved file ./data/test_X.
Saved file ./data/test_y.
Serial disconnected.


In [19]:
a = np.load('./data/test_X.npy')
b = np.load('./data/test_y.npy')