# Get weight values from .tflite file

- [https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)

In [0]:
import tensorflow as tf
import numpy as np

In [2]:
print('TensorFlow version: ', tf.__version__)

TensorFlow version:  1.13.1


## Load .tflite file

In [0]:
tfl_path = 'mnist_float.tflite'

In [0]:
interp = tf.lite.Interpreter(model_path=tfl_path, model_content=None)

## Show tensor info

- __tensor_list__ is list of dict ('name', 'index', 'shape', 'dtype', 'quantization')

In [5]:
tensor_list = interp.get_tensor_details()

for i in range(len(tensor_list)):
  print(i, ':', tensor_list[i]['name'], tensor_list[i]['shape'])

0 : activation/Relu [  1 256]
1 : activation_1/Softmax [ 1 10]
2 : batch_normalization_v1/FusedBatchNorm [ 1 28 28  1]
3 : batch_normalization_v1/FusedBatchNorm_add_param [1]
4 : batch_normalization_v1/FusedBatchNorm_mul_0 [ 1 28 28  1]
5 : batch_normalization_v1/FusedBatchNorm_mul_0_param [1]
6 : batch_normalization_v1_1/FusedBatchNorm [ 1 14 14 64]
7 : batch_normalization_v1_1/FusedBatchNorm_add_param [64]
8 : batch_normalization_v1_1/FusedBatchNorm_mul_0 [ 1 14 14 64]
9 : batch_normalization_v1_1/FusedBatchNorm_mul_0_param [64]
10 : batch_normalization_v1_2/FusedBatchNorm [  1   7   7 128]
11 : batch_normalization_v1_2/FusedBatchNorm_add_param [128]
12 : batch_normalization_v1_2/FusedBatchNorm_mul_0 [  1   7   7 128]
13 : batch_normalization_v1_2/FusedBatchNorm_mul_0_param [128]
14 : batch_normalization_v1_input [ 1 28 28  1]
15 : conv2d/Conv2D_bias [64]
16 : conv2d/Relu [ 1 28 28 64]
17 : conv2d/kernel [ 1  5  5 64]
18 : conv2d_1/Conv2D_bias [128]
19 : conv2d_1/Relu [  1  14  14 12

## Show some weight

In [12]:
print(interp.get_tensor(6))

[[[[ 0.23503728 -0.23302487 -0.22928101 ... -0.31781524 -0.19408007
    -0.20346361]
   [ 0.23503728 -0.23302487 -0.20962898 ... -0.31781524 -0.19408007
    -0.20346361]
   [ 0.23503728 -0.23302487 -0.20962898 ... -0.31781524 -0.19408007
    -0.20346361]
   ...
   [ 0.23503728 -0.13387382 -0.22928101 ...  0.99410504  0.3374843
    -0.20346361]
   [ 0.23503728 -0.23302487 -0.20962898 ... -0.31781524 -0.19408007
    -0.20346361]
   [ 0.23503728 -0.23302487 -0.22928101 ... -0.31781524 -0.19408007
    -0.20346361]]

  [[ 0.23503728 -0.23302487 -0.22928101 ... -0.31781524 -0.19408007
    -0.20346361]
   [ 0.23503728 -0.23302487 -0.06280945 ... -0.31781524 -0.19408007
    -0.20346361]
   [ 0.23503728 -0.23302487 -0.06280945 ... -0.31781524 -0.19408007
    -0.20346361]
   ...
   [ 0.23503728  0.1988411  -0.22928101 ... -0.02760372  1.5867637
    -0.20346361]
   [ 0.23503728 -0.23302487 -0.06280945 ... -0.31781524 -0.19408007
    -0.20346361]
   [ 0.23503728 -0.23302487 -0.22928101 ... -0.3178

***

## Inference

In [0]:
# Get MNIST test data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_test = x_test.astype(np.float32)
y_test = y_test.astype(np.float32)

In [0]:
interp.allocate_tensors()
input = interp.tensor(interp.get_input_details()[0]["index"])
output = interp.tensor(interp.get_output_details()[0]["index"])

In [10]:
corr = 0
for i in range(len(x_test)):
  x = np.expand_dims(np.expand_dims(x_test[i], axis=0), axis=3)
  interp.set_tensor(interp.get_input_details()[0]['index'], x)
  interp.invoke()
  if np.argmax(output()) == y_test[i]:
    corr = corr + 1
  else:
    print("fail: %s vs. %d" % (np.argmax(output()), y_test[i]) )

fail: 0 vs. 6
fail: 4 vs. 8
fail: 0 vs. 6
fail: 1 vs. 2
fail: 9 vs. 8
fail: 5 vs. 6
fail: 1 vs. 7
fail: 6 vs. 4
fail: 2 vs. 7
fail: 4 vs. 9
fail: 1 vs. 7
fail: 3 vs. 5
fail: 3 vs. 8
fail: 4 vs. 9
fail: 3 vs. 5
fail: 9 vs. 4
fail: 0 vs. 2
fail: 0 vs. 6
fail: 9 vs. 4
fail: 1 vs. 6
fail: 3 vs. 1
fail: 0 vs. 9
fail: 3 vs. 5
fail: 0 vs. 9
fail: 0 vs. 2
fail: 3 vs. 5
fail: 1 vs. 6
fail: 4 vs. 9
fail: 0 vs. 8
fail: 1 vs. 9
fail: 1 vs. 7
fail: 0 vs. 6
fail: 0 vs. 5
fail: 9 vs. 8
fail: 8 vs. 6
fail: 2 vs. 7
fail: 0 vs. 6
fail: 1 vs. 7
fail: 1 vs. 3
fail: 4 vs. 9
fail: 0 vs. 6
fail: 1 vs. 6
fail: 5 vs. 3
fail: 4 vs. 9
fail: 0 vs. 8
fail: 0 vs. 6
fail: 4 vs. 9
fail: 0 vs. 9
fail: 4 vs. 8
fail: 2 vs. 7
fail: 3 vs. 5
fail: 9 vs. 3
fail: 9 vs. 5
fail: 9 vs. 8
fail: 7 vs. 9
fail: 1 vs. 7
fail: 7 vs. 0
fail: 1 vs. 8
fail: 8 vs. 2
fail: 9 vs. 4
fail: 2 vs. 7
fail: 7 vs. 9
fail: 1 vs. 6
fail: 6 vs. 5


In [9]:
print('Accuracy: %d / %d' % (corr, len(x_test)))

Accuracy: 9936 / 10000
