In [1]:
#!/usr/bin/env python3
# Copyright 2022 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import argparse
import pathlib
import hjson
import random
import os

In [2]:
np.random.seed(42)
torch.manual_seed(42)
global verbose

In [3]:
def array_to_cstr(a):
    out = '{'
    if isinstance(a, np.ndarray):
        a = a.flat
    if isinstance(a, torch.Tensor):
        a = a.numpy().flat
    for el in a:
        out += '{}, '.format(el)
    out = out[:-2] + '}'
    return out

In [4]:
def emit_mnist_data(name='mnist', **kwargs):
    
    # constants
    IN_CH = kwargs['IN_CH']
    OUT_CH = kwargs['OUT_CH']
    DATASET_SIZE = kwargs['DATASET_SIZE']
    
    # data
    MAT_INPUT = kwargs['INPUT']
    MAT_LABELS = kwargs['LABELS']

    # network init parameters from golden model
    MAT_WEIGHTS = kwargs['WEIGHTS']
    MAT_BIASES = kwargs['BIASES'] 
    
    layer_str = ''
    layer_str += '#include "network.h"\n\n'
    layer_str += f'network_benchmark_t {name}_t = {{\n'
    layer_str += f'\t.IN_CH = {IN_CH},\n'
    layer_str += f'\t.OUT_CH = {OUT_CH},\n'
    layer_str += f'\t.dtype = FP{kwargs["prec"]}\n'
    layer_str += '};\n\n\n'

    ctypes = {
        '64': 'double',
        '32': 'float',
        '16': '__fp16',
        'B16': '__bf16',
        '8': 'char'
    }

    dtype = ctypes[str(kwargs['prec'])]

    # network initialization
    layer_str += f'static {dtype} {name}_weights_dram [{OUT_CH}][{IN_CH}] = ' + array_to_cstr(MAT_WEIGHTS) + ';\n\n\n'
    layer_str += f'static {dtype} {name}_biases_dram [{OUT_CH}][{1}] = ' + array_to_cstr(MAT_BIASES) + ';\n\n\n'


    # input data
    layer_str += f'static {dtype} {name}_images_dram [{DATASET_SIZE*IN_CH}][{1}] = ' + array_to_cstr(MAT_INPUT) + ';\n\n\n'
    layer_str += f'static uint32_t {name}_labels_dram [{DATASET_SIZE}][{1}] = ' + array_to_cstr(MAT_LABELS) + ';\n\n\n'

    return layer_str


In [6]:
def emit_mnist_header_file(layer_type: str, **kwargs):

    file_path = '/scratch/msc22f11/msc22f11/snitch/sw/applications/data/'
    emit_str = "// Copyright 2022 ETH Zurich and University of Bologna.\n" + \
               "// Licensed under the Apache License, Version 2.0, see LICENSE for details.\n" + \
               "// SPDX-License-Identifier: Apache-2.0\n\n"

    if(layer_type == 'mnist'):
        file = file_path + 'data_fp8_benchmark.h'
        emit_str += emit_mnist_data(**kwargs)

    with open(file, 'w') as f:
        f.write(emit_str)


In [5]:
def Linear(input, weights, bias, **kwargs):
    out = torch.mul(input, weights)
    return out

In [6]:
# download MNIST dataset using DataLoader

transform = transforms.Compose(
    [
        transforms.ToTensor()
    ]
)

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
mnist_dataset = MNIST(PATH_DATASETS, train=True, transform=transform, download=True)

# set seeds for reproducability 
g = torch.Generator()
g.manual_seed(42)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

mnist_dl = DataLoader(mnist_dataset, worker_init_fn=seed_worker, generator=g)

In [7]:
first_im, first_label = next(iter(mnist_dl))

In [8]:
np.random.seed(42)
torch.manual_seed(42)
# get input channels
IN_CH = 1 * 32 * 32 # 1 channels, 32x32 pixels
OUT_CH = 16 # 16 classes
r1 = 0
r2 = 0.5

# get random input data with shape (IN_CH, 1)
# input = first_im.to(torch.float64).view(first_im.to(torch.float64).size(0), -1) #torch.randn(IN_CH)
input = torch.randn(IN_CH, dtype=torch.float64)
print(input.shape)

