In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np
from sklearn.model_selection import train_test_split
%matplotlib inline

In [None]:
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [4]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


## define network


In [9]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64,32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=64, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=64, out_features=32, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=32, out_features=10, bias=True)
  )
)


In [10]:
loss_fn = nn.CrossEntropyLoss()

# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

In [11]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [12]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.320206  [    0/60000]
loss: 0.523348  [ 6400/60000]
loss: 0.370266  [12800/60000]
loss: 0.394129  [19200/60000]
loss: 0.264072  [25600/60000]
loss: 0.353106  [32000/60000]
loss: 0.187043  [38400/60000]
loss: 0.357616  [44800/60000]
loss: 0.305204  [51200/60000]
loss: 0.312075  [57600/60000]
Test Error: 
 Accuracy: 93.5%, Avg loss: 0.212408 

Epoch 2
-------------------------------
loss: 0.172859  [    0/60000]
loss: 0.187457  [ 6400/60000]
loss: 0.131854  [12800/60000]
loss: 0.291529  [19200/60000]
loss: 0.174217  [25600/60000]
loss: 0.270751  [32000/60000]
loss: 0.102506  [38400/60000]
loss: 0.284239  [44800/60000]
loss: 0.239804  [51200/60000]
loss: 0.247922  [57600/60000]
Test Error: 
 Accuracy: 95.2%, Avg loss: 0.155048 

