In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split

In [2]:
# contains popular datasets, model architectures, common image transformation
from torchvision import datasets, transforms

# chain together different transformation
transformer = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])

# download and load training datasets
trainset = datasets.MNIST('./mnist/',download=True,train=True,transform=transformer)
trainloader = torch.utils.data.DataLoader(dataset=trainset,shuffle=True,batch_size=64)

In [3]:
NUM_CLASS = 10
prec=(3, 7, 5)

In [4]:
class SimpleNN(nn.Module):
    def __init__(self, NUM_CLASS=10):
        super(SimpleNN, self).__init__()
        
        # Quad layer
        self.proj1 = nn.Linear(784, 40)
        self.diag1 = nn.Linear(40, NUM_CLASS, bias=False) # why bias false?
        
        # Layer that substitutes argmax function
        self.lin1 = nn.Linear(NUM_CLASS, 32)
        self.lin2 = nn.Linear(32, NUM_CLASS)
        
    def forward(self, x):
        # quad layer
        x = x.view(-1, 784) # flatten the image
        x = self.proj1(x)
        x = x * x # quadratic function
        x = self.diag1(x)
        
        # prediction 
        x = F.relu(x)
        x = self.lin1(x)
        x = F.relu(x)
        out = self.lin2(x)
        
        return F.log_softmax(out, dim=1)

In [5]:
model = SimpleNN()

In [6]:
import torch.optim as optim
opt = optim.Adam(model.parameters())

In [7]:
criterion = nn.CrossEntropyLoss()
for ep in range(2):
    running_loss = 0
    for data, target in iter(trainloader):
        opt.zero_grad()
        
        out = model.forward(data)
        loss = criterion(out,target)
        
        loss.backward()
        opt.step()
        
    print(loss.item())

0.47181209921836853
0.19812050461769104


In [8]:
correct = 0
total = 0
with torch.no_grad():
    for data in trainloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

In [9]:
# accuracy
(correct / total) % 100

0.9473

In [10]:
path = "/home/sukhad/Workspace/GithHub/reading-in-the-dark/mnist/objects/ml_models/simple_char.pt"
torch.save(model.state_dict(), path)

In [11]:
print(model.proj1.weight.shape)

torch.Size([40, 784])


In [12]:
print(model.proj1.bias.shape)

torch.Size([40])


In [13]:
print(model.diag1.weight.shape)

torch.Size([10, 40])
