## Python版

元のヤツをgit submodule。  
neuralnet_mnist_int.pyをコピー。

```
git submodule add https://github.com/oreilly-japan/deep-learning-from-scratch.git
```

In [1]:
import neuralnet_mnist_int
import numpy as np

In [2]:
# neuralnet_mnist_int.main()

In [3]:
# x_test, y_test = neuralnet_mnist_int.get_data()

In [4]:
mode = "INT_MODE"
x_test, y_test = neuralnet_mnist_int.get_data(mode=mode)
network = neuralnet_mnist_int.init_network(mode=mode)
py_y_test = np.ndarray(len(x_test), dtype=np.uint8)


In [5]:
%%time
## Python版の結果
for i in range(len(x_test)):
    y,_ = neuralnet_mnist_int.predict(network, x_test[i], mode=mode)
    py_y_test[i] = np.argmax(y) # 最も確率の高い要素のインデックスを取得

CPU times: user 35.8 s, sys: 37.8 ms, total: 35.8 s
Wall time: 35.9 s


In [6]:
py_y_test

array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

## FPGA版 全データ

In [41]:
## FPGA Load
from pynq import Overlay
OL = Overlay("/home/xilinx/pynq/overlays/my_design/test_mnist_wrapper.bit")
OL.download()
XLNK = OL.processing_system7_0

## show IPs
print(OL.ip_dict.keys(), XLNK)

dict_keys(['axi_dma', 'axi_gpio_0']) <pynq.xlnk.Xlnk object at 0xafd6f950>


In [42]:
# import mnist
# import numpy as np
# x_test = mnist.test_images()
# y_test = mnist.test_labels()
# x_test = x_test.reshape(10000, 28*28)

import neuralnet_mnist_int
import numpy as np
x_test, y_test = neuralnet_mnist_int.get_data()
x_test = x_test.astype(np.uint8)

In [43]:
type(x_test), x_test.shape, x_test.dtype

(numpy.ndarray, (10000, 784), dtype('uint8'))

In [44]:
## Allocate Memory
IMAGE_NUM = x_test.shape[0]
# IMAGE_NUM = 30
print(IMAGE_NUM)
input_buf = XLNK.cma_array([28*28*IMAGE_NUM], np.uint8)
print(hex(input_buf.physical_address))
output_buf = XLNK.cma_array([1*IMAGE_NUM], np.uint8)
print(hex(output_buf.physical_address))

10000
0x18100000
0x1804c000


In [45]:
## Write output_buf
for i in range(IMAGE_NUM):
    output_buf[i] = 0xFF
output_buf[0]

255

In [46]:
# %%time
# adr  = 0
# for i in range(0, IMAGE_NUM):
#     for t in test_x[i]:
#         input_buf[adr] = t
#         adr += 1
#     print(f"\r i={i}", end='')
# print()

In [47]:
## Write inpu_buf(DDR)
XLNK.cma_memcopy(input_buf, x_test, 28*28*IMAGE_NUM)

In [48]:
# N = 0
# for i in input_buf[(28*28)*N:(28*28)*(N+1)]:
#     print(f"{i:02X}")

In [49]:
def wait_dma():
    i = 0
    while i<100:
        st = OL.axi_dma.register_map.S2MM_DMASR.Idle
        if st:
            break
        i += 1
        print(f"\rWait for Idle: {i}", end='')
    print()


In [60]:
%%time
## DMA Control
## Stop
OL.axi_dma.register_map.MM2S_DMACR = 0x0
OL.axi_dma.register_map.S2MM_DMACR = 0x0

## Run
OL.axi_dma.register_map.MM2S_DMACR = 0x1
OL.axi_dma.register_map.S2MM_DMACR = 0x1

## Address
OL.axi_dma.register_map.MM2S_SA = input_buf.physical_address
OL.axi_dma.register_map.S2MM_DA = output_buf.physical_address

## Size
# OL.axi_dma.register_map.MM2S_LENGTH = 28*28*20
# OL.axi_dma.register_map.S2MM_LENGTH = 1*20
OL.axi_dma.register_map.MM2S_LENGTH = 28*28*IMAGE_NUM
OL.axi_dma.register_map.S2MM_LENGTH = 1*IMAGE_NUM