# get random weights with shape (OUT_CH, IN_CH)
weights = torch.FloatTensor(OUT_CH, IN_CH).uniform_(r1, r2).to(torch.float64) #torch.randn(OUT_CH, IN_CH).to(torch.float32)
print(weights.shape)

# get random bias with shape (OUT_CH, 1)
bias = torch.FloatTensor(OUT_CH).uniform_(r1, r2).to(torch.float64)#torch.randn(OUT_CH).to(torch.float32)
print(bias.shape)

# calculate the activations of the linear layer
activations = input @ weights.t() + bias
print(activations.shape)
# get a random integer between 0 and 16
label = torch.randint(0, 16, (1,))
print(label)

torch.Size([1024])
torch.Size([16, 1024])
torch.Size([16])
torch.Size([16])
tensor([5])


In [9]:
import pathlib
import ctypes

from ctypes import c_uint8, c_double, c_float
from ctypes import byref, Structure


class flexfloat_desc_t(Structure):
    _fields_ = [("exp_bits", c_uint8), ("frac_bits", c_uint8)]


class flexfloat_t(Structure):
    _fields_ = [("value", c_double), ("desc", flexfloat_desc_t)]


fp64_desc = flexfloat_desc_t(11, 52)
fp32_desc = flexfloat_desc_t(8, 23)
fp16_desc = flexfloat_desc_t(5, 11)
fp16alt_desc = flexfloat_desc_t(8, 7)
fp8_desc = flexfloat_desc_t(5, 2)
fp8alt_desc = flexfloat_desc_t(4, 3)

lib_path = "/scratch/msc22f11/msc22f11/PlayGround/flexfloat/src/libflexfloat.so"
ff_lib = ctypes.CDLL(lib_path)

ff_get_float = ff_lib.ff_get_float
ff_get_float.restype = c_float


class ff:
    def __init__(self, value: float, desc: flexfloat_desc_t = fp64_desc):
        self.desc = desc
        self.value = value
        self.a = flexfloat_t(value, desc)
        ff_lib.ff_init_float(byref(self.a), c_float(value), desc)

    def __add__(self, b):
        ff_res = flexfloat_t(0.0, self.desc)
        ff_lib.ff_add(byref(ff_res), byref(self.a), byref(b.a))
        return 0


if __name__ == "__main__":

    a = ff(3.0, fp32_desc)
    b = ff(4.0, fp32_desc)

    ff_a = flexfloat_t(0.0, fp8_desc)
    ff_b = flexfloat_t(0.0, fp8_desc)
    ff_c = flexfloat_t(0.0, fp8_desc)

    ff_lib.ff_init_float(byref(ff_a), c_float(1.0), fp8_desc)
    ff_lib.ff_init_float(byref(ff_b), c_float(2.0), fp8_desc)
    ff_lib.ff_init_float(byref(ff_c), c_float(3.0), fp8_desc)

    ff_lib.ff_add(byref(ff_c), byref(ff_b), byref(ff_c))

    c = ff_lib.ff_get_float(byref(ff_c))
    print(c)

5.0


In [10]:
# function to convert float32 to binary representation
import struct

def float32_to_bin(value):
    return ''.join(f'{c:0>8b}' for c in struct.pack('!f', value))

In [11]:
"""
We have to handle denormalized numbers:
    
    +INF will be represented in FP8 as 0 11111 00 
    -INF will be represented in FP8 as 1 11111 00
    +0 will be represented in FP8 as 0 00000 00
    -0 will be represented in FP8 as 1 00000 00
    NaN will be represented in FP8 as X 11111 MM (at least one of the MMM bits is set, sign bit is don't care)

According to https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8556098 denormalized transprecision numbers
will be represented by their high precision counterparts. In these cases we have to make sure that we do not adjust
the exponent. In the other cases we adjust the exponent and cut the mantissa.
"""

from numpy import binary_repr

