In [None]:
!pip install pyserial
!pip install torch
!pip install torchvision

In [None]:
# instructions: run each cell and use the correct serial port

In [1]:
import serial
import csv
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset

In [None]:
import serial

import serial.tools.list_ports

def list_all_ports():
    ports = serial.tools.list_ports.comports()

    if not ports:
        print("No serial ports found.")
    else:
        print("Available serial ports:")
        for port in ports:
            print(f"Port: {port.device}, Description: {port.description}")

list_all_ports()

In [None]:
serial_port = 'COM6';
baud_rate = 9600; #In arduino, Serial.begin(baud_rate)
ser = serial.Serial(serial_port, baud_rate)

In [None]:
ser.write('\n'.encode())

In [None]:
def processLine(line):
    line = line.split(',')
    line = line[0:4]
    if '' in line:
        return None
    if line == '':
        return None
    if len(line) != 4:
        return None
    line = [int(elem) for elem in line]
    line.append(line[0] - line[1])
    line.append(line[2] - line[3])
    return line

class SimpleNetWithHiddenLayer(nn.Module):
    def __init__(self):
        super(SimpleNetWithHiddenLayer, self).__init__()
        self.hidden = nn.Linear(6 * 15, 5)  # Input size 90, hidden layer size 5
        self.relu = nn.ReLU()                 # ReLU activation function
        self.output = nn.Linear(5, 5)      # Output layer for 5 classes 

    def forward(self, x):
        x = x.view(-1, 6 * 15)  # Flatten the input image from (15, 6) to (90,)
        x = self.relu(self.hidden(x))  # Apply hidden layer and ReLU activation
        x = self.output(x)             # Apply output layer
        return x
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.linear = nn.Linear(6 * 15, 4)  # Input size 90, output size 5

    def forward(self, x):
        x = x.view(-1, 6 * 15)  # Flatten the image from (15, 6) to (90,1)

        return self.linear(x)


In [None]:
with open('hiddenlayermodel.pkl', 'rb') as f:
    model = pickle.load(f)
model.eval()

In [None]:
def gesture(tensor):
    if (tensor == 0):
        return "Neural               UP"
    if (tensor == 1):
        return "Neural               LEFT"
    if (tensor == 2):
        return "Neural               DOWN"
    if (tensor == 3):
        return "Neural               RIGHT"
    if (tensor == 4):
        return "Neural            nothing"

In [None]:
data = []
while (True): # read lines from serial. if len(line) > 20, it's a line reporting what gesture arduino library detects so print the line itself. else, process the line and add it to data
    # if the model thinks that the list data is a gesture, print the gesture and the confidence value. then clear data
    line = ser.readline()
    line = line.decode("utf-8") #ser.readline returns a binary, convert to string
    line = line.strip()
    if (len(line) > 20):
        print(line)
        continue
    processedLine = processLine(line)
    if processedLine == None:
        continue
    if (len(processedLine) < 4):
        continue
    if(len(data) >= 15):
        del data[0]
    
    data.append(processedLine)
    
    if(len(data) == 15):
        processed = np.array(data)
        processed = processed / 255.0
        processed = (processed - 0.5) / 0.5
        tensor = torch.tensor(processed, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        with torch.no_grad():
            output = model(tensor)
            ma, predicted_label = torch.max(output, 1)
            if (ma > 8):
                print(gesture(predicted_label) + "                 confidence: " + str(((float)(ma))))
                data = []

In [None]:
while (True): # this cell only prints the arduino library-detected gesture
    line = ser.readline()
    line = line.decode("utf-8") #ser.readline returns a binary, convert to string
    line = line.strip()
    if (len(line) > 20):
        print(line)