## Python版

元のヤツをgit submodule。  
neuralnet_mnist_int.pyをコピー。

```
git submodule add https://github.com/oreilly-japan/deep-learning-from-scratch.git
```

In [46]:
import neuralnet_mnist_int
import numpy as np
import time

In [2]:
# neuralnet_mnist_int.main()

In [3]:
# x_test, y_test = neuralnet_mnist_int.get_data()

In [4]:
mode = "INT_MODE"
x_test, y_test = neuralnet_mnist_int.get_data(mode=mode)
network = neuralnet_mnist_int.init_network(mode=mode)
py_y_test = np.ndarray(len(x_test), dtype=np.uint8)


In [5]:
%%time
## Python版の結果
for i in range(len(x_test)):
    y,_ = neuralnet_mnist_int.predict(network, x_test[i], mode=mode)
    py_y_test[i] = np.argmax(y) # 最も確率の高い要素のインデックスを取得

CPU times: user 35.7 s, sys: 24.2 ms, total: 35.7 s
Wall time: 35.8 s


In [6]:
py_y_test

array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

## HLS版 全データ

In [50]:
## FPGA Load
from pynq import Overlay
OL = Overlay("/home/xilinx/pynq/overlays/my_design/test_mnist_wrapper.bit")
OL.download()
XLNK = OL.processing_system7_0

## show IPs
print(OL.ip_dict.keys(), XLNK)

dict_keys(['axi_dma', 'axi_dma_0', 'axi_gpio_0', 'nn_0']) <pynq.xlnk.Xlnk object at 0xafd9e990>


In [51]:
# import mnist
# import numpy as np
# x_test = mnist.test_images()
# y_test = mnist.test_labels()
# x_test = x_test.reshape(10000, 28*28)

import neuralnet_mnist_int
import numpy as np
x_test, y_test = neuralnet_mnist_int.get_data()
x_test = x_test.astype(np.uint8)

In [52]:
type(x_test), x_test.shape, x_test.dtype

(numpy.ndarray, (10000, 784), dtype('uint8'))

In [53]:
## Allocate Memory
IMAGE_NUM = x_test.shape[0]
# IMAGE_NUM = 30
print(IMAGE_NUM)
input_buf = XLNK.cma_array([28*28*IMAGE_NUM], np.uint8)
print(hex(input_buf.physical_address))
output_buf = XLNK.cma_array([1*IMAGE_NUM], np.uint8)
print(hex(output_buf.physical_address))

10000
0x18900000
0x18050000


In [54]:
## Write output_buf
for i in range(IMAGE_NUM):
    output_buf[i] = 0xFF
output_buf[0]

255

In [56]:
## Write inpu_buf(DDR)
XLNK.cma_memcopy(input_buf, x_test, 28*28*IMAGE_NUM)

In [75]:
def wait_dma():
    i = 0
    while i<100:
        st = OL.axi_dma_0.register_map.S2MM_DMASR.Halted
#         st = OL.axi_dma.register_map.S2MM_DMASR.Idle
        if st:
            break
        i += 1
        print(f"\rWait for Idle: {i}", end='')
        time.sleep(0.1)
    else:
        print(f"\rTimeOut       ", end='')
    print()


In [76]:
## nn START
OL.nn_0.register_map.CTRL.AUTO_RESTART = 1
OL.nn_0.register_map.CTRL.AP_START = 1
OL.nn_0.register_map.CTRL

Register(AP_START=1, AP_DONE=0, AP_IDLE=0, AP_READY=0, RESERVED_1=0, AUTO_RESTART=1, RESERVED_2=0)

In [77]:
%%time
## DMA Control
## Stop
# OL.axi_dma_0.register_map.MM2S_DMACR = 0x0
# OL.axi_dma_0.register_map.S2MM_DMACR = 0x0
OL.axi_dma_0.register_map.MM2S_DMACR = 0x04
OL.axi_dma_0.register_map.S2MM_DMACR = 0x04

## Run
OL.axi_dma_0.register_map.MM2S_DMACR = 0x1
OL.axi_dma_0.register_map.S2MM_DMACR = 0x1

## Address
OL.axi_dma_0.register_map.MM2S_SA = input_buf.physical_address
OL.axi_dma_0.register_map.S2MM_DA = output_buf.physical_address

## Size
OL.axi_dma_0.register_map.MM2S_LENGTH = 28*28*IMAGE_NUM
OL.axi_dma_0.register_map.S2MM_LENGTH = 1*IMAGE_NUM

wait_dma()

Wait for Idle: 54
CPU times: user 971 ms, sys: 98.9 ms, total: 1.07 s
Wall time: 6.26 s


In [78]:
## DMA MM2S Status
OL.axi_dma_0.register_map.MM2S_DMASR, OL.axi_dma_0.register_map.S2MM_DMASR 