# this function returns an 8 character string representing the binary representation of the FP8 number
def float32_to_fp8(value):
    max_exp_fp32 = int('11111111', 2)
    min_exp_fp32 = int('00000000', 2)
    exp_bias_fp32 = 2 ** (8 - 1) - 1
    exp_bias_fp8 = 2 ** (5 - 1) - 1
    # get the binary representation of the number
    binstr = float32_to_bin(value)
    # extract sign, exponent and mantissa bits
    sign = binstr[0]
    exponent = binstr[1:9]
    mantissa = binstr[9:]
    # check if the number is denormalized
    # we start by checking if all exponent bits are asserted
    if(int(exponent) == max_exp_fp32):
        # if so, we check if the mantissa is all zeros (will result in +/-INF)
        if(int(mantissa) == 0):
            return sign + exponent[:5] + mantissa[:2]
        # if not, we have to return a NaN
        else:
            return sign + exponent[:5] + '01'
    # if both exponent and mantissa are zero we will return +/-0
    elif (int(exponent) == min_exp_fp32 and int(mantissa) == 0): 
        return sign + exponent[:5] + mantissa[:2]
    else :
        # if not, we adjust the exponent and cut the mantissa
        exponent_fp8 = binary_repr(int(exponent, 2) - exp_bias_fp32 + exp_bias_fp8, width=5)
        mantissa_fp8 = mantissa[:2]
        return sign + exponent_fp8 + mantissa_fp8


In [12]:
# exponent is stored in two's complement
def twos_comp(val, bits):
    """compute the 2's complement of int value val"""
    if (val & (1 << (bits - 1))) != 0: # if sign bit is set e.g., 8bit: 128-255
        val = val - (1 << bits)        # compute negative value
    return val  

In [13]:
def convert_to_fp8_decimal(binstr):

    # extract sign, exponent and mantissa bits
    sign = binstr[0]
    num_sign_bits = 1
    # print(f'Sign:     ({num_sign_bits} bit)  = {sign}')
    exponent = binstr[1:6]
    num_exp_bits = len(exponent)
    # print(f'Exponent: ({num_exp_bits} bit)  = {exponent}')
    mantissa = binstr[6:]
    num_mant_bits = len(mantissa)
    # print(f'Mantissa: ({num_mant_bits} bit) = {mantissa}')

    exp_bias_fp8 = 2 ** (5 - 1) - 1
    dec_val_fp8 = (-1)**(int(sign, 2)) * (1 + (int(mantissa, 2))/(2**num_mant_bits)) * 2**(twos_comp(int(exponent, 2), num_exp_bits) - exp_bias_fp8)
    if(int(sign, 2) == 0 and int(exponent, 2) == 0 and int(mantissa, 2) == 0):
        dec_val_fp8 = 0
    # print("\nBinary to floating point number (FP8) conversion using formula: ", dec_val_fp8)
    return dec_val_fp8
    

In [34]:
# max FP8 value:
# 0 11110 11 = 0.99609375
fp8_max = '0 11110 11'
fp8_max_dec = convert_to_fp8_decimal(fp8_max.replace(' ', ''))
fp8_max_dec
# get NaN value
fp8_nan = '01111101'
fp8_nan2 = '11111111'
fp8_nan3 = '11111110' 
fp8_nan4 = '01111110'
# get INF value
fp8_inf = '01111100'
fp8_inf2 = '11111100'
# get max FP8 value
fp8_max = '01111011'
# get decimal from string
print("nan = ", int(fp8_nan, 2))
# print(twos_comp(int(fp8_nan, 2), 8))
print("nan2 = ", int(fp8_nan2, 2))
# print(twos_comp(int(fp8_nan2, 2), 8))
print("nan3 = ", int(fp8_nan3, 2))
print("nan4 = ", int(fp8_nan4, 2))
print("inf = ", int(fp8_inf, 2))
# print(twos_comp(int(fp8_inf, 2), 8))
print("-inf = ", int(fp8_inf2, 2))
# print(twos_comp(int(fp8_inf2, 2), 8))
print("max = ", int(fp8_max, 2))
# bin(126)

nan =  125
nan2 =  255
nan3 =  254
nan4 =  126
inf =  124
-inf =  252
max =  123


