In [None]:
import pathlib

import numpy as np
import matplotlib.pyplot as plt

from numpy.linalg import norm

import warnings

warnings.filterwarnings(action="once")
import tensorflow as tf
from tensorflow import keras

In [None]:
print("GPU Available: ", tf.test.is_gpu_available())
print("Eager execution enabled: ", tf.executing_eagerly())

# Load and rescale data

In [None]:
(train_images, train_labels), (
    test_images,
    test_labels,
) = tf.keras.datasets.mnist.load_data()

scale = tf.constant(255, dtype=tf.dtypes.float32)
x_train, x_test = train_images / scale, test_images / scale
y_train, y_test = tf.expand_dims(train_labels, 1), tf.expand_dims(test_labels, 1)

# mean = tf.math.reduce_mean(x_train)
# std = tf.math.reduce_std(x_train)
# x_train, x_test = (x_train-mean)/std, (x_test-mean)/std

# Define, compile, and train model

In [None]:
keras.backend.clear_session()


# single dense layer, i.e. multiple logistic regression
def build_model():
    return keras.Sequential(
        [
            keras.layers.Flatten(input_shape=(28, 28)),
            keras.layers.Dense(
                10, activation="softmax", kernel_regularizer=keras.regularizers.l1(1e-5)
            ),
        ]
    )


training_params = {
    "optimizer": "adam",
    "loss": "sparse_categorical_crossentropy",
    "metrics": ["accuracy"],
}

tf.random.set_seed(1234)
np.random.seed(1234)
model = build_model()
model.compile(**training_params)

model.summary()

In [None]:
# run the training
model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))

# Convert to TFLite and save to disk

In [None]:
models_dir = pathlib.Path("./models/")
models_dir.mkdir(exist_ok=True, parents=True)

### Float TFLite model

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_float_lite = converter.convert()

In [None]:
model_float_file = models_dir / "model_float.tflite"
size_float = model_float_file.write_bytes(model_float_lite)
print("Float model size: {:.0f} KB".format(size_float / 1024))

### Quantized TFLite model

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]  # this doesn't seem to do anything

# representative dataset to estimate activation distributions
x_train_ds = tf.data.Dataset.from_tensor_slices((x_train)).batch(1)


def representative_data_gen():
    for input_value in x_train_ds.take(x_train.shape[0]):
        yield [input_value]


converter.representative_dataset = representative_data_gen

model_quant_lite = converter.convert()

In [None]:
model_quant_file = models_dir / "model_quant.tflite"
size_quant = model_quant_file.write_bytes(model_quant_lite)
print("Quantized model size: {:.0f} KB".format(size_quant / 1024))

# Build interpreters and run inference on test set

In [None]:
interpreter_float = tf.lite.Interpreter(model_content=model_float_lite)
interpreter_float.allocate_tensors()
interpreter_quant = tf.lite.Interpreter(model_content=model_quant_lite)
interpreter_quant.allocate_tensors()

In [None]:
probabilities_float = np.NaN * np.zeros((y_test.shape[0], 10))
probabilities_quant = np.NaN * np.zeros((y_test.shape[0], 10))
probabilities = model(x_test).numpy()

for j, img in enumerate(x_test):
    img = tf.expand_dims(img, 0)
    interpreter_float.set_tensor(interpreter_float.get_input_details()[0]["index"], img)
    interpreter_float.invoke()
    probabilities_float[j] = interpreter_float.get_tensor(
        interpreter_float.get_output_details()[0]["index"]
    )

    interpreter_quant.set_tensor(interpreter_quant.get_input_details()[0]["index"], img)
    interpreter_quant.invoke()
    probabilities_quant[j] = interpreter_quant.get_tensor(
        interpreter_quant.get_output_details()[0]["index"]
    )

# Evaluate models

In [None]:
prob_abs_err_float = norm(probabilities_float - probabilities, axis=1)
prob_abs_err_quant = norm(probabilities_quant - probabilities, axis=1)
denom = norm(probabilities, axis=1)
prob_rel_err_float = prob_abs_err_float / denom
prob_rel_err_quant = prob_abs_err_quant / denom
print("Mean relative error of output activations compared to original model output:")
print("# Float TFLite model:     {:.5e}".format(np.mean(prob_rel_err_float)))
print("# Quantized TFLite model: {:.5e}".format(np.mean(prob_rel_err_quant)))