wait_dma()

Wait for Idle: 1Wait for Idle: 2Wait for Idle: 3
CPU times: user 94.8 ms, sys: 28.4 ms, total: 123 ms
Wall time: 111 ms


In [51]:
## DMA MM2S Status
OL.axi_dma.register_map.MM2S_DMASR, OL.axi_dma.register_map.S2MM_DMASR 

(Register(Halted=0, Idle=1, SGIncld=0, DMAIntErr=0, DMASlvErr=0, DMADecErr=0, SGIntErr=0, SGSlvErr=0, SGDecErr=0, IOC_Irq=1, Dly_Irq=0, Err_Irq=0, IRQThresholdSts=0, IRQDelaySts=0),
 Register(Halted=0, Idle=1, SGIncld=0, DMAIntErr=0, DMASlvErr=0, DMADecErr=0, SGIntErr=0, SGSlvErr=0, SGDecErr=0, IOC_Irq=1, Dly_Irq=0, Err_Irq=0, IRQThresholdSts=0, IRQDelaySts=0))

In [52]:
hex(OL.axi_dma.register_map.MM2S_LENGTH), hex(28*28*IMAGE_NUM), hex(28*28*21), hex(28*28*10000)

('0x77a100', '0x77a100', '0x4050', '0x77a100')

In [53]:
N = IMAGE_NUM
ok = 0
for i, (exp, data) in enumerate(zip(y_test, output_buf)):
    if exp==data:
        ok += 1
    ## print(f"{i}: {exp:02X}, {data:02X}, {exp==data}")
print(ok/N)

0.9325


In [54]:
for i, (exp, data) in enumerate(zip(py_y_test, output_buf)):
    if exp!=data:
        print(f"{i}, {exp} {data}")

## 他

In [37]:
## まとめ用
import re

TIME_GET = ["user", "sys", "total", "Wall time"]

def get_sec(value, unit):
    value = float(value)
    if unit=="µs":
        value *= 10**(-6)
    elif unit=="ms":
        value *= 10**(-3)
    return value

def time_formatter_sec(s):
    res = {}
    for tg in TIME_GET:
        m = re.search(f"{tg}.*?([0-9.]+)\s([µm]*s)", s)
        ## print(m)
        res[tg] = get_sec(m.group(1), m.group(2))
    return res

In [66]:
python_result = """
CPU times: user 36.4 s, sys: 15.4 ms, total: 36.4 s
Wall time: 36.5 s
"""

rtl_para1_result = """
CPU times: user 191 ms, sys: 50.7 ms, total: 242 ms
Wall time: 246 ms
"""

rtl_result = """
CPU times: user 94.8 ms, sys: 28.4 ms, total: 123 ms
Wall time: 111 ms
"""

python_time = time_formatter_sec(python_result)
rtl_para1_time = time_formatter_sec(rtl_para1_result)
rtl_time = time_formatter_sec(rtl_result)

for tg in TIME_GET:
    print(f"{tg} | {python_time[tg]:.4f} | {rtl_para1_time[tg]:.4f} | {rtl_time[tg]:.4f} |  {python_time[tg] / rtl_time[tg]:.1f}")

user | 36.4000 | 0.1910 | 0.0948 |  384.0
sys | 0.0154 | 0.0507 | 0.0284 |  0.5
total | 36.4000 | 0.2420 | 0.1230 |  295.9
Wall time | 36.5000 | 0.2460 | 0.1110 |  328.8


## まとめ

| Python版 | RTL(para1)版 | RTL版 | 倍率
-- | -- | --  
user | 36.4000 | 0.1910 | 0.0948 |  384.0
sys | 0.0154 | 0.0507 | 0.0284 |  0.5
total | 36.4000 | 0.2420 | 0.1230 |  295.9
Wall time | 36.5000 | 0.2460 | 0.1110 |  328.8

実行時間で、約300倍。(ばらつきアリ)