# Polyphony をインストール

In [1]:
!pip3 install polyphony 



## Python の bfloat16 用の mul を定義

In [2]:
from polyphony import testbench
from polyphony.typing import List, bit16, bit9, bit8, bit
def mul(x:bit16, w:bit16):
    if x == 0:
        return 0
    if w == 0:
        return 0
    x_e = (x >> 7) & 0xFF
    w_e = (w >> 7) & 0xFF
    e:bit8 = (x_e - 127) + (w_e - 127) + 127
    x_n = (x & 0x7F)
    w_n = (w & 0x7F)
    new_n:bit16 = ((x_n | 0x80) * (w_n | 0x80))
    if new_n & 0x8000:
        new_n >>= 8
        e += 1
    else:
        new_n >>= 7
    #print(x_n, w_n, new_n)
    new_n &= 0x7F
    s = (x & 0x8000) ^ (w & 0x8000)
    #print('s', s, e, new_n)

    x_w = s | (e << 7) | (new_n)
    #print(f'result:{x_w:2x}, {w_n} {x_n} {new_n}')
    return x_w

## Python の bfloat16 用の add の為の関数(sub_add)とaddを定義

In [3]:
def sub_add(x_sign:bit, x:bit8, b_sign:bit, b:bit8, e:bit8):
    #print(f'sub_add {x_sign} {x:08b}, {b_sign} {b:08b}')
    #print('sub_add', x_sign, x, b_sign, b, e)

    if (x_sign == 0 and b_sign == 0) or (x_sign == 1 and b_sign == 1):
        rv_n:bit9 = (x + b)
        add_e = 1 if rv_n & 0x100 else 0
        if add_e:
            rv_n >>= 1

        return (0x8000 if x_sign else 0x0000) | ((e + add_e) << 7) | (rv_n & 0x7F)
    else:
        if x < b:
            x_sign, b_sign = b_sign, x_sign
            x, b = b, x

        #print('sub_add', x_sign, x, b_sign, b, e, x - b)
        
        rv_n = x - b
        rv_sign = x_sign
        for i in range(0, 7):
            #print(i, rv_n)
            if rv_n & 0x80:
                return (0x8000 if rv_sign else 0x0000) | ((e - i) << 7) | (rv_n & 0x7F)
            rv_n <<= 1

        return 0

In [4]:
def add(x:bit16, b:bit16):
    if x == 0:
        return b
    if b == 0:
        return x
    x_sign = 1 if x & 0x8000 else 0
    b_sign = 1 if b & 0x8000 else 0
    x_e = (x >> 7) & 0xFF
    b_e = (b >> 7) & 0xFF

    if x_e < b_e:
        x, b = b, x
        x_e, b_e = b_e, x_e
        x_sign, b_sign = b_sign, x_sign

    x_n = (x & 0x7F)
    b_n = (b & 0x7F)

    d = x_e - b_e
    e = x_e
    if d > 8:
        return x
    
    new_n:bit8 = 0
    if d == 0:
        rv = sub_add(x_sign, 0x80 | x_n, b_sign, 0x80 | b_n, e)
    else:
        new_b_n = ((0x80 | b_n) >> d) + ((b_n >> (d-1)) & 1)
        #print('d:', d, 'b_n:', f'{b_n:8b} {new_b_n:8b}')
        if new_b_n == 0:
            return x

        rv = sub_add(x_sign, (0x80 | x_n), b_sign, new_b_n, e)

    return rv

## テスト(ベンチ) プログラムの定義

In [5]:
@testbench
def test():
    x = 16209
    w = 48824
    b = 16036
    rv = mul(x, w)
    rv2 = add(rv, b)
    print(rv2, rv, '<=', x, '*',w, '+', b)
    assert(rv2 == 15584)
    assert(rv == 48790)

    a = 48790
    c = 16036
    rv3 = add(a, c)
    print(rv3, '<=', a, '+',c)
    assert(rv3 == 15584)
    
    x = 16457
    w = 16042
    rv = mul(x, w)
    print(rv)
    assert(rv == 16261)
    rv = add(x, w)
    print(rv)
    assert(rv == 16478)

## テスト(ベンチ) プログラムの実行

In [6]:
test()

15584 48790 <= 16209 * 48824 + 16036
15584 <= 48790 + 16036
16261
16478


上の10進表示ではいくら何でもわかりづらいですね。浮動小数点数で表示してみます。

In [7]:
from float2bfloat import float2bfloat, bfloat2float

x = 16209
w = 48824
b = 16036
rv = mul(x, w)
rv2 = add(rv, b)
print(bfloat2float(rv2), bfloat2float(rv), '<=', bfloat2float(x), '*',bfloat2float(w), '+', bfloat2float(b))
assert(rv2 == 15584)
assert(rv == 48790)

0.02734375 -0.29296875 <= 0.81640625 * -0.359375 + 0.3203125