In [None]:
predictions_float = np.argmax(probabilities_float, axis=1)
predictions_quant = np.argmax(probabilities_quant, axis=1)
predictions = np.argmax(probabilities, axis=1)

acc = tf.metrics.Accuracy()
print("Accuracy of models:")
print("# Original keras model:   {:.2%}".format(acc(test_labels, predictions).numpy()))
print(
    "# Float TFLite model:     {:.2%}".format(
        acc(test_labels, predictions_float).numpy()
    )
)
print(
    "# Quantized TFLite model: {:.2%}".format(
        acc(test_labels, predictions_quant).numpy()
    )
)

# Interpreter surgery

In [None]:
# run interpreters on a single sample
img = tf.expand_dims(x_test[10], 0)
interpreter_float.set_tensor(interpreter_float.get_input_details()[0]["index"], img)
interpreter_float.invoke()
interpreter_quant.set_tensor(interpreter_quant.get_input_details()[0]["index"], img)
interpreter_quant.invoke()

### Float interpreter components

In [None]:
interpreter_float.get_tensor_details()

### Quantized interpreter components

In [None]:
interpreter_quant.get_tensor_details()

### Retrieve input image and its quantization, compare

In [None]:
img_float = interpreter_float.get_tensor(1)[0].copy()
img_quant_float = interpreter_quant.get_tensor(5)[0].copy()
img_quant_int8 = interpreter_quant.get_tensor(1)[0].copy()
img_quantization = interpreter_quant.get_tensor_details()[1]["quantization"]

img_quant_int8_float = (
    np.float32(img_quant_int8) - img_quantization[1]
) * img_quantization[0]
img_quant_float_int8 = np.int8(
    img_quant_float / img_quantization[0] + img_quantization[1]
)
img_quant_diff = np.abs(
    (np.float32(img_quant_int8) - img_quantization[1]) * img_quantization[0]
    - img_quant_float
)

In [None]:
im_dict = {
    "float input": img_float,
    "quant float input": img_quant_float,
    "quant int8 input": img_quant_int8,
    "quant inputs' diff": img_quant_diff,
    "float from quant int8": img_quant_int8_float,
    "int8 from quant float": img_quant_float_int8,
}

plt.figure(figsize=(16, 4))
for j, (title, im) in enumerate(im_dict.items()):
    plt.subplot(1, len(im_dict), j + 1)
    kwargs = {"vmin": 0, "vmax": 1} if title == "quant inputs' diff" else dict()
    plt.imshow(im, cmap="gray", **kwargs)
    plt.title(title)
plt.grid(False)
plt.show()

### Demonstrate that the bug corrupts the internal state

In [None]:
# TODO: file a bug report

interpreter_quant.set_tensor(
    interpreter_quant.get_input_details()[0]["index"],
    tf.expand_dims(img_quant_int8_float, 0),
)
interpreter_quant.invoke()
print("Output with corrupted image:")
print(
    interpreter_quant.get_tensor(
        interpreter_quant.get_output_details()[0]["index"]
    ).flatten()
)

interpreter_quant.set_tensor(
    interpreter_quant.get_input_details()[0]["index"],
    tf.expand_dims(img_quant_float, 0),
)
interpreter_quant.invoke()
print("Output with uncorrupted image:")
print(
    interpreter_quant.get_tensor(
        interpreter_quant.get_output_details()[0]["index"]
    ).flatten()
)

### Retrieve weights and quantizations, compare

In [None]:
weights_quant = interpreter_quant.get_tensor(3)
weights_float = interpreter_float.get_tensor(3)
weights_quantization = interpreter_quant.get_tensor_details()[3]["quantization"]

weights_quant_diff = np.abs(
    np.float32(weights_quant) - weights_float / weights_quantization[0]
)
weights_rel_err = norm(weights_quant_diff) / norm(np.float32(weights_quant))
print(
    "Mean relative error between quantized and float weights: {:.4%}".format(
        weights_rel_err
    )
)

### Weight visualization