In [21]:
# convert input to FP8
input_fp8 = [float32_to_fp8(x) for x in input.flatten().tolist()]
# reshape input_fp8 back to tensor shape
input_fp8 = np.array(input_fp8).reshape(input.shape)
print(input_fp8.shape)
input = input_fp8
decimal_input = [convert_to_fp8_decimal(x) for x in input.flatten().tolist()]
print(decimal_input)

(1, 784)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.01171875, 0.0625, 0.0625, 0.0625, 0.4375, 0.5, 0.625, 0.09375, 0.625, 1.0, 0.875, 0.4375, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.109375, 0.125, 0.3125, 0.5, 0.625, 0.875, 0.875, 0.875, 0.875, 0.875, 0.875, 0.625, 0.875, 0.875, 0.75, 0.25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1875, 0.875, 0.875, 0.875, 0.875, 0.875, 0.875, 0.875, 0.875, 0.875, 0.875, 0.3125, 0.3125, 0.3125, 0.21875, 0.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.0625, 0.75, 0.875, 0.875, 0.875, 0.875, 0.875, 0.75, 0.625, 0.875, 0.875, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [22]:
# for C header we need to add '0b' to the binary representation
print("Input before: ", input[0][0])
input_t = ['0b' + x for x in input.flatten().tolist()]
print("Input after: ", input_t[0])

Input before:  00000000
Input after:  0b00000000


In [24]:
# convert weights to FP8
weights_fp8 = [float32_to_fp8(x) for x in weights.flatten().tolist()]
# reshape weights_fp8 back to tensor shape
weights_fp8 = np.array(weights_fp8).reshape(weights.shape)
print(weights_fp8.shape)
weights = weights_fp8
decimal_weights = [convert_to_fp8_decimal(x) for x in weights.flatten().tolist()]
print(decimal_weights)

