In [None]:
"""
   Copyright 2022 shts

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
 """
 print()

In [None]:
import numpy as np
from tqdm import tqdm

from pynq import Overlay
import pynq.lib.dma
from pynq import allocate

OL = Overlay("./design_1.bit")
print(OL.ip_dict.keys())
dma_w = OL.axi_dma_0
dma_x = OL.axi_dma_1
dma_b = OL.axi_dma_2
dma_y = OL.axi_dma_3
FMA_UNIT = OL.FMA_Unit_0

# bus_width = 1024
bus_width = 64
array_type = np.float32
data_bit_width = np.dtype(array_type).itemsize * 8

fraction = 12
integer = data_bit_width - fraction
enable_fp2fix = False

loops = int(1e+5)
# loops = int(1e+0)
test_times = 5
data_length = bus_width // data_bit_width * loops
print("data_length :", data_length)
print("data_bit_width :", data_bit_width)

In [None]:
data_x = allocate((data_length,), dtype=array_type)
data_w = allocate((data_length,), dtype=array_type)
data_b = allocate((data_length,), dtype=array_type)    
data_y = allocate((data_length,), dtype=array_type)

print("size of data_w :", data_w.nbytes, "Byte")
print("size of data_x :", data_x.nbytes, "Byte")
print("size of data_b :", data_b.nbytes, "Byte")
print("size of data_y :", data_y.nbytes, "Byte")

In [21]:
def fp2fix(array, enable=False):
    if enable is True:
        _buff = array * (2 ** fraction)
        _buff = _buff.astype(array_type)
        return _buff
    else:
        return array

def fix2fp(array, enable=False):
    if enable is True:
        _buff = array.astype(np.float64)
        _buff = _buff / (2 ** fraction)
        return _buff
    else:
        return array

In [None]:
bar = tqdm(total = test_times)

_mean = 0
_scale = 2.5
time_record = []
import time
data_y = allocate((data_length,), dtype=array_type)
true_y = allocate((data_length,), dtype=array_type)
for _ in range(test_times):

    _data_w = np.random.normal(_mean, _scale, data_length).astype(array_type)
    _data_x = np.random.normal(_mean, _scale, data_length).astype(array_type)
    _data_b = np.random.normal(_mean, _scale, data_length).astype(array_type)
    
    data_w[:] = fp2fix(_data_w, enable=enable_fp2fix)
    data_x[:] = fp2fix(_data_x, enable=enable_fp2fix)
    data_b[:] = fp2fix(_data_b, enable=enable_fp2fix)

    if enable_fp2fix is True:
        _vec_mul = data_w.astype(np.uint32) * data_x.astype(np.uint32)
        _vec_mul = (_vec_mul >> fraction).astype(array_type)
        true_y[:] = _vec_mul + data_b
    else:
        true_y[:] = data_w * data_x + data_b
    data_y[:] = np.zeros((data_length,), dtype=array_type)
    
    #     start_send = time.perf_counter()
    dma_w.sendchannel.transfer(data_w)
    dma_x.sendchannel.transfer(data_x)
    dma_y.recvchannel.transfer(data_y)
    dma_b.sendchannel.transfer(data_b)
    FMA_UNIT.write(0x10, loops)

    start_ip = time.perf_counter()
    FMA_UNIT.write(0x00, 0x01)
    dma_w.sendchannel.wait()
    dma_x.sendchannel.wait()
    dma_b.sendchannel.wait()
    #     print("send done")
    #     end_send = time.perf_counter()

    dma_y.recvchannel.wait()
    #     print("receive done")
    end_ip = time.perf_counter()

    #     time_record.append({"start_send": start_send,
    #                         "start_ip": start_ip,
    #                         "end_send": end_send,
    #                         "end_ip": end_ip})

    time_record.append({"start_ip": start_ip,
                        "end_ip": end_ip})

    for cnt, v in enumerate(zip(data_y, true_y)):