In [None]:
w = weights_quant.reshape(-1, 28, 28)
plt.figure(figsize=(16, 7))
for j in range(10):
    plt.subplot(2, 5, j + 1)
    plt.imshow(w[j, :, :], vmin=-128, vmax=127)
    plt.title("Digit {}".format(j))
plt.show()

### Distribution of weight quantization errors

In [None]:
w = weights_quant_diff.reshape(-1, 28, 28)
plt.figure(figsize=(16, 1))
for j in range(10):
    plt.subplot(1, 10, j + 1)
    plt.hist(w[j, :, :].reshape(-1))
    plt.title("Digit {}".format(j))
plt.subplots_adjust(wspace=0.5)
plt.show()

### Retrieve biases and quantizations, compare

In [None]:
bias_quant = interpreter_quant.get_tensor(4)
bias_float = interpreter_float.get_tensor(4)
bias_quantization = interpreter_quant.get_tensor_details()[4]["quantization"]

bias_quant_diff = np.abs(
    np.float32(bias_quant) - bias_quantization[1] - bias_float / bias_quantization[0]
)
bias_rel_err = norm(bias_quant_diff) / norm(np.float32(bias_quant))
print(
    "Mean relative error between quantized and float matmul bieses: {:.4%}".format(
        bias_rel_err
    )
)

### Retrieve preactivations and quantizations, compare

In [None]:
# NOTE: the tensor dense/BiasAdd is actually a (pre)activation, not a bias
preact_quant = interpreter_quant.get_tensor(2)
preact_float = interpreter_float.get_tensor(2)
preact_quantization = interpreter_quant.get_tensor_details()[2]["quantization"]

preact_quant_diff = np.abs(
    np.float32(preact_quant)
    - preact_quantization[1]
    - preact_float / preact_quantization[0]
)
preact_rel_err = norm(preact_quant_diff) / norm(np.float32(preact_quant))
print(
    "Mean relative error between quantized and float preactivations: {:.4%}".format(
        preact_rel_err
    )
)

### Retrieve outputs and quantizations, compare

In [None]:
output_float = interpreter_float.get_tensor(
    interpreter_float.get_output_details()[0]["index"]
)
output_quant_float = interpreter_quant.get_tensor(
    interpreter_quant.get_output_details()[0]["index"]
)
output_quant_int8 = interpreter_quant.get_tensor(0)
output_quantization = interpreter_quant.get_tensor_details()[0]["quantization"]

output_quant_diff = np.abs(
    np.float32(output_quant_int8)
    - output_quantization[1]
    - output_float / output_quantization[0]
)
output_rel_err = norm(output_quant_diff) / norm(np.float32(output_quant_int8))
print(
    "Mean relative error between quantized and float outputs: {:.4%}".format(
        output_rel_err
    )
)

# Interpreter reconstruction

In [None]:
# float interpreter
rec_preact_float = np.matmul(weights_float, img_float.flatten()) + bias_float
rec_out_float = tf.math.softmax(rec_preact_float).numpy()

with np.printoptions(formatter={"float": "{:.6e}".format}):
    print("Reconstructed output:\n{}".format(rec_out_float))
    print("Original output:\n{}".format(output_float.flatten()))
    print(
        "Relative error: {:.6e}".format(
            norm(rec_out_float - output_float.flatten()) / norm(output_float.flatten())
        )
    )

In [None]:
# int weights converted to float and float input (from quant model), compared quant output
rec_preact_float2 = (
    np.matmul(
        np.float32(weights_quant) * weights_quantization[0], img_quant_float.flatten()
    )
    + bias_quant * bias_quantization[0]
)

rec_out_float2 = tf.math.softmax(rec_preact_float2).numpy()
with np.printoptions(formatter={"float": "{:.6e}".format}):
    print("Reconstructed output:\n{}".format(rec_out_float2))
    print("Original output:\n{}".format(output_float.flatten()))
    print(
        "Relative error: {:.6e}".format(
            norm(rec_out_float2 - output_float.flatten()) / norm(output_float.flatten())
        )
    )

In [None]:
# int weights converted to float and int input converted to float
# NOTE: because of the above bug, float->int8->float converted image is used
rec_preact_float3 = (
    np.matmul(
        np.float32(weights_quant) * weights_quantization[0],
        (np.float32(img_quant_float_int8) - img_quantization[1]).flatten()
        * img_quantization[0],
    )
    + bias_quant * bias_quantization[0]
)
rec_out_float3 = tf.math.softmax(rec_preact_float3).numpy()

