In [19]:
import struct
import math
import numpy as np

def decimal_to_binary(decimal_number, max_iterations=128):
    if decimal_number < 0:
        sign = 1
    else:
        sign = 0
        
    decimal_number = abs(decimal_number)
    integer_part = bin(int(decimal_number))[2:]

    fractional_part = []
    decimal_fraction = decimal_number - int(decimal_number)
    iterations = 0

    while decimal_fraction != 0 and iterations < max_iterations:
        decimal_fraction *= 2
        bit = str(int(decimal_fraction))
        fractional_part.append(bit)
        decimal_fraction -= int(decimal_fraction)
        iterations += 1

    return list(integer_part), fractional_part, sign

def convert_to_scientific_notation(integer_part, fractional_part):
    # 计算整数部分的二进制指数
    exponent = len(integer_part) - 1

    # 将整数部分和小数部分合并为一个二进制字符串
    binary = "".join(integer_part) + "".join(fractional_part)
    #print(binary)
    check = 0
    for i in range(len(binary)):
        if binary[i] == '1':
            check = 1
            break
            
    if check == 0:
        return -17000, ['0']

    # 找到第一个非零位的索引
    #print(check)
    first_nonzero_index = binary.index('1')

    # 计算小数部分的二进制指数
    fractional_exponent = first_nonzero_index - exponent

    # 提取科学计数法中的指数部分
    if first_nonzero_index == 0:
        exponent_part = exponent
    else:
        exponent_part = -1 * fractional_exponent

    # 提取科学计数法中的小数部分
    fractional_part = list(binary[first_nonzero_index + 1:])

    return exponent_part, fractional_part

def get_bits_from_list_8bit(integer_part,fractional_part, sign):
    exponent_part, fractional_part = convert_to_scientific_notation(integer_part, fractional_part)
    exponent = exponent_part + 7
    #print(exponent)
    if exponent > 15:
        if sign == 0:
            return ['0','1','1','1','1','1','1','1']
        else:
            return ['1','1','1','1','1','1','1','1']
    elif exponent < 0:
        if sign == 0:
            return ['0','0','0','0','0','0','0','0']
        else:
            return ['1','0','0','0','0','0','0','0']
    result = []
    if sign == 0:
        result.append('0')
    else:
        result.append('1')
    exp, temp, temp2 = decimal_to_binary(exponent)
    while len(fractional_part) < 4:
        fractional_part.append('0')
    if fractional_part[0:4] == ['0','0','1','1']:
        fractional_part = ['0','1','0']
    elif fractional_part[0:4] == ['0','1','1','1']:
        fractional_part = ['1','0','0']
    elif fractional_part[0:4] == ['1','0','1','1']:
        fractional_part = ['1','1','0']
    else:
        fractional_part = fractional_part[0:3]
    while len(exp) < 4:
        exp.insert(0, '0')
    result = result + exp + fractional_part[0:3]
    return result

def get_bits_from_list_128bit(integer_part,fractional_part, sign):
    exponent_part, fractional_part = convert_to_scientific_notation(integer_part, fractional_part)
    exponent = exponent_part + 16383
    allone = []
    allzero = []
    for i in range(127):
        allone.append('1')
        allzero.append('0')
        
    if exponent > 32767:
        if sign == 0:
            return ['0'] + allone
        else:
            return ['1'] + allone
    elif exponent < 0:
        if sign == 0:
            return ['0'] + allzero
        else:
            return ['1'] + allzero
    result = []
    if sign == 0:
        result.append('0')
    else:
        result.append('1')
    exp, temp, temp2 = decimal_to_binary(exponent)
    while len(fractional_part) < 112:
        fractional_part.append('0')
    while len(exp) < 15:
        exp.insert(0, '0')
    result = result + exp + fractional_part[0:112]
    return result

