In [None]:
import pathlib

import numpy as np
import matplotlib.pyplot as plt

from numpy.linalg import norm

import warnings
warnings.filterwarnings(action='once')
import tensorflow as tf

In [None]:
print("GPU Available: ", tf.test.is_gpu_available())
print("Eager execution enabled: ", tf.executing_eagerly())

# Load and rescale data

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

scale = tf.constant(255, dtype=tf.dtypes.float32)
x_train, x_test = train_images/scale, test_images/scale
y_train, y_test = tf.expand_dims(train_labels, 1), tf.expand_dims(test_labels, 1)

#mean = tf.math.reduce_mean(x_train)
#std = tf.math.reduce_std(x_train)
#x_train, x_test = (x_train-mean)/std, (x_test-mean)/std

# Define, compile, and train model

In [None]:
from tensorflow import keras

# single dense layer, i.e. multiple logistic regression
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(10, activation='softmax')
])

training_params = {'optimizer': 'adam',
                   'loss': 'sparse_categorical_crossentropy',
                   'metrics': ['accuracy']}

tf.random.set_seed(123)
np.random.seed(123)
model.compile(**training_params)

model.summary()

In [None]:
# run the training
model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))

# Convert to TFLite and save to disk

In [None]:
models_dir = pathlib.Path("./mnist_models/")
models_dir.mkdir(exist_ok=True, parents=True)

### Float TFLite model

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_float_lite = converter.convert()

In [None]:
model_float_file = models_dir/"model_float.tflite"
size_float = model_float_file.write_bytes(model_float_lite)
print('Float model size: {:.0f} KB'.format(size_float/1024))

### Quantized TFLite model

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
#converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]  # this doesn't seem to do anything

# representative dataset to estimate activation distributions
x_train_ds = tf.data.Dataset.from_tensor_slices((x_train)).batch(1)
def representative_data_gen():
    for input_value in x_train_ds.take(100):
        yield [input_value]
converter.representative_dataset = representative_data_gen

model_quant_lite = converter.convert()

In [None]:
model_quant_file = models_dir/"model_quant.tflite"
size_quant = model_quant_file.write_bytes(model_quant_lite)
print('Quantized model size: {:.0f} KB'.format(size_quant/1024))

# Build interpreters and run inference on test set

In [None]:
interpreter_float = tf.lite.Interpreter(model_content=model_float_lite)
interpreter_float.allocate_tensors()
interpreter_quant = tf.lite.Interpreter(model_content=model_quant_lite)
interpreter_quant.allocate_tensors()

In [None]:
probabilities_float = np.NaN*np.zeros((y_test.shape[0], 10))
probabilities_quant = np.NaN*np.zeros((y_test.shape[0], 10))
probabilities = model(x_test).numpy()

for j, img in enumerate(x_test):
    img = tf.expand_dims(img, 0)
    interpreter_float.set_tensor(interpreter_float.get_input_details()[0]["index"], img)
    interpreter_float.invoke()
    probabilities_float[j] = interpreter_float.get_tensor(interpreter_float.get_output_details()[0]["index"])
    
    interpreter_quant.set_tensor(interpreter_quant.get_input_details()[0]["index"], img)
    interpreter_quant.invoke()
    probabilities_quant[j] = interpreter_quant.get_tensor(interpreter_quant.get_output_details()[0]["index"])

# Evaluate models

In [None]:
prob_abs_err_float = norm(probabilities_float-probabilities, axis=1)
prob_abs_err_quant = norm(probabilities_quant-probabilities, axis=1)
denom = norm(probabilities, axis=1)
prob_rel_err_float = prob_abs_err_float / denom
prob_rel_err_quant = prob_abs_err_quant / denom
print('Mean relative error of output activations compared to original model output:')
print('# Float TFLite model:     {:.5e}'.format(np.mean(prob_rel_err_float)))
print('# Quantized TFLite model: {:.5e}'.format(np.mean(prob_rel_err_quant)))

In [None]:
predictions_float = np.argmax(probabilities_float, axis=1)
predictions_quant = np.argmax(probabilities_quant, axis=1)
predictions = np.argmax(probabilities, axis=1)