with np.printoptions(formatter={"float": "{:.6e}".format}):
    print("Reconstructed output:\n{}".format(rec_out_float3))
    print("Original output:\n{}".format(output_float.flatten()))
    print(
        "Relative error: {:.6e}".format(
            norm(rec_out_float3 - output_float.flatten()) / norm(output_float.flatten())
        )
    )

In [None]:
# int weights and int input, using 32 bit accumulation and 32 bit bias
rec_preact_int32 = (
    np.matmul(np.int32(weights_quant), np.int32(img_quant_float_int8).flatten())
    - np.matmul(
        np.int32(weights_quant),
        np.int32(img_quantization[1] * np.ones(img_quant_float_int8.size)),
    )
    + bias_quant
)
rec_out_int = tf.math.softmax(rec_preact_int32 * bias_quantization[0]).numpy()

with np.printoptions(formatter={"float": "{:.6e}".format}):
    print("Reconstructed output:\n{}".format(rec_out_int))
    print("Original output:\n{}".format(output_float.flatten()))
    print(
        "Relative error: {:.6e}".format(
            norm(rec_out_int - output_float.flatten()) / norm(output_float.flatten())
        )
    )

# XS3 emulation and scaling

In [None]:
from XS3VPU import XS3VPU


def compute_chunk(vpu, W, x, W_start, W_step, x_start):
    # ~ 17 instructions
    vpu.VLDC(x[x_start : x_start + vpu.ve])
    rw = W_start
    for _ in range(vpu.acc_period):  # unroll in asm
        vpu.VLMACCR(W[rw : rw + vpu.ve])
        rw += W_step


def compute_tile(vpu, W, x, N_chunks, W_start, W_step, W_chunk_step, x_start, x_step):
    # ~ N_chunks * (17 + 2) + 5
    rx = x_start
    rw = W_start
    for _ in range(N_chunks):
        compute_chunk(vpu, W, x, W_start=rw, W_step=W_step, x_start=rx)
        rx += x_step
        rw += W_chunk_step


def XS3_matmul(vpu, W, x, y, N_bands, N_chunks):
    # ~ N_bands * (N_chunks * (17 + 2) + 5 + 8) + 5
    rw = 0
    ry = 0
    for _ in range(N_bands):
        vpu.VCLRDR()  # TODO add bias loading
        compute_tile(
            vpu,
            W,
            x,
            N_chunks,
            W_start=rw,
            W_step=N_chunks * vpu.ve,
            W_chunk_step=vpu.ve,
            x_start=0,
            x_step=vpu.ve,
        )
        y[ry : ry + vpu.acc_period] = vpu._combine_vD_vR()  # VLSAT, VPOS, VSTRPV
        rw += vpu.acc_period * N_chunks * vpu.ve
        ry += vpu.acc_period


