In [1]:
import os
import warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import pandas as pd
import numpy as np
from gtda.time_series import SlidingWindow
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.keras.backend import set_session
config = tf.compat.v1.ConfigProto() 
config.gpu_options.allow_growth = True  
config.log_device_placement = True  
sess2 = tf.compat.v1.Session(config=config)
set_session(sess2) 
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.layers import Activation 
from tensorflow.keras.backend import sigmoid
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Activation, BatchNormalization, Flatten, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
import tensorflow_datasets as tfds
from tensorflow.keras.models import load_model
import tensorflow_hub as hub

import get_dataset as kws_data
import kws_util
import argparse
from tqdm import tqdm

Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:21:00.0, compute capability: 8.6



## Import Dataset

In [2]:
Flags, unparsed = kws_util.parse_command()
Flags.batch_size = 1
Flags.window_size_ms = 40.0
Flags.data_dir = '/home/nesl/209as_sec/audio_ks/data'
print('We will download data to {:}'.format(Flags.data_dir))
ds_train, ds_test, ds_val = kws_data.get_training_data(Flags)
print("Done getting data")
train_shuffle_buffer_size = 85511
val_shuffle_buffer_size = 10102
test_shuffle_buffer_size = 4890

ds_train = ds_train.shuffle(train_shuffle_buffer_size)
ds_val = ds_val.shuffle(val_shuffle_buffer_size)
ds_test = ds_test.shuffle(test_shuffle_buffer_size)

We will download data to /home/nesl/209as_sec/audio_ks/data
Done getting data


## Load Model and Perform Inference

In [3]:
output_data = []
labels = []
interpreter = tf.lite.Interpreter('kws_micronet_s.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
input_scale, input_zero_point = input_details[0]["quantization"]

for dat, label in tqdm(ds_test):
    if input_details[0]['dtype'] == np.float32:
        interpreter.set_tensor(input_details[0]['index'], dat)
    elif input_details[0]['dtype'] == np.int8:
        dat_q = np.array(dat/input_scale + input_zero_point, dtype=np.int8) 
        interpreter.set_tensor(input_details[0]['index'], dat_q)
    else:
        raise ValueError("TFLite file has input dtype {:}.  Only np.int8 and np.float32 are supported".format(
            input_details[0]['dtype']))
    interpreter.invoke()
    output_data.append(np.argmax(interpreter.get_tensor(output_details[0]['index'])))
    labels.append(label[0])

100%|███████████████████████████████████████| 4890/4890 [01:13<00:00, 66.26it/s]


In [4]:
num_correct = np.sum(np.array(labels) == output_data)
acc = num_correct / len(labels)
print('Accuracy:', acc)

Accuracy: 0.8439672801635992


In [5]:
output_data = []
labels = []
interpreter = tf.lite.Interpreter('kws_micronet_l.tflite')
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
input_scale, input_zero_point = input_details[0]["quantization"]

for dat, label in tqdm(ds_test):
    if input_details[0]['dtype'] == np.float32:
        interpreter.set_tensor(input_details[0]['index'], dat)
    elif input_details[0]['dtype'] == np.int8:
        dat_q = np.array(dat/input_scale + input_zero_point, dtype=np.int8) 
        interpreter.set_tensor(input_details[0]['index'], dat_q)
    else:
        raise ValueError("TFLite file has input dtype {:}.  Only np.int8 and np.float32 are supported".format(
            input_details[0]['dtype']))
    interpreter.invoke()
    output_data.append(np.argmax(interpreter.get_tensor(output_details[0]['index'])))
    labels.append(label[0])

100%|███████████████████████████████████████| 4890/4890 [08:44<00:00,  9.32it/s]


In [6]:
num_correct = np.sum(np.array(labels) == output_data)
acc = num_correct / len(labels)
print('Accuracy:', acc)

Accuracy: 0.8877300613496932
