In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np
from sklearn.model_selection import train_test_split
%matplotlib inline

In [3]:
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [4]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


## define network


In [106]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64,32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=32, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=32, out_features=10, bias=True)
  )
)


In [107]:
loss_fn = nn.CrossEntropyLoss()

# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

In [108]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [109]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.308667  [    0/60000]
loss: 0.559440  [ 6400/60000]
loss: 0.364937  [12800/60000]
loss: 0.356125  [19200/60000]
loss: 0.264652  [25600/60000]
loss: 0.363730  [32000/60000]
loss: 0.195791  [38400/60000]
loss: 0.368445  [44800/60000]
loss: 0.351554  [51200/60000]
loss: 0.338255  [57600/60000]
Test Error: 
 Accuracy: 93.1%, Avg loss: 0.228729 

Epoch 2
-------------------------------
loss: 0.156071  [    0/60000]
loss: 0.225682  [ 6400/60000]
loss: 0.116495  [12800/60000]
loss: 0.307438  [19200/60000]
loss: 0.177098  [25600/60000]
loss: 0.323541  [32000/60000]
loss: 0.099816  [38400/60000]
loss: 0.271824  [44800/60000]
loss: 0.220665  [51200/60000]
loss: 0.261181  [57600/60000]
Test Error: 
 Accuracy: 95.3%, Avg loss: 0.158525 

Epoch 3
-------------------------------
loss: 0.098936  [    0/60000]
loss: 0.162150  [ 6400/60000]
loss: 0.092269  [12800/60000]
loss: 0.259521  [19200/60000]
loss: 0.123687  [25600/60000]
loss: 0.265576  [32000/600

## save model


In [110]:
torch.save(model.state_dict(), "model.pth")

## load model


In [5]:
model = torch.load( "model.pth")

In [6]:
lrsw0=model["linear_relu_stack.0.weight"]
lrsb0=model["linear_relu_stack.0.bias"]
lrsw2=model["linear_relu_stack.2.weight"]
lrsb2=model["linear_relu_stack.2.bias"]
lrsw4=model["linear_relu_stack.4.weight"]
lrsb4=model["linear_relu_stack.4.bias"]

In [7]:
lrsw0=lrsw0.reshape(-1)
lrsb0=lrsb0.reshape(-1)
lrsw2=lrsw2.reshape(-1)
lrsb2=lrsb2.reshape(-1)

lrsw4=lrsw4.reshape(-1)


print(lrsw0.shape)
print(lrsb0.shape)
print(lrsw2.shape)
print(lrsb2.shape)

print(lrsw4.shape)

torch.Size([50176])
torch.Size([64])
torch.Size([2048])
torch.Size([32])
torch.Size([320])


In [8]:
np_lrsw0=np.asarray(lrsw0.cpu())
np_lrsb0=np.asarray(lrsb0.cpu())
np_lrsw2=np.asarray(lrsw2.cpu())
np_lrsb2=np.asarray(lrsb2.cpu())
np_lrsw4=np.asarray(lrsw4.cpu())
np_lrsb4=np.asarray(lrsb4.cpu())

np.save("lrsw0",np_lrsw0)
np.save("lrsb0",np_lrsb0)
np.save("lrsw2",np_lrsw2)
np.save("lrsb2",np_lrsb2)
np.save("lrsw4",np_lrsw4)
np.save("lrsb4",np_lrsb4)


In [9]:
img_np=np.zeros((1000,784))
lable_np=np.zeros(1000)
for i in range(1000):
    img,label=test_data[i]    
    img_np[i]=np.asarray(img).reshape(-1)
    lable_np[i]=label

np.save("img",img_np)
np.save("label",lable_np)

In [10]:
lrsw0=lrsw0.reshape(-1)
lrsb0=lrsb0.reshape(-1)
lrsw2=lrsw2.reshape(-1)
lrsb2=lrsb2.reshape(-1)
lrsw4=lrsw4.reshape(-1)



In [11]:
print(lrsw0.shape)
print(lrsb0.shape)
print(lrsw2.shape)
print(lrsb2.shape)
print(lrsw4.shape)

torch.Size([50176])
torch.Size([64])
torch.Size([2048])
torch.Size([32])
torch.Size([320])


In [12]:
np_lrsw0=np.asarray(lrsw0.cpu())
np_lrsb0=np.asarray(lrsb0.cpu())
np_lrsw2=np.asarray(lrsw2.cpu())
np_lrsb2=np.asarray(lrsb2.cpu())
np_lrsw4=np.asarray(lrsw4.cpu())
np_lrsb4=np.asarray(lrsb4.cpu())


In [13]:
lrsw0_len=lrsw0.size(dim=0)
lrsb0_len=lrsb0.size(dim=0)
lrsw2_len=lrsw2.size(dim=0)
lrsb2_len=lrsb2.size(dim=0)
lrsw4_len=lrsw4.size(dim=0)
lrsb4_len=lrsb4.size(dim=0)

# C++ program validation

In [14]:
hit=0
for r in range(2):
    img,label=test_data[r]
    img=img.reshape(-1)
    out1=np.zeros(lrsb0_len)
    out2=np.zeros(lrsb2_len)
    out3=np.zeros(lrsb4_len)
    
    for i in range(lrsb0_len):
        for j in range(784):
            out1[i]+=lrsw0[i*784+j]*img[j]
        out1[i]+=lrsb0[i]
    for i in range(lrsb0_len):
        if(out1[i]<0):
            out1[i]=0

    for i in range(lrsb2_len):
        for j in range(lrsb0_len):
            out2[i]+=lrsw2[i*lrsb0_len+j]*out1[j]
        out2[i]+=lrsb2[i]
    for i in range(lrsb2_len):
        if(out2[i]<0):
            out2[i]=0

    for i in range(lrsb4_len):
        for j in range(lrsb2_len):
            out3[i]+=lrsw4[i*lrsb2_len+j]*out2[j]
        out3[i]+=lrsb4[i]
    if(out3.argmax()==label):
        hit+=1
    print(out3)
    print(out1)
    print(out2)
    print("now:",r,"hit:",hit)
print(hit)

[ -3.16065717  -7.64282131  -1.88874662   0.7816208  -10.37539387
  -8.21591949 -16.52468491  13.17329597  -3.39338565  -3.45833802]
[0.         0.         2.20556116 1.6693989  0.         1.25749528
 0.         0.83755994 1.37699878 0.         1.73201311 0.30123317
 2.56744552 2.18847084 4.4123888  0.         0.         3.07979155
 4.31433964 0.         0.68415785 3.63678288 0.41284668 0.
 1.01892638 0.         2.06896257 0.09673144 0.         1.53036249
 2.41919732 2.99665618 0.         0.         1.21433258 0.
 0.74359769 0.         0.         0.         0.55924273 0.
 0.         0.10054541 2.06047034 0.         0.         0.
 0.59952945 2.69664383 2.56162834 0.         2.6689043  2.90566063
 0.         1.25290632 0.         0.         0.94321156 3.20101738
 1.79623199 1.58700264 1.22447836 1.46508384]
[ 7.62313128  4.88743114  0.07896322  4.44060612  0.          4.97876453
  3.74316478  6.58791733  2.50065422  5.69034719  0.67534876  0.
  3.24253464  0.36259326  2.63701248  0.     