def get_bits_from_float(value, num_bits):
    # 根据浮点数的位数选择合适的格式字符串
    if num_bits == 32:
        fmt = '!f'  # 32位浮点数
    elif num_bits == 64:
        fmt = '!d'  # 64位浮点数
    elif num_bits == 16:
        fmt = '!e'  # 16位浮点数
    elif num_bits == 8:
        integer_part, fractional_part, sign = decimal_to_binary(value)
        #print(integer_part,fractional_part,sign)
        result = get_bits_from_list_8bit(integer_part,fractional_part, sign)
        return result
    elif num_bits == 128:
        integer_part, fractional_part, sign = decimal_to_binary(value)
        result = get_bits_from_list_128bit(integer_part,fractional_part, sign)
        return result
    else:
        raise ValueError("Unsupported number of bits")

    # 将浮点数转换为二进制数据
    binary = struct.pack(fmt, value)
    
    # 将二进制数据转换为比特列表
    bits = []  # 创建一个空列表来存储结果
    for byte in binary:
        binary_str = bin(byte)[2:].rjust(8, '0')  # 将字节转换为二进制字符串并填充为8位
        for bit in binary_str:
            bit_int = int(bit)  # 将每个位从二进制转换为整数
            bits.append(bit_int) 
    
    result = []
    for i in range(len(bits)):
        if bits[i] == 0:
            result.append('0')
        elif bits[i] == 1:
            result.append('1')
    
    return result

def float_from_bits_8(bits):
    exp_b = bits[1:5]
    mantissa_b = ['1'] + bits[5:8]
    exp = ''.join(exp_b)  # 将列表中的元素连接成字符串
    exp = int(exp, 2)  # 将二进制字符串转换为十进制数
    exp = exp - 7
    mantissa = ''.join(mantissa_b)  # 将列表中的元素连接成字符串
    mantissa = int(mantissa, 2)  # 将二进制字符串转换为十进制数
    result = mantissa * (2**(exp - 3))
    if bits[0] == '1':
        result = result * -1
    return result

def float_from_bits_128(bits):
    exp_b = bits[1:16]
    mantissa_b = ['1'] + bits[16:128]
    exp = ''.join(exp_b)  # 将列表中的元素连接成字符串
    exp = int(exp, 2)  # 将二进制字符串转换为十进制数
    exp = exp - 16383
    mantissa = ''.join(mantissa_b)  # 将列表中的元素连接成字符串
    mantissa = int(mantissa, 2)  # 将二进制字符串转换为十进制数
    result = mantissa * (2**(exp - 112))
    if bits[0] == '1':
        result = result * -1
    return result

def float_from_bits(bits):
    # 根据比特串的长度选择合适的格式字符串
    if len(bits) == 16:
        fmt = '!e'  # 16位浮点数
    elif len(bits) == 32:
        fmt = '!f'  # 32位浮点数
    elif len(bits) == 64:
        fmt = '!d'  # 64位浮点数
    elif len(bits) == 8:
        result = float_from_bits_8(bits)
        return result
    elif len(bits) == 128:
        result = float_from_bits_128(bits)
        return result
    else:
        raise ValueError("Unsupported number of bits")

    # 将比特串转换为字节串
    binary = bytearray()
    for i in range(0, len(bits), 8):
        byte = bits[i:i+8]
        byte_str = ''.join(str(bit) for bit in byte)
        binary.append(int(byte_str, 2))
    
    # 解包字节串为浮点数
    value = struct.unpack(fmt, binary)[0]
    
    return value

def sigmoid(x):
    return 1 / (1 + math.exp(-x))

def tanh(z):
    return (np.exp(z)-np.exp(-z))/(np.exp(z)+np.exp(-z))

def swish(x):
    return x * sigmoid(x)

def ELU(x, alpha = 0.1):
    return np.where(x >= 0.0, x, alpha * (np.exp(x) - 1))

def softmax(x):
    """Compute the softmax of vector x."""
    exp_x = np.exp(x)
    softmax_x = exp_x / np.sum(exp_x)
    return softmax_x

def GELU(x):
    return x*sigmoid(1.702*x)

def tanh_table(n):
    check = 0
    if (n == 8) or (n == 16) or (n == 32) or (n == 64) or (n == 128):
        check = 1
    if check == 0:
        print("n error")
        return 0
    table = []
    for i in range(2**n):
        binary, t1, t2 = decimal_to_binary(i)
        while len(binary) < n:
            binary.insert(0, '0')
        new = [binary]
        inputv = float_from_bits(binary)
        if inputv < -650:
            temp = get_bits_from_float(-1, n)
            new.append(temp)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),-1)
        elif inputv > 650:
            temp = get_bits_from_float(1, n)
            new.append(temp)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),1)
        else:
            outputv = tanh(inputv)
            outputbit = get_bits_from_float(outputv, n)
            new.append(outputbit)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),outputv)
        table.append(new)
        #print(new)
    return table