acc = tf.metrics.Accuracy()
print('Accuracy of models:')
print('# Original keras model:   {:.2%}'.format(acc(test_labels, predictions).numpy()))
print('# Float TFLite model:     {:.2%}'.format(acc(test_labels, predictions_float).numpy()))
print('# Quantized TFLite model: {:.2%}'.format(acc(test_labels, predictions_quant).numpy()))

# Interpreter surgery

In [None]:
# run interpreters on a single sample
img = tf.expand_dims(x_test[10], 0)
interpreter_float.set_tensor(interpreter_float.get_input_details()[0]["index"], img)
interpreter_float.invoke()
interpreter_quant.set_tensor(interpreter_quant.get_input_details()[0]["index"], img)
interpreter_quant.invoke()

### Float interpreter components

In [None]:
interpreter_float.get_tensor_details()

### Quantized interpreter components

In [None]:
interpreter_quant.get_tensor_details()

### Retrieve input image and its quantization, compare

In [None]:
img_float = interpreter_float.get_tensor(1)[0].copy()
img_quant_float = interpreter_quant.get_tensor(5)[0].copy()
img_quant_int8 = interpreter_quant.get_tensor(1)[0].copy()
img_quantization = interpreter_quant.get_tensor_details()[1]['quantization']

img_quant_int8_float = (np.float32(img_quant_int8) - img_quantization[1])*img_quantization[0]
img_quant_float_int8 = np.int8(img_quant_float/img_quantization[0] + img_quantization[1])
img_quant_diff = np.abs((np.float32(img_quant_int8) - img_quantization[1]) * img_quantization[0] - img_quant_float)

In [None]:
im_dict = {"float input": img_float,
           "quant float input": img_quant_float,
           "quant int8 input": img_quant_int8,
           "quant inputs' diff": img_quant_diff,
           "float from quant int8": img_quant_int8_float,
           "int8 from quant float": img_quant_float_int8}

plt.figure(figsize=(16,4))
for j, (title, im) in enumerate(im_dict.items()):
    plt.subplot(1, len(im_dict), j+1)
    kwargs = {'vmin':0, 'vmax':1} if title == "quant inputs' diff" else dict()
    plt.imshow(im, cmap='gray', **kwargs)
    plt.title(title)
plt.grid(False)
plt.show()

### Demonstrate that the bug corrupts the internal state

In [None]:
# TODO: file a bug report

interpreter_quant.set_tensor(interpreter_quant.get_input_details()[0]["index"], tf.expand_dims(img_quant_int8_float, 0))
interpreter_quant.invoke()
print('Output with corrupted image:')
print(interpreter_quant.get_tensor(interpreter_quant.get_output_details()[0]["index"]).flatten())

interpreter_quant.set_tensor(interpreter_quant.get_input_details()[0]["index"], tf.expand_dims(img_quant_float, 0))
interpreter_quant.invoke()
print('Output with uncorrupted image:')
print(interpreter_quant.get_tensor(interpreter_quant.get_output_details()[0]["index"]).flatten())

### Retrieve weights and quantizations, compare

In [None]:
weights_quant = interpreter_quant.get_tensor(3)
weights_float = interpreter_float.get_tensor(3)
weights_quantization = interpreter_quant.get_tensor_details()[3]['quantization']

weights_quant_diff = np.abs(np.float32(weights_quant) - weights_float / weights_quantization[0])
weights_rel_err = norm(weights_quant_diff) / norm(np.float32(weights_quant))
print('Mean relative error between quantized and float weights: {:.4%}'.format(weights_rel_err))

### Weight visualization

In [None]:
w = weights_quant.reshape(-1, 28, 28)
plt.figure(figsize=(16, 7))
for j in range(10):
    plt.subplot(2, 5, j+1)
    plt.imshow(w[j,:,:], vmin=-128, vmax=127)
    plt.title('Digit {}'.format(j))
plt.show()

### Distribution of weight quantization errors