def XS3_fc_forward_int32(W, b, x):
    vpu = XS3VPU(bpe=8)
    y = np.zeros((16,), dtype=np.int32).flatten()
    XS3_matmul(vpu, W, x, y, N_bands=1, N_chunks=800 // vpu.ve)
    return y + b

In [None]:
pad0 = 16 - weights_quant.shape[0]
pad1 = weights_quant.shape[1] - 32 * (weights_quant.shape[1] // 32)
weights_xs3 = np.pad(weights_quant, pad_width=[(0, pad0), (0, pad1)])
weights_xs3 = np.flipud(weights_xs3).flatten()

data_xs3 = np.pad(img_quant_float_int8.flatten(), pad_width=[(0, pad1)])

bias_xs3 = bias_quant - np.matmul(
    np.int32(weights_quant),
    np.int32(img_quantization[1] * np.ones(img_quant_float_int8.size)),
)
bias_xs3 = np.pad(bias_xs3, pad_width=[(0, pad0)])

y_xs3 = XS3_fc_forward_int32(weights_xs3, bias_xs3, data_xs3)

preact_xs3_int32 = y_xs3[:-pad0]

In [None]:
print("int32 preactivation values produced by XS3 emulation (without offset):")
print(preact_xs3_int32)
print("int32 preactivation values produced by int32 accumulation:")
print(rec_preact_int32)

In [None]:
# compare to float preactivation
rec_preact_xs3 = preact_xs3_int32 * bias_quantization[0]

with np.printoptions(formatter={"float": "{:.6e}".format}):
    print("Reconstructed preactivation (xs3):\n{}".format(rec_preact_xs3))
    print("Original preactivation:\n{}".format(preact_float.flatten()))
    print(
        "Relative error: {:.6e}".format(
            norm(rec_preact_xs3 - preact_float.flatten()) / norm(preact_float.flatten())
        )
    )

### Calculating final bias

In [None]:
def sat_to_16(a):
    return np.int16(np.round(np.clip(a, -(2**15) + 1, 2**15 - 1)))


def sat_to_8(a):
    return np.int8(np.round(np.clip(a, -(2**7) + 1, 2**7 - 1)))


def XS3_fc_forward_int8(W, b, x, rshift, scale):
    y = XS3_fc_forward_int32(W, b, x)
    preact_xs3_int32_offset = y[:-pad0]
    preact_xs3_int32_vlsat = sat_to_16(preact_xs3_int32_offset / 2**rshift)
    preact_xs3_int32_vlmul = sat_to_16(preact_xs3_int32_vlsat * scale / 2**14)
    return sat_to_8(preact_xs3_int32_vlmul / 2**7)  # this is what VDEPTH8 would do

In [None]:
# the final bias is calculated here
# this includes the output offset, so that fused activations are already applied
bias_scale = bias_quantization[0]
output_scale, output_zero_point = preact_quantization
multiplier = bias_scale / output_scale

rshift = -np.ceil(np.log2(multiplier)) + 1
scale = np.round(2**14 * (multiplier * 2**rshift))
if scale == 2**15:
    rshift -= 1
    scale /= 2
rshift -= 7

bias_xs3_offset = bias_xs3 + np.int32(output_zero_point / multiplier)

preact_xs3_int8 = XS3_fc_forward_int8(
    weights_xs3, bias_xs3_offset, data_xs3, rshift, scale
)

print("int8 preactivation values using xs3 emulation:")
print(preact_xs3_int8)
print("int8 preactivation values produced by tflite:")
print(preact_quant[0])

### Evaluate performance of the XS3 emulation

In [None]:
import sys
from multiprocessing import Pool


def eval_pred(args):
    j, im = args
    data_xs3 = np.pad(
        np.int8(im + img_quantization[1]).flatten(), pad_width=[(0, pad1)]
    )
    preact_xs3_int8 = XS3_fc_forward_int8(
        weights_xs3, bias_xs3_offset, data_xs3, rshift, scale
    )
    if (j + 1) % 10 == 0:
        print("{:6d}/10000".format(j + 1))
        sys.stdout.flush()

    return np.argmax(preact_xs3_int8)


predictions_xs3 = np.zeros(predictions.shape, dtype=np.int64)
p = Pool(10)
predictions_xs3 = p.map(eval_pred, enumerate(test_images))

In [None]:
acc = tf.metrics.Accuracy()
print("Accuracy of models:")
print("# Original keras model:   {:.2%}".format(acc(test_labels, predictions).numpy()))
print(
    "# Float TFLite model:     {:.2%}".format(
        acc(test_labels, predictions_float).numpy()
    )
)
print(
    "# Quantized TFLite model: {:.2%}".format(
        acc(test_labels, predictions_quant).numpy()
    )
)
print(
    "# Emulated XS3 model:     {:.2%}".format(acc(test_labels, predictions_xs3).numpy())
)

# Convert tflite model graph

In [None]:
from tflite_utils import load_tflite_as_json, save_json_as_tflite
from tflite2xcore_utils import (
    clean_unused_opcodes,
    clean_unused_tensors,
    clean_unused_buffers,
)
from tflite2xcore_graph_conv import remove_float_inputs_outputs

model_quant_stripped_file = "models/model_quant_stripped.tflite"

json_model = load_tflite_as_json(model_quant_file)
remove_float_inputs_outputs(json_model)
clean_unused_opcodes(json_model)
clean_unused_tensors(json_model)
clean_unused_buffers(json_model)
save_json_as_tflite(json_model, model_quant_stripped_file)