def sigmoid_table(n):
    check = 0
    if (n == 8) or (n == 16) or (n == 32) or (n == 64) or (n == 128):
        check = 1
    if check == 0:
        print("n error")
        return 0
    table = []
    for i in range(2**n):
        binary, t1, t2 = decimal_to_binary(i)
        while len(binary) < n:
            binary.insert(0, '0')
        new = [binary]
        inputv = float_from_bits(binary)
        if inputv < -650:
            temp = get_bits_from_float(0, n)
            new.append(temp)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),0)
        elif inputv > 650:
            temp = get_bits_from_float(1, n)
            new.append(temp)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),1)
        else:
            outputv = sigmoid(inputv)
            outputbit = get_bits_from_float(outputv, n)
            new.append(outputbit)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),outputv)
        table.append(new)
        #print(new)
    return table

def swish_table(n):
    check = 0
    if (n == 8) or (n == 16) or (n == 32) or (n == 64) or (n == 128):
        check = 1
    if check == 0:
        print("n error")
        return 0
    table = []
    for i in range(2**n):
        binary, t1, t2 = decimal_to_binary(i)
        while len(binary) < n:
            binary.insert(0, '0')
        new = [binary]
        inputv = float_from_bits(binary)
        if inputv < -650:
            temp = get_bits_from_float(0, n)
            new.append(temp)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),0)
        elif inputv > 650:
            new.append(binary)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),inputv)
        else:
            outputv = swish(inputv)
            outputbit = get_bits_from_float(outputv, n)
            new.append(outputbit)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),outputv)
        table.append(new)
        #print(new)
    return table

def ELU_table(n,alpha = 0.1):
    check = 0
    if (n == 8) or (n == 16) or (n == 32) or (n == 64) or (n == 128):
        check = 1
    if check == 0:
        print("n error")
        return 0
    table = []
    for i in range(2**n):
        binary, t1, t2 = decimal_to_binary(i)
        while len(binary) < n:
            binary.insert(0, '0')
        new = [binary]
        inputv = float_from_bits(binary)
        if inputv < -650:
            temp = get_bits_from_float(-1 * alpha, n)
            new.append(temp)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),-1 * alpha)
        elif inputv > 650:
            new.append(binary)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),inputv)
        else:
            outputv = ELU(inputv, alpha)
            outputbit = get_bits_from_float(outputv, n)
            new.append(outputbit)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),outputv)
        table.append(new)
        #print(new)
    return table

def softmax_table(n,length):
    check = 0
    if (n == 8) or (n == 16) or (n == 32) or (n == 64) or (n == 128):
        check = 1
    if check == 0:
        print("n error")
        return 0
    if length < 1:
        print("length error")
        return 0
    table = []
    for i in range(2**(n*length)):
        binary, t1, t2 = decimal_to_binary(i)
        while len(binary) < n*length:
            binary.insert(0, '0')
        new = [binary]
        inputv = []
        for j in range(length):
            temp = binary[j*n:j*n + n]
            inputv.append(float_from_bits(temp))
        outputv = softmax(inputv)
        outputbit = []
        for j in range(length):
            temp = get_bits_from_float(outputv[j], n)
            for k in range(len(temp)):
                outputbit.append(temp[k])
        new.append(outputbit)
        table.append(new)
        #print(inputv,outputv)
        #print(new)
    return table

def GELU_table(n):
    check = 0
    if (n == 8) or (n == 16) or (n == 32) or (n == 64) or (n == 128):
        check = 1
    if check == 0:
        print("n error")
        return 0
    table = []
    for i in range(2**n):
        binary, t1, t2 = decimal_to_binary(i)
        while len(binary) < n:
            binary.insert(0, '0')
        new = [binary]
        inputv = float_from_bits(binary)
        if inputv < -400:
            temp = get_bits_from_float(0, n)
            new.append(temp)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),0)
        elif inputv > 400:
            new.append(binary)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),inputv)
        else:
            outputv = GELU(inputv)
            outputbit = get_bits_from_float(outputv, n)
            new.append(outputbit)
            #print(i,float_from_bits(new[0]),float_from_bits(new[1]),outputv)
        table.append(new)
        #print(new)
    return table