In [None]:
w = weights_quant_diff.reshape(-1, 28, 28)
plt.figure(figsize=(16, 1))
for j in range(10):
    plt.subplot(1, 10, j+1)
    plt.hist(w[j,:,:].reshape(-1))
    plt.title('Digit {}'.format(j))
plt.subplots_adjust(wspace=.5)
plt.show()

### Retrieve biases and quantizations, compare

In [None]:
bias_quant = interpreter_quant.get_tensor(4)
bias_float = interpreter_float.get_tensor(4)
bias_quantization = interpreter_quant.get_tensor_details()[4]['quantization']

bias_quant_diff = np.abs(np.float32(bias_quant) - bias_quantization[1] \
                                - bias_float / bias_quantization[0])
bias_rel_err = norm(bias_quant_diff) / norm(np.float32(bias_quant))
print('Mean relative error between quantized and float matmul bieses: {:.4%}'.format(bias_rel_err))

### Retrieve preactivations and quantizations, compare

In [None]:
# NOTE: the tensor dense/BiasAdd is actually a preactivation, not a bias
preact_quant = interpreter_quant.get_tensor(2)
preact_float = interpreter_float.get_tensor(2)
preact_quantization = interpreter_quant.get_tensor_details()[2]['quantization']

preact_quant_diff = np.abs(np.float32(preact_quant) - preact_quantization[1] - preact_float / preact_quantization[0])
preact_rel_err = norm(preact_quant_diff) / norm(np.float32(preact_quant))
print('Mean relative error between quantized and float preactivations: {:.4%}'.format(preact_rel_err))

### Retrieve outputs and quantizations, compare

In [None]:
output_float = interpreter_float.get_tensor(interpreter_float.get_output_details()[0]["index"])
output_quant_float = interpreter_quant.get_tensor(interpreter_quant.get_output_details()[0]["index"])
output_quant_int8 = interpreter_quant.get_tensor(0)
output_quantization = interpreter_quant.get_tensor_details()[0]['quantization']

output_quant_diff = np.abs(np.float32(output_quant_int8) - output_quantization[1] \
                     - output_float / output_quantization[0])
output_rel_err = norm(output_quant_diff) / norm(np.float32(output_quant_int8))
print('Mean relative error between quantized and float outputs: {:.4%}'.format(output_rel_err))

# Interpreter reconstruction

In [None]:
# float interpreter
rec_preact_float = np.matmul(weights_float, img_float.flatten()) + bias_float
rec_out_float = tf.math.softmax(rec_preact_float).numpy()

with np.printoptions(formatter={'float': '{:.6e}'.format}):
    print("Reconstructed output:\n{}".format(rec_out_float))
    print("Original output:\n{}".format(output_float.flatten()))
    print("Relative error: {:.6e}".format(norm(rec_out_float-output_float.flatten())/norm(output_float.flatten())))

In [None]:
# int weights converted to float and float input (from quant model), compared quant output
rec_preact_float2 = np.matmul(
    np.float32(weights_quant)*weights_quantization[0],
    img_quant_float.flatten()
) + bias_quant*bias_quantization[0]

rec_out_float2 = tf.math.softmax(rec_preact_float2).numpy()
with np.printoptions(formatter={'float': '{:.6e}'.format}):
    print("Reconstructed output:\n{}".format(rec_out_float2))
    print("Original output:\n{}".format(output_float.flatten()))
    print("Relative error: {:.6e}".format(
        norm(rec_out_float2-output_float.flatten())/norm(output_float.flatten())))

In [None]:
# int weights converted to float and int input converted to float
# NOTE: because of the above bug, float->int8->float converted image is used
rec_preact_float3 = np.matmul(
    np.float32(weights_quant)*weights_quantization[0],
    (np.float32(img_quant_float_int8) - img_quantization[1]).flatten()*img_quantization[0]
) + bias_quant*bias_quantization[0]
rec_out_float3 = tf.math.softmax(rec_preact_float3).numpy()

