In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split

In [29]:
# contains popular datasets, model architectures, common image transformation
from torchvision import datasets, transforms

# chain together different transformation
transformer = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])

# download and load training datasets
trainset = datasets.MNIST('./mnist/',download=True,train=True,transform=transformer)
trainloader = torch.utils.data.DataLoader(dataset=trainset,shuffle=True,batch_size=64)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw
Processing...
Done!


In [47]:
NUM_CLASS = 10
prec=(3, 7, 5)

In [31]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        
        # Quad layer
        self.proj1 = nn.Linear(784, 40)
        self.diag1 = nn.Linear(40, NUM_CLASS, bias=False) # why bias false?
        
        # Layer that substitutes argmax function
        self.lin1 = nn.Linear(NUM_CLASS, 32)
        self.lin2 = nn.Linear(32, NUM_CLASS)
        
    def forward(self, x):
        # quad layer
        x = x.view(-1, 784) # flatten the image
        x = self.proj1(x)
        x = x * x # quadratic function
        x = self.diag1(x)
        
        # prediction 
        x = F.relu(x)
        x = self.lin1(x)
        x = F.relu(x)
        out = self.lin2(x)
        
        return F.log_softmax(out, dim=1)

In [32]:
model = SimpleNN()

In [33]:
import torch.optim as optim
opt = optim.Adam(model.parameters())

In [35]:
criterion = nn.CrossEntropyLoss()
for ep in range(2):
    running_loss = 0
    for data, target in iter(trainloader):
        opt.zero_grad()
        
        out = model.forward(data)
        loss = criterion(out,target)
        
        loss.backward()
        opt.step()
        
    print(loss.item())

0.0972975343465805
0.47948095202445984


In [39]:
correct = 0
total = 0
with torch.no_grad():
    for data in trainloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

In [41]:
# accuracy
(correct / total) % 100

0.95

In [48]:
print(model.proj1.weight.shape)

torch.Size([40, 784])


In [49]:
print(model.proj1.bias.shape)

torch.Size([40])


In [46]:
print(model.diag1.weight.shape)

torch.Size([10, 40])


In [None]:
data_prec = 3
proj_param = 