(16, 784)
[0.4375, 0.4375, 0.1875, 0.4375, 0.1875, 0.25, 0.125, 0.375, 0.4375, 0.0625, 0.4375, 0.25, 0.375, 0.25, 0.3125, 0.1875, 0.4375, 0.25, 0.125, 0.3125, 0.125, 0.21875, 0.125, 0.375, 0.046875, 0.125, 0.15625, 0.09375, 0.25, 0.0029296875, 0.4375, 0.03125, 0.4375, 0.25, 0.15625, 0.375, 0.25, 0.4375, 0.25, 0.15625, 0.3125, 0.15625, 0.3125, 0.4375, 0.375, 0.125, 0.375, 0.25, 0.375, 0.09375, 0.00244140625, 0.125, 0.0546875, 0.4375, 0.3125, 0.3125, 0.3125, 0.21875, 0.4375, 0.0625, 0.25, 0.078125, 0.3125, 0.15625, 0.3125, 0.1875, 0.4375, 0.09375, 0.09375, 0.09375, 0.4375, 0.3125, 0.4375, 0.0390625, 0.001953125, 0.046875, 0.078125, 0.3125, 0.3125, 0.4375, 0.109375, 0.078125, 0.375, 0.125, 0.375, 0.1875, 0.375, 0.0546875, 0.109375, 0.3125, 0.25, 0.15625, 0.375, 0.375, 0.0625, 0.109375, 0.4375, 0.15625, 0.15625, 0.0078125, 0.09375, 0.25, 0.1875, 0.0625, 0.25, 0.078125, 0.03125, 0.109375, 0.02734375, 0.078125, 0.4375, 0.25, 0.3125, 0.015625, 0.078125, 0.15625, 0.25, 0.02734375, 0.125, 0.093

In [25]:
# for C header we need to add '0b' to the binary representation
print("Weights before: ", weights[0][0])
weights_t = ['0b' + x for x in weights.flatten().tolist()]
print("Weights after: ", weights_t[0])

Weights before:  00110111
Weights after:  0b00110111


In [26]:
# convert bias to FP8
bias_fp8 = [float32_to_fp8(x) for x in bias.flatten().tolist()]
# reshape bias_fp8 back to tensor shape
bias_fp8 = np.array(bias_fp8).reshape(bias.shape)
print(bias_fp8.shape)
bias = bias_fp8
decimal_bias = [convert_to_fp8_decimal(x) for x in bias.flatten().tolist()]
print(decimal_bias)

(16,)
[0.375, 0.4375, 0.3125, 0.4375, 0.125, 0.078125, 0.046875, 0.1875, 0.21875, 0.09375, 0.15625, 0.15625, 0.4375, 0.21875, 0.4375, 0.25]


In [47]:
float32_to_fp8(0.5)

'00111000'

In [46]:
convert_to_fp8_decimal('00100100')

0.015625

In [27]:
# for C header we need to add '0b' to the binary representation
print("Bias before: ", bias[0])
bias_t = ['0b' + x for x in bias.flatten().tolist()]
print("Bias after: ", bias_t[0])

Bias before:  00110110
Bias after:  0b00110110


In [115]:
type(bias_fp8[0].item())

str

In [97]:
print(activations)
# print data type of the activations
print(activations.dtype)

tensor([[26.3321, 26.9016, 27.8120, 30.3491, 27.2808, 25.2449, 28.9580, 26.0793,
         28.7988, 27.2119, 25.8306, 28.1869, 27.7081, 24.9478, 27.7745, 29.6327]])
torch.float32


In [25]:
np.max(activations.numpy())

30.31

In [38]:
activations.to(torch.float32)

tensor([[26.2812, 26.8750, 27.7812, 30.3125, 27.2812, 25.2812, 28.9375, 26.1094,
         28.7188, 27.1562, 25.7969, 28.0938, 27.6719, 25.0000, 27.7500, 29.6406]])

In [39]:
np.exp(activations.to(torch.float32)[0][0])

tensor(2.5930e+11)

In [43]:
# apply softmax to the activations
softmax = torch.nn.Softmax(dim=1)
# upcast the activations to float32 to use the softmax function
ff_out = softmax(activations.to(torch.float32))
# downcast the output to float16
ff_out = ff_out.to(torch.float16)
print(ff_out)
print(ff_out.shape)
print(ff_out.dtype)

tensor([[0.0072, 0.0130, 0.0321, 0.4031, 0.0194, 0.0026, 0.1019, 0.0060, 0.0818,
         0.0172, 0.0044, 0.0438, 0.0287, 0.0020, 0.0311, 0.2058]],
       dtype=torch.float16)
torch.Size([1, 16])
torch.float16


In [61]:
# transform softmax activations to list
ff_out_l = ff_out.tolist()[0]
# if index matches label, subtract 1 from value at index
ff_out_l[label] = ff_out_l[label] - 1
print(ff_out_l)
# print the bias gradient 
bias_gradients = torch.FloatTensor(ff_out_l).reshape(1, -1).to(torch.float16)
print(bias_gradients)
print(bias_gradients.shape)
print(bias_gradients.dtype)

[0.007152557373046875, 0.0129547119140625, 0.032073974609375, 0.403076171875, 0.019439697265625, 0.00263214111328125, 0.10186767578125, -0.9939765930175781, 0.08184814453125, 0.0171661376953125, 0.004405975341796875, 0.0438232421875, 0.0287322998046875, 0.0019855499267578125, 0.03106689453125, 0.205810546875]
tensor([[ 0.0072,  0.0130,  0.0321,  0.4031,  0.0194,  0.0026,  0.1019, -0.9941,
          0.0818,  0.0172,  0.0044,  0.0438,  0.0287,  0.0020,  0.0311,  0.2058]],
       dtype=torch.float16)
torch.Size([1, 16])
torch.float16


In [53]:
# compute the weight gradient matrix
weight_gradients = torch.mul(input.t(), bias_gradients).t()
print(weight_gradients.shape)
# compute the checksum for every column of the weight gradient matrix
weight_gradients_checksum = torch.sum(weight_gradients, dim=1)
print(weight_gradients_checksum)
print(weight_gradients_checksum.shape)
print(weight_gradients_checksum.dtype)

torch.Size([16, 784])
tensor([   0.7720,    1.3984,    3.4629,   43.5000,    2.0977,    0.2842,
          10.9922, -107.3125,    8.8359,    1.8525,    0.4756,    4.7305,
           3.1016,    0.2144,    3.3535,   22.2188], dtype=torch.float16)
torch.Size([16])
torch.float16


In [51]:
# compute the training step
bias_update = bias - torch.mul(bias_gradients, 0.5)
print("bias_update = ", bias_update)
print(bias_update.shape)
print(bias_update.dtype)
weight_update = weights - torch.mul(weight_gradients, 0.5)
weight_update_checksum = torch.sum(weight_update, dim=1)
print("\nweight_update_checksum = ", weight_update_checksum)
print(weight_update_checksum.shape)
print(weight_update_checksum.dtype)

bias_update =  tensor([[ 0.3965,  0.4460,  0.3086,  0.2477,  0.1212,  0.0851, -0.0017,  0.7075,
          0.1843,  0.0993,  0.1624,  0.1475,  0.4768,  0.2312,  0.4370,  0.1713]],
       dtype=torch.float16)
torch.Size([1, 16])
torch.float16

weight_update_checksum =  tensor([188.3750, 197.7500, 191.1250, 176.7500, 190.5000, 195.2500, 197.7500,
        248.8750, 197.3750, 191.1250, 198.1250, 197.1250, 193.0000, 189.7500,
        194.1250, 180.6250], dtype=torch.float16)
torch.Size([16])
torch.float16


In [119]:
torch.is_tensor(bias[0])

False

In [39]:
bias[0].item().dtype

AttributeError: 'float' object has no attribute 'dtype'

In [45]:
# calculate the memory requirements
if(torch.is_tensor(bias[0]) and bias[0].dtype == torch.float64):
    print(f'Input size: {IN_CH * 64 / 8 / 1024} KB')
    print(f'Weights size: {OUT_CH * IN_CH * 64 / 8 / 1024} KB')
    print(f'Bias size: {OUT_CH * 64 / 8 / 1024} KB')
    print(f'Output size: {OUT_CH * 64 / 8 / 1024} KB')
    print(f'\nTotal size: {(IN_CH + OUT_CH * IN_CH + OUT_CH) * 64 / 8 / 1024} KB')
elif(torch.is_tensor(bias[0]) and bias[0].item().dtype == torch.float32):
    print(f'Input size: {IN_CH * 32 / 8 / 1024} KB')
    print(f'Weights size: {OUT_CH * IN_CH * 32 / 8 / 1024} KB')
    print(f'Bias size: {OUT_CH * 32 / 8 / 1024} KB')
    print(f'Output size: {OUT_CH * 32 / 8 / 1024} KB')
    print(f'\nTotal size: {(IN_CH + OUT_CH * IN_CH + OUT_CH) * 32 / 8 / 1024} KB')
elif(torch.is_tensor(bias[0]) and bias[0].item().dtype == torch.float16):
    print(f'Input size: {IN_CH * 16 / 8 / 1024} KB')
    print(f'Weights size: {OUT_CH * IN_CH * 16 / 8 / 1024} KB')
    print(f'Bias size: {OUT_CH * 16 / 8 / 1024} KB')
    print(f'Output size: {OUT_CH * 16 / 8 / 1024} KB')
    print(f'\nTotal size: {(IN_CH + OUT_CH * IN_CH + OUT_CH) * 16 / 8 / 1024} KB')
else:
    print(f'Input size: {IN_CH * 8 / 8 / 1024} KB')
    print(f'Weights size: {OUT_CH * IN_CH * 8 / 8 / 1024} KB')
    print(f'Bias size: {OUT_CH * 8 / 8 / 1024} KB')
    print(f'Output size: {OUT_CH * 8 / 8 / 1024} KB')
    print(f'\nTotal size: {(IN_CH + OUT_CH * IN_CH + OUT_CH) * 8 / 8 / 1024} KB')

Input size: 8.0 KB
Weights size: 128.0 KB
Bias size: 0.125 KB
Output size: 0.125 KB

Total size: 136.125 KB


In [28]:
kwargs = {
            'IN_CH': IN_CH,
            'OUT_CH': OUT_CH,
            'DATASET_SIZE': 1,
            'INPUT': input_t, #input.to(torch.float16),
            'WEIGHTS': weights_t, #weights.detach().to(torch.float16),
            'BIASES': bias_t,#bias.detach().to(torch.float16),
            'LABELS': label,
            'prec': 8
}

In [29]:
emit_mnist_header_file('mnist', **kwargs)