#         print(v)
        if v[0] != v[1]:
            print("loops : {}, cnt : {}, data_y : {}, true_y : {}".format(_, cnt, v[0], v[1]))
    bar.update(1)
    

In [None]:
for i, kv in enumerate(time_record):
    print("iter : {}".format(i))
#     print("send_time : {}[ms]".format((kv["end_send"] - kv["start_send"]) * 1000))
    print("ip_run_time : {}[ms]".format((kv["end_ip"] - kv["start_ip"]) * 1000))
    print("{}[flops]".format(1.0 / (kv["end_ip"] - kv["start_ip"]) * data_length * 2))
    print()

In [None]:
"""
iter : 0
ip_run_time : 2.778314000352111[ms]
143972207.58679754[flops]

iter : 1
ip_run_time : 2.83615100033785[ms]
141036214.20451555[flops]

iter : 2
ip_run_time : 2.7944519997618045[ms]
143140766.07295296[flops]

iter : 3
ip_run_time : 2.7934870004173717[ms]
143190213.50027275[flops]

iter : 4
ip_run_time : 2.83070699970267[ms]
141307454.3009979[flops]

fp32_bus64_loops_1e+5
"""

"""
iter : 0
ip_run_time : 2.8353319994494086[ms]
282153906.5461651[flops]

iter : 1
ip_run_time : 2.8268429996387567[ms]
283001213.7576202[flops]

iter : 2
ip_run_time : 2.8600340001503355[ms]
279716954.3991256[flops]

iter : 3
ip_run_time : 2.8105040000809822[ms]
284646454.862526[flops]

iter : 4
ip_run_time : 2.879313999983424[ms]
277843958.6667538[flops]
fix16_bus64_loops_1e+5
"""
"""
iter : 0
ip_run_time : 4.550431000097888[ms]
175807522.404535[flops]

iter : 1
ip_run_time : 4.5517319995269645[ms]
175757272.19509837[flops]

iter : 2
ip_run_time : 4.520722999586724[ms]
176962844.23379502[flops]

iter : 3
ip_run_time : 4.629034000572574[ms]
172822234.59604025[flops]

iter : 4
ip_run_time : 4.514643999755208[ms]
177201125.94556236[flops]

fp32_bus128_loops_1e+5
"""
"""
iter : 0
ip_run_time : 4.60764000126801[ms]
347249350.98221314[flops]

iter : 1
ip_run_time : 4.545516998405219[ms]
351995163.7099487[flops]

iter : 2
ip_run_time : 4.60070699955395[ms]
347772635.84816945[flops]

iter : 3
ip_run_time : 4.548297998553608[ms]
351779940.6522642[flops]

iter : 4
ip_run_time : 4.618377999577206[ms]
346441975.98084736[flops]
fix_bus128_loops_1e+5
"""
"""
iter : 0
ip_run_time : 7.795691999490373[ms]
205241561.63488716[flops]

iter : 1
ip_run_time : 7.826987999578705[ms]
204420908.7948163[flops]

iter : 2
ip_run_time : 7.836375001716078[ms]
204176037.9830748[flops]

iter : 3
ip_run_time : 7.789170000251033[ms]
205413413.74606463[flops]

iter : 4
ip_run_time : 7.835348000298836[ms]
204202799.918903[flops]
fp32_bus256_loops_1e+5
"""
"""
iter : 0
ip_run_time : 14.90842199996223[ms]
214643776.51827317[flops]

iter : 1
ip_run_time : 14.908306999927845[ms]
214645432.2422719[flops]

iter : 2
ip_run_time : 14.875104999987343[ms]
215124531.89424363[flops]

iter : 3
ip_run_time : 14.865406000012626[ms]
215264890.84773615[flops]

iter : 4
ip_run_time : 14.92853299998842[ms]
214354618.769472[flops]
fp32_bus512_loops_1e+5
"""

print()