### Instructions
1. Make sure you have the ```lb.pkl``` and ```model_gray.pth``` files in the same directory as this jupyter notebook
2. Make sure you have changed the ```IP_ADDRESS``` variable.
3. Make sure your mosquitto is running
4. Ideally, you should see "Connected" printed when runnning the 2nd block of code

In [1]:
import paho.mqtt.client as mqtt
import numpy as np
from PIL import Image
import json
from os import listdir
from os.path import join
import os
import cv2
import torch
import joblib
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import time
import requests

# Phone
IP_ADDRESS = "192.168.43.172" # Laptop IP - obtained through ipconfig in cmd prompt

USERNAME = os.getenv('_USERNAME')
PASSWORD = os.getenv('_PASSWORD')

global repeat, alphabet
alphabet = ""
repeat = 0
### --------------------------Model from Aryan---------------------------------
# load label binarizer
lb = joblib.load('lb.pkl')

class ASTCNN(nn.Module):
    def __init__(self):
        super(ASTCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.conv3 = nn.Conv2d(32, 64, 3)
        self.conv4 = nn.Conv2d(64, 128, 3)
        self.conv5 = nn.Conv2d(128, 256, 3)
        
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, len(lb.classes_))
        
        self.pool = nn.MaxPool2d(2, 2)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.pool(F.relu(self.conv5(x)))
        bs, _, _, _ = x.shape
        x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def setup(hostname):
    client = mqtt.Client()
    client.on_connect = on_connect
    client.on_message = on_message
    client.username_pw_set(USERNAME, PASSWORD)
    client.connect(hostname)
    client.loop_start()
    return client

def on_connect(client, userdata, flags, rc):
    if rc == 0:
        print("Connected.")
        client.subscribe("Group_25/CNN/Input")
    else:
        print("Failed to Connect. Error code: %d." % rc)
        
def hand_area(img):
    hand = img[100:324, 100:324]
    hand = cv2.resize(hand, (224,224))
    return hand

def on_message(client, userdata, msg):
    global alphabet, repeat
    # Receiving Frame here for visualization
    # Ideally, we should just receive the hand part. 
    frame = np.frombuffer(msg.payload, dtype=np.uint8)
    frame = frame.reshape(400, 400)
    
    cv2.rectangle(frame, (100, 100), (324, 324), (20, 34, 255), 2)
    image = hand_area(frame)

    image = torch.tensor(image, dtype=torch.float)
    image = image.unsqueeze(0)
    image = image.unsqueeze(0)
    
    outputs = model(image)
    outputs = nn.Softmax(dim = -1)(outputs)
    prob, preds = torch.max(outputs.data, 1)
    if prob > 0.5 and alphabet == lb.classes_[preds] and alphabet != 'nothing':
        repeat += 1
    else:
        alphabet = lb.classes_[preds]
        repeat = 0
    if repeat > 5:
        repeat = 0
        print(alphabet)
        client.publish("Group_25/NLP/Input", alphabet.lower())        
        
    frame_size = frame.shape
    cv2.putText(frame, f"{lb.classes_[preds]}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
    cv2.imshow('image', frame)
    cv2.waitKey(1)



https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [2]:
model = ASTCNN()
model.load_state_dict(torch.load('model_gray.pth', map_location='cpu'))
print("Model Loaded")
if __name__ == "__main__":
    client = setup(IP_ADDRESS) 
    

Model Loaded
Connected.
