In [1]:
from pynq import Overlay
from pynq import allocate
import numpy as np
import time

In [2]:
ol = Overlay('./sha_26.bit')
ol.ip_dict

{'axi_dma_end_len': {'type': 'xilinx.com:ip:axi_dma:7.1',
  'mem_id': 'S_AXI_LITE',
  'memtype': 'REGISTER',
  'gpio': {},
  'interrupts': {},
  'parameters': {'C_S_AXI_LITE_ADDR_WIDTH': '10',
   'C_S_AXI_LITE_DATA_WIDTH': '32',
   'C_DLYTMR_RESOLUTION': '125',
   'C_PRMRY_IS_ACLK_ASYNC': '1',
   'C_ENABLE_MULTI_CHANNEL': '0',
   'C_NUM_MM2S_CHANNELS': '1',
   'C_NUM_S2MM_CHANNELS': '1',
   'C_INCLUDE_SG': '0',
   'C_SG_INCLUDE_STSCNTRL_STRM': '0',
   'C_SG_USE_STSAPP_LENGTH': '0',
   'C_SG_LENGTH_WIDTH': '26',
   'C_M_AXI_SG_ADDR_WIDTH': '32',
   'C_M_AXI_SG_DATA_WIDTH': '32',
   'C_M_AXIS_MM2S_CNTRL_TDATA_WIDTH': '32',
   'C_S_AXIS_S2MM_STS_TDATA_WIDTH': '32',
   'C_MICRO_DMA': '0',
   'C_INCLUDE_MM2S': '1',
   'C_INCLUDE_MM2S_SF': '1',
   'C_MM2S_BURST_SIZE': '16',
   'C_M_AXI_MM2S_ADDR_WIDTH': '32',
   'C_M_AXI_MM2S_DATA_WIDTH': '32',
   'C_M_AXIS_MM2S_TDATA_WIDTH': '8',
   'C_INCLUDE_MM2S_DRE': '0',
   'C_INCLUDE_S2MM': '0',
   'C_INCLUDE_S2MM_SF': '1',
   'C_S2MM_BURST_SIZE': '16

In [3]:
# input streams
dma_len = ol.axi_dma_len
dma_msg = ol.axi_dma_msg
dma_end_len = ol.axi_dma_end_len

# output streams
dma_hash = ol.axi_dma_hash

In [4]:
hasher = ol.DUT_FUNC_0
hasher_reg = hasher.register_map

In [5]:
def sha_arr(arr):
    #start_time = time.time()
    arr.append(b'1')
    
    el = len(arr)
    offset = []
    
    end_len_buf = allocate(shape=(el + 1,), dtype=np.uint8)
    len_buf = allocate(shape=(el,), dtype=np.uint64)
    hash_buf = allocate(shape=(el * 32,), dtype=np.uint8)
    
    msg_len = 0
    
    for (k, a) in enumerate(arr):
        l = len(a)
        l_buf = int(l / 8)
        if l / 8 != l_buf:
            l_buf = l_buf + 1

        l_buf = l_buf * 8
        
        offset.append(msg_len)
        msg_len += l_buf
        
        len_buf[k] = l
        end_len_buf[k] = 0
    
    end_len_buf[k] = 1
    
    msg_buf = allocate(shape=(msg_len,), dtype=np.uint8)
    
    for (k, a) in enumerate(arr):
        msg_buf[offset[k]:offset[k] + len(a)] = bytearray(a)
    #print("pre: --- %s seconds ---" % (time.time() - start_time))    
    
    # start transfer
    #start_time = time.time()
    # send buffers
    dma_len.sendchannel.transfer(len_buf)
    dma_msg.sendchannel.transfer(msg_buf)
    dma_end_len.sendchannel.transfer(end_len_buf)
    
    # get results
    dma_hash.recvchannel.transfer(hash_buf)
    
    # start sha
    hasher_reg.CTRL.AP_START = 1
    
    # wait results
    dma_hash.recvchannel.wait()
    #print("hw: --- %s seconds ---" % (time.time() - start_time))
    
    #start_time = time.time()
    ret = []
    
    for i in range(el - 1):
        ret.append(bytearray(hash_buf[32 * (i + 1) : 32 * (i + 1) + 32]).hex())
    #print("post: --- %s seconds ---" % (time.time() - start_time))
            
    # clean up
    del msg_buf
    del end_len_buf
    del len_buf
    del hash_buf
    
    return ret

In [6]:
def sha_arr_print(arr):
    start_time = time.time()
    arr.append(b'1')
    
    el = len(arr)
    offset = []
    
    end_len_buf = allocate(shape=(el + 1,), dtype=np.uint8)
    len_buf = allocate(shape=(el,), dtype=np.uint64)
    hash_buf = allocate(shape=(el * 32,), dtype=np.uint8)
    
    msg_len = 0
    
    for (k, a) in enumerate(arr):
        l = len(a)
        l_buf = int(l / 8)
        if l / 8 != l_buf:
            l_buf = l_buf + 1

        l_buf = l_buf * 8
        
        offset.append(msg_len)
        msg_len += l_buf
        
        len_buf[k] = l
        end_len_buf[k] = 0
    
    #print("msg buff size: ", float(msg_len)/1024.0/1024.0)
    print(float(msg_len)/1024.0/1024.0, end = ',')
    
    end_len_buf[k] = 1
    
    msg_buf = allocate(shape=(msg_len,), dtype=np.uint8)
    
    for (k, a) in enumerate(arr):
        msg_buf[offset[k]:offset[k] + len(a)] = bytearray(a)
    #print("pre: --- %s seconds ---" % (time.time() - start_time))
    print((time.time() - start_time), end = ',')
    
    # start transfer
    
    # send buffers
    dma_len.sendchannel.transfer(len_buf)
    dma_msg.sendchannel.transfer(msg_buf)
    dma_end_len.sendchannel.transfer(end_len_buf)
    
    # get results
    dma_hash.recvchannel.transfer(hash_buf)
    
    start_time = time.time()
    # start sha
    hasher_reg.CTRL.AP_START = 1
    
    # wait results
    dma_hash.recvchannel.wait()
    #print("hw: --- %s seconds ---" % (time.time() - start_time))
    print((time.time() - start_time), end = ',')
    
    start_time = time.time()
    ret = []
    
    for i in range(el - 1):
        ret.append(bytearray(hash_buf[32 * (i + 1) : 32 * (i + 1) + 32]).hex())
    #print("post: --- %s seconds ---" % (time.time() - start_time))
    print((time.time() - start_time), end = ',')
            
    # clean up
    del msg_buf
    del end_len_buf
    del len_buf
    del hash_buf
    
    return ret

In [7]:
import hashlib

def SHA256_hardware(bstr):
    return sha_arr([bstr])[0]

def SHA256_software(bstr):
    return hashlib.sha256(bstr).hexdigest()

In [8]:
print(sha_arr([b'abcdef']))

['0000000000000000000000000000000000000000000000000000000000000000']


In [9]:
import csv

tests = []

with open('sha256_test.csv') as csvfile:
    csvreader = csv.reader(csvfile, delimiter=',')
    
    for row in csvreader:
        tests.append(row)

In [10]:
# test hw sha
for test in tests:
    hw = SHA256_hardware(bytes(test[0], encoding='utf-8'))
    assert (hw == test[1]), (test[0], hw, test[1])

print("All golden model tests passed")

All golden model tests passed


In [11]:
import secrets

sw_time = %timeit -n 1000 -r 5 -o SHA256_software(secrets.token_bytes(16))
hw_time = %timeit -n 1000 -r 5 -o SHA256_hardware(secrets.token_bytes(16))
print('Performance gain:', sw_time.average / hw_time.average) 

15.6 µs ± 3.66 µs per loop (mean ± std. dev. of 5 runs, 1,000 loops each)
2.25 ms ± 9.81 µs per loop (mean ± std. dev. of 5 runs, 1,000 loops each)
Performance gain: 0.00690682837790949


In [12]:
arr = [bytes(test[0], encoding='utf-8') for test in tests]

print('size,sw_time,buf_size,pre,hw_sha,post,hw_time')

for i in range(15):
    print(len(arr), end = ',')
    #print("software")
    start_time = time.time()
    for i in arr:
        SHA256_software(i)
    #print("--- %s seconds ---" % (time.time() - start_time))
    print((time.time() - start_time), end = ',')
    #print("hardware")
    start_time = time.time()
    sha_arr_print(arr)
    #print("--- %s seconds ---" % (time.time() - start_time))
    print((time.time() - start_time))
    arr = arr + arr

size,sw_time,buf_size,pre,hw_sha,post,hw_time
100,0.0007460117340087891,0.00373077392578125,0.00942373275756836,0.0002155303955078125,0.004403829574584961,0.01546788215637207
202,0.001438140869140625,0.00746917724609375,0.012952566146850586,0.00023865699768066406,0.008584260940551758,0.022915124893188477
406,0.002862691879272461,0.01494598388671875,0.022758960723876953,0.0004181861877441406,0.018067121505737305,0.0425112247467041
814,0.00562286376953125,0.02989959716796875,0.04359269142150879,0.0004477500915527344,0.03489422798156738,0.08034706115722656
1630,0.012357473373413086,0.05980682373046875,0.10672950744628906,0.0007140636444091797,0.07037472724914551,0.18120121955871582
3262,0.022431373596191406,0.11962127685546875,0.18264102935791016,0.0012962818145751953,0.14246487617492676,0.3308577537536621
6526,0.04581904411315918,0.23925018310546875,0.3552978038787842,0.002301454544067383,0.27952051162719727,0.6413440704345703
13054,0.09078764915466309,0.47850799560546875,0.6977498531341