Epoch 3
-------------------------------
loss: 0.112985  [    0/60000]
loss: 0.117993  [ 6400/60000]
loss: 0.105089  [12800/60000]
loss: 0.189570  [19200/60000]
loss: 0.127416  [25600/60000]
loss: 0.206800  [32000/600

## save model


In [13]:
torch.save(model.state_dict(), "model.pth")

## load model


In [5]:
model = torch.load( "model.pth")

FileNotFoundError: [Errno 2] No such file or directory: 'model.pth'

In [5]:
lrsw0=model["linear_relu_stack.0.weight"]
lrsb0=model["linear_relu_stack.0.bias"]
lrsw2=model["linear_relu_stack.2.weight"]
lrsb2=model["linear_relu_stack.2.bias"]
lrsw4=model["linear_relu_stack.4.weight"]
lrsb4=model["linear_relu_stack.4.bias"]

In [6]:
lrsw0=lrsw0.reshape(-1)
lrsb0=lrsb0.reshape(-1)
lrsw2=lrsw2.reshape(-1)
lrsb2=lrsb2.reshape(-1)

lrsw4=lrsw4.reshape(-1)


print(lrsw0.shape)
print(lrsb0.shape)
print(lrsw2.shape)
print(lrsb2.shape)

print(lrsw4.shape)

torch.Size([50176])
torch.Size([64])
torch.Size([2048])
torch.Size([32])
torch.Size([320])


In [7]:
np_lrsw0=np.asarray(lrsw0.cpu())
np_lrsb0=np.asarray(lrsb0.cpu())
np_lrsw2=np.asarray(lrsw2.cpu())
np_lrsb2=np.asarray(lrsb2.cpu())
np_lrsw4=np.asarray(lrsw4.cpu())
np_lrsb4=np.asarray(lrsb4.cpu())

np.save("lrsw0",np_lrsw0)
np.save("lrsb0",np_lrsb0)
np.save("lrsw2",np_lrsw2)
np.save("lrsb2",np_lrsb2)
np.save("lrsw4",np_lrsw4)
np.save("lrsb4",np_lrsb4)


In [8]:
img_np=np.zeros((1000,784))
lable_np=np.zeros(1000)
for i in range(1000):
    img,label=test_data[i]    
    img_np[i]=np.asarray(img).reshape(-1)
    lable_np[i]=label

np.save("img",img_np)
np.save("label",lable_np)

In [9]:
lrsw0=lrsw0.reshape(-1)
lrsb0=lrsb0.reshape(-1)
lrsw2=lrsw2.reshape(-1)
lrsb2=lrsb2.reshape(-1)
lrsw4=lrsw4.reshape(-1)



In [10]:
print(lrsw0.shape)
print(lrsb0.shape)
print(lrsw2.shape)
print(lrsb2.shape)
print(lrsw4.shape)

torch.Size([50176])
torch.Size([64])
torch.Size([2048])
torch.Size([32])
torch.Size([320])


In [11]:
np_lrsw0=np.asarray(lrsw0.cpu())
np_lrsb0=np.asarray(lrsb0.cpu())
np_lrsw2=np.asarray(lrsw2.cpu())
np_lrsb2=np.asarray(lrsb2.cpu())
np_lrsw4=np.asarray(lrsw4.cpu())
np_lrsb4=np.asarray(lrsb4.cpu())


In [12]:
lrsw0_len=lrsw0.size(dim=0)
lrsb0_len=lrsb0.size(dim=0)
lrsw2_len=lrsw2.size(dim=0)
lrsb2_len=lrsb2.size(dim=0)
lrsw4_len=lrsw4.size(dim=0)
lrsb4_len=lrsb4.size(dim=0)

# C++ program validation

In [22]:
hit=0
for r in range(100):
    img,label=test_data[r]
    img=img.reshape(-1)
    #img=np.ones(784)*0.5
    out1=np.zeros(lrsb0_len)
    out2=np.zeros(lrsb2_len)
    out3=np.zeros(lrsb4_len)
    
    for i in range(lrsb0_len):
        for j in range(784):
            out1[i]+=lrsw0[i*784+j]*img[j]
        out1[i]+=lrsb0[i]
    for i in range(lrsb0_len):
        if(out1[i]<0):
            out1[i]=0

    for i in range(lrsb2_len):
        for j in range(lrsb0_len):
            out2[i]+=lrsw2[i*lrsb0_len+j]*out1[j]
        out2[i]+=lrsb2[i]
    for i in range(lrsb2_len):
        if(out2[i]<0):
            out2[i]=0

    for i in range(lrsb4_len):
        for j in range(lrsb2_len):
            out3[i]+=lrsw4[i*lrsb2_len+j]*out2[j]
        out3[i]+=lrsb4[i]
    if(out3.argmax()==label):
        hit+=1
    print(out3)
    #print(out1)
    #print(out2)
    print("now:",r,"hit:",hit)
print(hit)

[ -4.92369461  -7.66485691  -0.98552698   1.12712586 -18.26786041
  -3.78585649 -23.27987862  10.22897148  -3.02274179  -1.70486081]
now: 0 hit: 1
[-12.83986282   6.93760252  13.18755817   6.52630424 -25.29696465
  -6.38394022 -15.25874615 -10.16630554  -2.74610853 -20.17750931]
now: 1 hit: 2
[-7.39804792  9.94595909 -1.10424376 -5.06233072 -2.85030746 -5.60874224
 -3.17454672  0.58857054 -1.30212307 -5.22882318]
now: 2 hit: 3
[ 11.11396503  -7.5778079   -1.61776733  -4.81179857 -15.71409988
  -2.94385266  -4.53554869  -3.4479239  -12.57793713  -2.34860802]
now: 3 hit: 4
[-3.94347453 -9.67495155 -3.94474721 -6.39585495  9.06172466 -4.07384443
 -6.29758787 -1.76892364 -3.47377467  2.11102986]
now: 4 hit: 5
[-9.41569233 13.15397263 -3.75986767 -5.50849581 -3.9629488  -9.46448898
 -8.59140682  3.56472158 -3.70649743 -4.90757608]
now: 5 hit: 6
[-11.71393108 -12.06459713 -11.61721897  -5.48554516  11.11864376
  -3.66878796 -10.74733162  -4.9389534    1.59216356  -2.54225016]
now: 6 hit: 7
[

KeyboardInterrupt: 

In [13]:
for idx ,ele in enumerate(lrsw0):
    lrsw0[idx]=int((ele+1)/2*255-127)
for idx ,ele in enumerate(lrsb0):
    lrsb0[idx]=int((ele+1)/2*255-127)
for idx ,ele in enumerate(lrsw2):
    lrsw2[idx]=int((ele+1)/2*255-127)
for idx ,ele in enumerate(lrsb2):
    lrsb2[idx]=int((ele+1)/2*255-127)
for idx ,ele in enumerate(lrsw4):
    lrsw4[idx]=int((ele+1)/2*255-127)
for idx ,ele in enumerate(lrsb4):
    lrsb4[idx]=int((ele+1)/2*255-127)

In [6]:
#turn img from float to int
def int_conver(np_a):
    for idx,ele in enumerate(np_a):
        np_a[idx]=int((ele+1)/2*255-127)

    return np_a

## Test overflow underflow

In [36]:
hit=0
for r in range(100):
    img,label=test_data[r]
    img=img.reshape(-1)
    img=int_conver(img)
    #img=np.ones(784)
    out1=np.zeros(lrsb0_len)
    out2=np.zeros(lrsb2_len)
    out3=np.zeros(lrsb4_len)
    
    for i in range(lrsb0_len):
        for j in range(784):
            out1[i]+=lrsw0[i*784+j]*img[j]
        out1[i]+=lrsb0[i]
    for i in range(lrsb0_len):
        if(out1[i]<0):
            out1[i]=0
##########debug begin #################
    for ele in out1:
        if(ele>0 and ele>2147483647):
            print("overflow out1:",out1)
    for ele in out1:
        if(ele<0 and ele<-2147483648):
            print("underflow out1:",out1) 
    for idx,ele in enumerate(out1):
        out1[idx]=int((out1[idx]/2**14))
    
##########debug end#################
    for i in range(lrsb2_len):
        for j in range(lrsb0_len):
            out2[i]+=lrsw2[i*lrsb0_len+j]*out1[j]
        out2[i]+=lrsb2[i]
    for i in range(lrsb2_len):
        if(out2[i]<0):
            out2[i]=0
##########debug begin #################

    for ele in out2:
        if(ele>0 and ele>2147483647):
            print("overflow out2:",out2)
    for ele in out2:
        if(ele<0 and ele<-2147483648):
            print("underflow out2:",out2) 



##########debug end#################
    for i in range(lrsb4_len):
        for j in range(lrsb2_len):
            out3[i]+=lrsw4[i*lrsb2_len+j]*out2[j]
        out3[i]+=lrsb4[i]
##########debug begin #################
    for ele in out3:
        if(ele>0 and ele>2147483647):
            print("overflow out3:",out3)
    for ele in out3:
        if(ele<0 and ele<-2147483648):
            print("underflow out3:",out3) 

##########debug end##########
    if(out3.argmax()==label):
        hit+=1
    print(out3)
    #print(out1)
    #print(out2)
    print("now:",r,"hit:",hit)
print(hit)

[ -70624.  -98242.  -10629.   34992. -274669.  -57607. -335702.  152348.
  -34606.  -30507.]
now: 0 hit: 1
[-217811.  122613.  199276.  119135. -416930.  -86434. -239267. -172297.
  -33127. -306899.]
now: 1 hit: 2
[-88264. 137708.  -7069. -66690. -41208. -64176. -30171.   7647. -18517.
 -78782.]
now: 2 hit: 3
[ 194821. -126154.  -23870.  -58429. -236840.  -62416.  -80430.  -59965.
 -205584.  -45231.]
now: 3 hit: 4
[ -64593. -129611.  -56257.  -69673.  135720.  -57840.  -91932.  -27020.
  -45389.   23976.]
now: 4 hit: 5
[-117884.  191689.  -48317.  -68117.  -42777. -131649. -100324.   47537.
  -57112.  -69832.]
now: 5 hit: 6
[-168924. -166428. -176691.  -69623.  148713.  -26253. -134392.  -76397.
   30688.  -53954.]
now: 6 hit: 7
[-187435.  -91829.  -95538.   54594.  -14310.  -92280. -321222.   -8418.
  -64570.  106749.]
now: 7 hit: 8
[ -91873.  -72935. -100869.  -86448.  -49622.   83270.   40498. -152648.
  -11853.  -22223.]
now: 8 hit: 9
[-192363. -198178. -235024.   -1136.   44732. -

: 

## Generate cpp file for c sim in vitis hls & cpp file for ROM

In [28]:
np_lrsw0=np.asarray(lrsw0.cpu())
np_lrsb0=np.asarray(lrsb0.cpu())
np_lrsw2=np.asarray(lrsw2.cpu())
np_lrsb2=np.asarray(lrsb2.cpu())
np_lrsw4=np.asarray(lrsw4.cpu())
np_lrsb4=np.asarray(lrsb4.cpu())


In [29]:
f=open('lrsw0.cpp','w')
f.write("int lrsw0[]={")
for i in range(np_lrsw0.shape[0]):
    f.write(str(int(np_lrsw0[i])))
    if(i==np_lrsw0.shape[0]-1):
        pass
    else:
        f.write(',')
f.write('};')
    

2

In [30]:
f=open('lrsb0.cpp','w')
f.write("int lrsb0[]={")
for i in range(np_lrsb0.shape[0]):
    f.write(str(int(np_lrsb0[i])))
    if(i==np_lrsb0.shape[0]-1):
        pass
    else:
        f.write(',')
f.write('};')

2

In [31]:
f=open('lrsw2.cpp','w')
f.write("int lrsw2[]={")
for i in range(np_lrsw2.shape[0]):
    f.write(str(int(np_lrsw2[i])))
    if(i==np_lrsw2.shape[0]-1):
        pass
    else:
        f.write(',')
f.write('};')
    

2

In [32]:
f=open('lrsb2.cpp','w')
f.write("int lrsb2[]={")
for i in range(np_lrsb2.shape[0]):
    f.write(str(int(np_lrsb2[i])))
    if(i==np_lrsb2.shape[0]-1):
        pass
    else:
        f.write(',')
f.write('};')


2

In [33]:
f=open('lrsb4.cpp','w')
f.write("int lrsb4[]={")
for i in range(np_lrsb4.shape[0]):
    f.write(str(int(np_lrsb4[i])))
    if(i==np_lrsb4.shape[0]-1):
        pass
    else:
        f.write(',')
f.write('};')




2

In [35]:
f=open('lrsw4.cpp','w')
f.write("int lrsw4[]={")
for i in range(np_lrsw4.shape[0]):
    f.write(str(int(np_lrsw4[i])))
    if(i==np_lrsw4.shape[0]-1):
        pass
    else:
        f.write(',')
f.write('};')

2

In [7]:
for i in range(1000):
    img,label=test_data[i]
    img=img.reshape(-1)
    img=int_conver(img)
    