with np.printoptions(formatter={'float': '{:.6e}'.format}):
    print("Reconstructed output:\n{}".format(rec_out_float3))
    print("Original output:\n{}".format(output_float.flatten()))
    print("Relative error: {:.6e}".format(
        norm(rec_out_float3-output_float.flatten())/norm(output_float.flatten())))

In [None]:
# int weights and int input, using 32 bit accumulation and 32 bit bias
rec_preact_int = np.matmul(np.int32(weights_quant),
                           np.int32(img_quant_float_int8).flatten()) \
    - np.matmul(np.int32(weights_quant),
                np.int32(img_quantization[1]*np.ones(img_quant_float_int8.size))) \
    + bias_quant
rec_out_int = tf.math.softmax(rec_preact_int*bias_quantization[0]).numpy()

with np.printoptions(formatter={'float': '{:.6e}'.format}):
    print("Reconstructed output:\n{}".format(rec_out_int))
    print("Original output:\n{}".format(output_float.flatten()))
    print("Relative error: {:.6e}".format(
        norm(rec_out_int-output_float.flatten())/norm(output_float.flatten())))

# XS3 emulation and scaling

In [None]:
# these are XS3 hardware parameters
bpv, bpe, vac = 256, 8, 8
ve = bpv//bpe

### Calculate int16 bias values for XS

In [None]:
# NOTE: on XS3 the accumulator for int8 vector operation is 2 x int8 (vR/vD), hence the int16 bias
# NOTE: to avoid saturation while computing the dot prod, the bias should be spread out between the elements
#       Thus vR/vD should be initialized with the spread out values in all elements
# TODO: there might be a better strategy to apply the bias
# TODO: storing the biases in int8 might be okay too, investigate
unified_bias = bias_quant - \
    np.matmul(np.int32(weights_quant),
              np.int32(img_quantization[1]*np.ones(img_quant_float_int8.size)))
unified_bias = unified_bias / np.float32(2**(bpe-2))  # the shift here is b/c of how VLMACCR works
unified_bias_int16_ve = np.round(unified_bias / ve)  # spread out the bias between vR/vD elements
unified_bias_int16_ve = np.int16(np.clip(unified_bias_int16_ve, -2**(2*bpe-1), 2**(2*bpe-1)-1))

print("These are int16 bias values that the XS3 implementation should store:")
print(unified_bias_int16_ve)

### Define functions that emulate vector unit on XS

In [None]:
# NOTE: vacc is equivalent to the vR/vD pair in XS3

def VLMACCR(a, b, vacc):
    assert len(a) == len(b) == len(vacc)
    t = np.round(np.int16(a)*np.int16(b) / np.float32(2**(bpe-2)))  # multiply, round+shift
    t = sum(t) + np.float32(vacc[-1])  # sum and apply bias from buffer
    t = np.clip(t, -2**(bpe+vac-1)+1, 2**(bpe+vac-1)-1)  # this is how VLMACCR saturates in XS3
    vacc = np.hstack([np.int16(t), vacc[:-1]])  # update buffer
    return vacc

def VLSAT(v, s=0):
    t = np.round(np.float32(v) / 2**s)
    t = np.clip(t, -2**(bpe-1), 2**(bpe-1)-1)
    return np.int8(t)

def VLREDSUM(v, s=0):  # this actually doesn't exists in the XS3 ISA (yet?)
    vacc = np.zeros(ve, dtype=np.int16)
    # use VLMACCR to do the summation, so we need a 2**(bpe-2) shift
    vacc = VLMACCR(v, np.int8(np.ones(v.shape) * 2**(bpe-2-s)), vacc)
    return vacc

In [None]:
from math import ceil