(Register(Halted=1, Idle=0, SGIncld=0, DMAIntErr=0, DMASlvErr=0, DMADecErr=0, SGIntErr=0, SGSlvErr=0, SGDecErr=0, IOC_Irq=1, Dly_Irq=0, Err_Irq=0, IRQThresholdSts=0, IRQDelaySts=0),
 Register(Halted=1, Idle=0, SGIncld=0, DMAIntErr=1, DMASlvErr=0, DMADecErr=0, SGIntErr=0, SGSlvErr=0, SGDecErr=0, IOC_Irq=1, Dly_Irq=0, Err_Irq=1, IRQThresholdSts=0, IRQDelaySts=0))

In [79]:
## DMAの転送サイズのビット幅(Buffer Length Register)が小さすぎたことがあるので確認。
hex(OL.axi_dma_0.register_map.MM2S_LENGTH), hex(28*28*IMAGE_NUM)

('0x77a100', '0x77a100')

In [80]:
N = IMAGE_NUM
ok = 0
for i, (exp, data) in enumerate(zip(y_test, output_buf)):
    if exp==data:
        ok += 1
    ## print(f"{i}: {exp:02X}, {data:02X}, {exp==data}")
print(ok/N)

0.9301


In [81]:
for i, (exp, data) in enumerate(zip(py_y_test, output_buf)):
    if exp!=data:
        print(f"{i}, {exp} {data}")

62, 4 9
92, 4 9
233, 7 8
359, 4 9
468, 7 2
550, 7 9
610, 4 2
624, 2 8
628, 3 9
689, 7 9
760, 4 9
882, 7 9
924, 7 2
951, 5 4
1096, 0 9
1204, 2 8
1325, 6 8
1559, 7 9
1790, 7 9
1813, 5 8
1955, 2 8
1969, 6 2
2125, 4 9
2488, 6 4
2586, 5 3
2698, 5 3
2805, 5 8
2919, 5 8
3030, 6 2
3033, 6 4
3240, 3 8
3288, 4 9
3437, 4 9
3475, 7 3
3503, 1 8
3614, 6 4
3688, 6 8
3929, 5 8
3951, 5 8
4059, 5 8
4254, 5 3
4302, 5 8
4391, 7 9
4427, 2 8
4497, 7 8
4575, 0 2
4583, 5 8
4605, 3 8
4761, 1 8
4808, 5 3
4990, 2 8
5009, 4 9
5020, 5 8
5046, 0 2
5183, 4 8
5642, 5 8
5736, 6 4
5862, 5 3
5926, 4 9
5957, 5 8
6002, 6 4
6045, 3 9
6046, 3 8
6081, 5 8
6093, 2 8
6160, 3 8
6432, 3 8
6555, 7 9
6571, 7 9
6577, 7 2
6625, 4 8
6641, 5 8
6651, 5 8
7372, 5 3
7454, 5 8
7542, 5 8
8072, 5 3
8095, 6 8
8183, 5 8
8279, 6 8
8294, 5 8
8406, 4 9
8410, 6 8
8426, 4 9
8509, 1 8
9031, 7 2
9214, 4 9
9426, 6 4
9446, 6 4
9534, 7 9
9544, 7 9
9655, 3 2
9738, 6 4
9764, 4 8
9792, 4 9
9832, 2 8
9901, 4 9


## 他

In [82]:
## まとめ用
import re

TIME_GET = ["user", "sys", "total", "Wall time"]

def get_sec(value, unit):
    value = float(value)
    if unit=="µs":
        value *= 10**(-6)
    elif unit=="ms":
        value *= 10**(-3)
    return value

def time_formatter_sec(s):
    res = {}
    for tg in TIME_GET:
        m = re.search(f"{tg}.*?([0-9.]+)\s([µm]*s)", s)
        ## print(m)
        res[tg] = get_sec(m.group(1), m.group(2))
    return res

In [85]:
python_result = """
CPU times: user 36.4 s, sys: 15.4 ms, total: 36.4 s
Wall time: 36.5 s
"""

rtl_result = """
CPU times: user 971 ms, sys: 98.9 ms, total: 1.07 s
Wall time: 6.26 s
"""

python_time = time_formatter_sec(python_result)
rtl_time = time_formatter_sec(rtl_result)

for tg in TIME_GET:
    print(f"{tg} | {python_time[tg]:.4f} | {rtl_time[tg]:.4f} |  {python_time[tg] / rtl_time[tg]:.1f}")

user | 36.4000 | 0.9710 |  37.5
sys | 0.0154 | 0.0989 |  0.2
total | 36.4000 | 1.0700 |  34.0
Wall time | 36.5000 | 6.2600 |  5.8


## まとめ

| Python版 | HLS版 | 倍率
-- | -- | --  
user | 36.4000 | 0.9710 |  37.5
sys | 0.0154 | 0.0989 |  0.2
total | 36.4000 | 1.0700 |  34.0
Wall time | 36.5000 | 6.2600 |  5.8

実行時間で、約30倍。(ばらつきアリ)