In [27]:
import torch
from PIL import Image
import pandas as pd
import numpy as np



In [28]:
from importlib.resources import path

from torch import FloatTensor, LongTensor
from torchvision import transforms

class MNISTDatasetASL(torch.utils.data.Dataset):
    def __init__(self, path):
        
        self.X, self.Y = self.get_data_mapping(path)
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(28, scale=(0.8, 1.2)),
            transforms.RandomHorizontalFlip(p = 0.5),
            transforms.ToTensor(),
            transforms.RandomResizedCrop(28, scale=(0.8, 1.2))]
        )

    def get_data_mapping(self, path):
        dataset = pd.read_csv(path)
        Y = torch.tensor(dataset['label'].values.astype(np.int8))
        X = torch.tensor(dataset.drop('label', axis = 1).values.astype(np.int8)) 
        return (X, Y)    

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        label = self.Y[idx].type(LongTensor)
        input = self.transform(self.X[idx].reshape(1, 28, 28).type(FloatTensor))
        return (input, label)


In [37]:
train_path = "data/sign_mnist_train.csv"
test_path = "data/sign_mnist_test.csv"

trainLoader = torch.utils.data.DataLoader(MNISTDatasetASL(train_path), batch_size=32)
testLoader = torch.utils.data.DataLoader(MNISTDatasetASL(train_path), batch_size=16)


## Define Network Architectyre

In [30]:
import torch.nn as nn

class ASLClassifier(nn.Module):
    class ConvBlock(nn.Module):
        def __init__(self, input_c, output_c, kernel_size):
            super().__init__()
            self.block = nn.Sequential(
                nn.Conv2d(input_c, output_c, kernel_size),
                nn.ReLU(),
                nn.MaxPool2d(2,2)
            )
        def forward(self, x):
            return self.block(x)

        
    def __init__(self):
        print("HERE")
        super().__init__()

        self.conv = torch.nn.Sequential(
            self.ConvBlock(1, 6, 3),
            self.ConvBlock(6, 16, 3)
        )
        self.mlp = torch.nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 48),
            nn.ReLU(),
            nn.Linear(48, 25)
        )

    def forward(self, x):
        flat_out = self.conv(x).flatten(start_dim = 1)
        return self.mlp(flat_out)

    def predict(self, x):
        return torch.argmax(self.forward(x), dim=1)

net = ASLClassifier()


HERE


In [31]:
%reload_ext tensorboard
%tensorboard --logdir docs

Reusing TensorBoard on port 6006 (pid 33653), started 1:08:48 ago. (Use '!kill 33653' to kill it.)

## Set up a Training Loop

In [38]:
import torch.utils.tensorboard as tb
from os import path

def train(model, train_loader, test_loader, lr = 0.001, epochs = 2, log_dir = 'docs'):
    
    optimizer = torch.optim.AdamW(model.parameters())
    loss_fn = nn.CrossEntropyLoss()
    global_step = 0

    train_logger = tb.SummaryWriter(path.join(log_dir, 'train'))
    valid_logger = tb.SummaryWriter(path.join(log_dir, 'valid'))

    for epoch in range(epochs):
        total_loss = 0
        for X, Y in train_loader:
            X = X.type(torch.FloatTensor)
            output = model(X)
            loss = loss_fn(output, Y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_logger.add_scalar('loss', loss, global_step=global_step)

            global_step += 1

            total_loss += loss
        
        print(f"LOSS AT EPOCH {epoch} : {total_loss}")

        total_correct = 0
        count = 0
        for X, Y in test_loader:
            predictions = model.predict(X)

            total_correct += (predictions == Y).sum()
            count += len(Y)
        print(f"Valid Accuracy AT EPOCH {epoch} : {total_correct/ count}")
        valid_logger.add_scalar('accuracy', total_correct/count, global_step=global_step)

net = ASLClassifier()
train(net, trainLoader, testLoader, lr = 0.01, epochs=15)





HERE
LOSS AT EPOCH 0 : 2191.998779296875
Valid Accuracy AT EPOCH 0 : 0.3832453191280365
LOSS AT EPOCH 1 : 1387.6546630859375
Valid Accuracy AT EPOCH 1 : 0.5365142822265625
LOSS AT EPOCH 2 : 1092.827392578125
Valid Accuracy AT EPOCH 2 : 0.6077581644058228
LOSS AT EPOCH 3 : 917.8850708007812
Valid Accuracy AT EPOCH 3 : 0.6772900819778442
LOSS AT EPOCH 4 : 777.7047729492188
Valid Accuracy AT EPOCH 4 : 0.7247495651245117
LOSS AT EPOCH 5 : 685.1024780273438
Valid Accuracy AT EPOCH 5 : 0.756037175655365
LOSS AT EPOCH 6 : 600.2249755859375
Valid Accuracy AT EPOCH 6 : 0.7850300669670105
LOSS AT EPOCH 7 : 537.0792236328125
Valid Accuracy AT EPOCH 7 : 0.8025131821632385
LOSS AT EPOCH 8 : 494.1986999511719
Valid Accuracy AT EPOCH 8 : 0.8183937072753906
LOSS AT EPOCH 9 : 448.7412414550781
Valid Accuracy AT EPOCH 9 : 0.8404297828674316
LOSS AT EPOCH 10 : 416.04498291015625
Valid Accuracy AT EPOCH 10 : 0.8547805547714233
LOSS AT EPOCH 11 : 380.7231140136719
Valid Accuracy AT EPOCH 11 : 0.87259149551

In [39]:
torch.save(net.state_dict(), "checkpoint.pth")

## Export Code To Onnpix

In [None]:
import onnx
import onnxruntime as ort

testLoader = torch.utils.data.DataLoader(MNISTDatasetASL(test_path), batch_size=1)

net = ASLClassifier()
model = torch.load("checkpoint.pth")
net.load_state_dict(model)

acc_count = 0
count = 0
for X, Y in testLoader:
    predictions = net.predict(X)
    # print(type(predictions == Y))
    acc_count += (predictions == Y).type(torch.LongTensor)
    count += 1

print(f'Torch Model Accuracy (baseline):  {acc_count / count}')


# export to onnx
fname = "asl_model.onnx"
dummy = torch.randn(1, 1, 28, 28)
net(dummy)
torch.onnx.export(net, dummy, fname, input_names=['input'])

# check exported model
model = onnx.load(fname)
onnx.checker.check_model(model)  # check model is well-formed

# create runnable session with exported model
ort_session = ort.InferenceSession(fname)
net = lambda inp: ort_session.run(None, {'input': inp.data.numpy()})[0]


acc_count = 0
count = 0
for X, Y in testLoader:
    predictions = net(X).argmax()
    acc_count += int(predictions == Y)
    count += 1

print(f'Torch Model Accuracy (baseline):  {acc_count / count}')

HERE