def XS3_dot_prod(v, w, bias_vacc=None, scale1=4, scale2=1):
    assert len(v) == len(w)
    num_vlmaccr = ceil(len(v)/ve)  # this is a trick
    
    # add bias (distributed accross all elements of vacc)
    if bias_vacc is None:
        vacc = np.zeros(ve, dtype=np.int16)
    else:
        vacc = bias_vacc
    
    pad = num_vlmaccr * ve - len(v)
    v, w = np.pad(v, (0, pad)), np.pad(w, (0, pad))
    for n in range(num_vlmaccr):
        beg, end = n*ve, (n+1)*ve
        vacc = VLMACCR(v[beg:end], w[beg:end], vacc)

    # saturate vector register contents
    # NOTE: the shift here is our choice, maybe optimize for it?
    # if this shift is too small, saturation will occure often
    # if it's too large, we loose less significant digits, which also leads to loss in accuracy
    vR = VLSAT(vacc, scale1)
    
    # sum contents of vector register
    # result is int16 value stored at the beginning of vR/vD
    # NOTE: the shift here is our choice, maybe optimize for it?
    
    # use this for more accuracy
    vacc = VLREDSUM(vR)
    vR = VLSAT(vacc, scale2)
    
    # less accurate but one less instruction
    #vR = np.int8(VLREDSUM(vR, scale2))

    return vR[0]  # result is int8

# TODO: there is probably a more efficient way to do matrix-vector multiplication
def XS3_fcc_forward(input_int8, weights_int8, bias_int16_ve,
                    scale1, scale2):
    bias_int16_vacc = np.tile(bias_int16_ve, (ve, 1)).T  # just a copy for easy access
    output_int8 = np.zeros(weights_quant.shape[0], dtype=np.int8)
    for feature_num in range(10):
        w = weights_int8[feature_num]  # these are the feature coeffs
        bias_vacc = bias_int16_vacc[feature_num]
        output_int8[feature_num] = XS3_dot_prod(w, input_int8, bias_vacc,
                                                scale1, scale2)

    return output_int8

### Calculate preactivation on xs3, compare

If the only goal is classification, getting the argmax of the preactivation is sufficient.

In [None]:
scale1, scale2 = 4, 1  # TODO: find these scales by optimization on the training set
preact_xs3 = XS3_fcc_forward(input_int8=img_quant_float_int8.flatten(),
                             weights_int8=weights_quant, bias_int16_ve=unified_bias_int16_ve,
                             scale1=scale1, scale2=scale2)
print("int8 preactivation values produced by XS3 emulation:")
print(preact_xs3)
print("int8 preactivation values produced by int32 accumulation:")
print(np.int8(np.round(rec_preact_int / np.float32(2**(bpe-2)) / 2**(scale1+scale2))))

In [None]:
# compare to float preactivation
rec_preact_xs3 = preact_xs3 * np.float32(2**(bpe-2)) * 2**(scale1+scale2) * bias_quantization[0]

with np.printoptions(formatter={'float': '{:.6e}'.format}):
    print("Reconstructed preactivation (xs3):\n{}".format(rec_preact_xs3))
    print("Original preactivation:\n{}".format(preact_float.flatten()))
    print("Relative error: {:.6e}".format(
        norm(rec_preact_xs3-preact_float.flatten())/norm(preact_float.flatten())))

### Evaluate performance of the XS3 emulation

In [None]:
# this takes a while because the XS3 emulation is very inefficient
predictions_xs3 = np.zeros(predictions.shape, dtype=np.int64)
for j, im in enumerate(test_images):  #.shape, img_quant_float_int8.shape
    preact_xs3 = XS3_fcc_forward(input_int8=np.int8(im+img_quantization[1]).flatten(),
                                 weights_int8=weights_quant, bias_int16_ve=unified_bias_int16_ve,
                                 scale1=scale1, scale2=scale2)
    predictions_xs3[j] = np.argmax(preact_xs3)
    if (j+1) % 10 == 0:
        print('{:6d}/10000'.format(j+1), end='\r')
print()

In [None]:
acc = tf.metrics.Accuracy()
print('Accuracy of models:')
print('# Original keras model:   {:.2%}'.format(acc(test_labels, predictions).numpy()))
print('# Float TFLite model:     {:.2%}'.format(acc(test_labels, predictions_float).numpy()))
print('# Quantized TFLite model: {:.2%}'.format(acc(test_labels, predictions_quant).numpy()))
print('# Emulated XS3 model:     {:.2%}'.format(acc(test_labels, predictions_xs3).numpy()))