<a href="https://colab.research.google.com/github/rosan-international/video_classifier_bugs/blob/main/quantizing_movinet_stream.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transfer learning for video classification with MoViNet

MoViNets (Mobile Video Networks) provide a family of efficient video classification models, supporting inference on streaming video. We take a pre-trained MoViNet model to classify custom actions. A pre-trained model is a saved network that was previously trained on a larger dataset. You can find more details about MoViNets in the [MoViNets: Mobile Video Networks for Efficient Video Recognition](https://arxiv.org/abs/2103.11511) paper by Kondratyuk, D. et al. (2021). This notebook will: 

* Download a pre-trained MoViNet model
* Create a new model using a pre-trained model with a new classifier by freezing the convolutional base of the MoViNet model
* Replace the classifier head with the number of labels of a new dataset
* Perform transfer learning on the custom actions dataset
* Export the trained/tuned MoViNet-A2-Stream model, including a int8 quantized version
* Illustrate streaming action recognition with a sample video

The model downloaded in this notebook is from [official/projects/movinet](https://github.com/tensorflow/models/tree/master/official/projects/movinet). This repository contains a collection of MoViNet models that TF Hub uses in the TensorFlow 2 SavedModel format. The transfer learning strategy is adapted from https://github.com/tensorflow/models/blob/master/official/projects/movinet/movinet_streaming_model_training_and_inference.ipynb



# Setup

## Install Libraries

In [3]:
!pip install git+https://github.com/okankop/vidaug.git
!pip install apache_beam
import vidaug 
from vidaug import augmentors as va

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/okankop/vidaug.git
  Cloning https://github.com/okankop/vidaug.git to /tmp/pip-req-build-ssf_u6hj
  Running command git clone --filter=blob:none --quiet https://github.com/okankop/vidaug.git /tmp/pip-req-build-ssf_u6hj
  Resolved https://github.com/okankop/vidaug.git to commit 1c1ddf2640fe4a9171267d64ae5e3bd70c24d54a
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
!pip install -U -q "tf-models-official"
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy remotezip
!pip install -U -q git+https://github.com/tensorflow/docs
!pip install tensorflow --upgrade
!pip install --upgrade pandas_profiling

  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Import Libraries

In [14]:
import os, shutil
import functools
import tqdm
import random
import pathlib
import imageio
import itertools
import collections
import cv2
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import PIL
import mediapy as media
import math
import keras
import tensorflow as tf
import tensorflow_hub as hub
from zipfile import ZipFile
from tensorflow_docs.vis import embed
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy

In [15]:
import numpy as np

In [16]:
%matplotlib inline

## Distribution strategy

In [17]:
# Detect hardware
try:
  tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
except ValueError:
  tpu_resolver = None
  gpus = tf.config.experimental.list_logical_devices("GPU")

# Select appropriate distribution strategy
if tpu_resolver:
  tf.config.experimental_connect_to_cluster(tpu_resolver)
  tf.tpu.experimental.initialize_tpu_system(tpu_resolver)
  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu_resolver)
  print('Running on TPU ', tpu_resolver.cluster_spec().as_dict()['worker'])
elif len(gpus) > 1:
  distribution_strategy = tf.distribute.MirroredStrategy([gpu.name for gpu in gpus])
  print('Running on multiple GPUs ', [gpu.name for gpu in gpus])
elif len(gpus) == 1:
  distribution_strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
  print('Running on single GPU ', gpus[0].name)
else:
  distribution_strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
  print('Running on CPU')

print("Number of accelerators: ", distribution_strategy.num_replicas_in_sync)

Running on single GPU  /device:GPU:0
Number of accelerators:  1


## Tuning parameters

In [18]:
#See original paper for recommended parameters: https://arxiv.org/pdf/2103.11511.pdf

clean = True
random_seed = 789
model_id = 'a2'
resolution = 172
num_frames = 30
frame_steps = 5
batch_size = 10

# Load data

The cell below defines helper functions to upload the custom actions, and load it into a `tf.data.Dataset`. The `FrameGenerator` class at the end of the hidden block is the most important utility here. It creates an iterable object that can feed data into the TensorFlow data pipeline. Specifically, this class contains a Python generator that loads the video frames along with its encoded label. The generator (`__call__`) function yields the frame array produced by `frames_from_video_file` and a one-hot encoded vector of the label associated with the set of frames.

## Download directory

In [19]:
download_dir = pathlib.Path('/content/custom_actions_subset/')

## Data Builder

In [20]:
# custom_action Data  Builder

def list_files_per_class(zip_path):
  """
    List the files in each class of the dataset given the zip path.

    Args:
      zip_path: URL from which the files can be unzipped. 

    Return:
      files: List of files in each of the classes.
  """
  files = []
  with ZipFile(zip_path, 'r') as zip:
    for zip_info in zip.infolist():
      files.append(zip_info.filename)
  return files

def get_class(fname):
    # Get the directory path from the file path
    dir_path = os.path.dirname(fname)
   
    # Get the folder name from the directory path
    folder_name = os.path.basename(dir_path)
   
    # Return the folder name
    return folder_name

def get_files_per_class(files):
  """
    Retrieve the files that belong to each class. 

    Args:
      files: List of files in the dataset.

    Return:
      Dictionary of class names (key) and files (values).
  """
  files_for_class = collections.defaultdict(list)
  for fname in files:
    class_name = get_class(fname)
    files_for_class[class_name].append(fname)
  return files_for_class

def download_from_zip(zip_path, to_dir, file_names):
  """
    Download the contents of the zip file from the zip path.

    Args:
      zip_path: Zip path containing data.
      to_dir: Directory to download data to.
      file_names: Names of files to download.
  """
  with ZipFile(zip_path, 'r') as zip:
    for fn in tqdm.tqdm(file_names):
      class_name = get_class(fn)
      zip.extract(fn, str(to_dir / class_name))
      unzipped_file = to_dir / class_name / fn

      fn = pathlib.Path(fn).parts[-1]
      output_file = to_dir / class_name / fn
      unzipped_file.rename(output_file,)

def split_class_lists(files_for_class, count):
  """
    Returns the list of files belonging to a subset of data as well as the remainder of
    files that need to be downloaded.

    Args:
      files_for_class: Files belonging to a particular class of data.
      count: Number of files to download.

    Return:
      split_files: Files belonging to the subset of data.
      remainder: Dictionary of the remainder of files that need to be downloaded.
  """
  split_files = []
  remainder = {}
  for cls in files_for_class:
    split_files.extend(files_for_class[cls][:count])
    remainder[cls] = files_for_class[cls][count:]
  return split_files, remainder

def upload_custom_actions_subset(zip_path, num_classes, splits, download_dir):
  """
    Download a subset of the custom_actions dataset and split them into various parts, such as
    training, validation, and test. 

    Args:
      zip_path: Zip path containing data.
      num_classes: Number of labels.
      splits: Dictionary specifying the training, validation, test, etc. (key) division of data 
              (value is number of files per split).
      download_dir: Directory to download data to.

    Return:
      dir: Posix path of the resulting directories containing the splits of data.
  """
  files = list_files_per_class(zip_path)
  for f in files:
    tokens = f.split('/')
    if len(tokens) <= 1:
      files.remove(f) # Remove that item from the list if it does not have a filename

  files_for_class = get_files_per_class(files)

  classes = list(files_for_class.keys())[:num_classes]

  for cls in classes:
    new_files_for_class = files_for_class[cls]
    random.seed(random_seed) # Added random seed for reproducibility
    random.shuffle(new_files_for_class)
    files_for_class[cls] = new_files_for_class

  # Only use the number of classes you want in the dictionary
  files_for_class = {x: files_for_class[x] for x in list(files_for_class)[:num_classes]}

  dirs = {}
  for split_name, split_count in splits.items():
    print(split_name, ":")
    split_dir = download_dir / split_name
    split_files, files_for_class = split_class_lists(files_for_class, split_count)
    download_from_zip(zip_path, split_dir, split_files)
    dirs[split_name] = split_dir

  return dirs

def format_frames(frame, output_size):
  """
    Pad and resize an image from a video.

    Args:
      frame: Image that needs to resized and padded. 
      output_size: Pixel size of the output frame image.

    Return:
      Formatted frame with padding of specified output size.
  """
  frame = tf.image.convert_image_dtype(frame, tf.float32)
  #frame = tf.image.resize_with_pad(frame, *output_size)
  frame = tf.image.resize(frame, [resolution, resolution], method='nearest')  # Modified original resize_with_pad as it left some black padding. 
  return frame

def frames_from_video_file(video_path, n_frames, output_size = (resolution, resolution), frame_step = frame_steps):
  """
    Creates frames from each video file present for each category.

    Args:
      video_path: File path to the video.
      n_frames: Number of frames to be created per video file.
      output_size: Pixel size of the output frame image.

    Return:
      An NumPy array of frames in the shape of (n_frames, height, width, channels).
  """
  # Read each video frame by frame
  result = []
  src = cv2.VideoCapture(str(video_path))  

  video_length = src.get(cv2.CAP_PROP_FRAME_COUNT)
  need_length = 1 + (n_frames - 1) * frame_step

  if need_length > video_length:
    start = 0
  else:
    max_start = video_length - need_length
    start = random.randint(0, max_start + 1)

  src.set(cv2.CAP_PROP_POS_FRAMES, start)
  # ret is a boolean indicating whether read was successful, frame is the image itself
  ret, frame = src.read()
  result.append(format_frames(frame, output_size))

  for _ in range(n_frames - 1):
    for _ in range(frame_step):
      ret, frame = src.read()
    if ret:
      frame = format_frames(frame, output_size)
      result.append(frame)
    else:
      result.append(np.zeros_like(result[0]))
  src.release()
  result = np.array(result)[..., [2, 1, 0]]

  return result

def to_gif(images):
  converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)
  imageio.mimsave('/content/animation.gif', converted_images, fps=10)
  return embed.embed_file('./animation.gif')

class FrameGenerator:
  def __init__(self, path, n_frames, training = False):
    """ Returns a set of frames with their associated label. 

      Args:
        path: Video file paths.
        n_frames: Number of frames. 
        training: Boolean to determine if training dataset is being created.
    """
    self.path = path
    self.n_frames = n_frames
    self.training = training
    self.class_names = sorted(set(p.name for p in self.path.iterdir() if p.is_dir()))
    self.class_ids_for_name = dict((name, idx) for idx, name in enumerate(self.class_names))

  def get_files_and_class_names(self):
    video_paths = list(self.path.glob('*/*.avi'))
    classes = [p.parent.name for p in video_paths] 
    return video_paths, classes

  def __call__(self):
    video_paths, classes = self.get_files_and_class_names()

    pairs = list(zip(video_paths, classes))

    if self.training:
      random.shuffle(pairs)

    for path, name in pairs:
      video_frames = frames_from_video_file(path, self.n_frames) 
      label = self.class_ids_for_name[name] # Encode labels
      yield video_frames, label

## Clean custom_actions_subset directory (Optional)

Only executed if clean = true (See "Tuning parameters)*texto en cursiva*


In [21]:
## If need to clean the whole directory, replace '/content/custom_actions_subset' with '/content' below

if os.path.exists(download_dir) and clean:
    for filename in os.listdir(download_dir):
        file_path = os.path.join(download_dir, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))       

## Upload files

### Upload train/test files

In [22]:
# If uploading from local drive, code below will pop up menu.

#from google.colab import files
#uploaded = files.upload()

# If uploading from drive, code below will pop up menu.

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [23]:
subset_paths = upload_custom_actions_subset(zip_path = '/content/drive/MyDrive/custom_action_vids/videos_trim.zip', 
                        num_classes = 3, # 'Action1', 'Action2', 'Action3'
                        splits = {"train": train_size, "test": test_size}, # Split per class (e.g. if "train": 10, "test": 5, 10 'Action3' videos go to training set, 5 to testing set, and same for the other classes)
                        download_dir = download_dir)


train :


100%|██████████| 30/30 [00:48<00:00,  1.63s/it]


test :


100%|██████████| 15/15 [00:16<00:00,  1.11s/it]


### Upload validation file

In [24]:
subset_paths2 = upload_custom_actions_subset(zip_path = '/content/drive/MyDrive/custom_action_vids/sequence.zip', 
                        num_classes = 1,
                        splits = {"val": 1},
                        download_dir = download_dir)

val :


100%|██████████| 1/1 [00:00<00:00, 593.17it/s]


Check if video is actually in custom_actions_subset/val, otherwise upload manually from unzipped video in drive or local folder

## Check number of frames per video

In [25]:
def calculate_video_stats(directory):
    total_length = 0
    num_videos = 0
    max_length = 0
    min_length = float('inf')
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith('.avi'):
                path = os.path.join(root, file)
                src = cv2.VideoCapture(path)  
                video_length = src.get(cv2.CAP_PROP_FRAME_COUNT)
                total_length += video_length
                num_videos += 1
                if video_length > max_length:
                    max_length = video_length
                if video_length < min_length:
                    min_length = video_length
                src.release()

    if num_videos > 0:
        avg_length = total_length / num_videos
        print(f"Average video length: {avg_length:.2f} frames")
        print(f"Maximum video length: {max_length:.0f} frames")
        print(f"Minimum video length: {min_length:.0f} frames")
    else:
        print("No .avi files found in directory")

In [26]:
calculate_video_stats('/content/custom_actions_subset/train')

Average video length: 164.80 frames
Maximum video length: 178 frames
Minimum video length: 79 frames


In [27]:
calculate_video_stats('/content/custom_actions_subset/test')

Average video length: 167.75 frames
Maximum video length: 215 frames
Minimum video length: 126 frames


In [28]:
calculate_video_stats('/content/custom_actions_subset/val')

No .avi files found in directory


## Prepare train, valid and test dataset

In [29]:
CLASSES = sorted(os.listdir('/content/custom_actions_subset/train'))

output_signature = (tf.TensorSpec(shape = (None, None, None, 3), dtype = tf.float32),
                    tf.TensorSpec(shape = (), dtype = tf.int16))

train_ds = tf.data.Dataset.from_generator(FrameGenerator(subset_paths['train'], num_frames, training = True),
                                          output_signature = output_signature)
test_ds = tf.data.Dataset.from_generator(FrameGenerator(subset_paths['test'], num_frames),
                                         output_signature = output_signature)
val_ds = tf.data.Dataset.from_generator(FrameGenerator(subset_paths2['val'], 150),
                                         output_signature = output_signature)

# Convert to TF Lite quantized to int8 for edge devices (Raspberry Pi 3 A+)
Source: https://github.com/tensorflow/models/blob/8d92444cc4aed9a2d9746cbde882c05f4a0e748c/official/projects/movinet/tools/quantize_movinet.py

In [None]:
!wget https://raw.githubusercontent.com/tensorflow/models/8d92444cc4aed9a2d9746cbde882c05f4a0e748c/official/projects/movinet/tools/quantize_movinet.py

In [9]:
sh = """
python3 quantize_movinet.py \
--saved_model_dir='/model' \
--saved_model_with_states_dir='/model/init_states' \
--output_dataset_dir='/model_quantized' \
--output_tflite='/model_quantized' \
--quantization_mode='int8' \
--save_dataset_to_tfrecords=True
"""
with open('script.sh', 'w') as file:
  file.write(sh)

!bash script.sh
# See ValueError: `tfds_name` is , but `tfds_split` is not specified.
# This error is associated with quantize_movinet.py attempting to load the kinetics600 dataset, which is not on tfds



TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

Traceback (most recent call last):
  File "/content/quantize_movinet.py", line 331, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/content/quantize_movinet.py", line 321, in main
    tflite_buffer = quantize_movinet(dataset_fn=get_dataset)
  File "/content/quantize_movinet.py", line 268, in quantize_movinet
    valid_dataset = dataset_fn()
  File "/content/quantize_movinet.py", line 168,

In [1]:
import functools
from typing import Any, Callable, Mapping, Optional
from absl import app
from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_hub as hub

from official.vision.configs import video_classification as video_classification_configs
from official.vision.tasks import video_classification


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [6]:
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy()  # BytesList won't unpack string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _build_tf_example(feature):
  return tf.train.Example(
      features=tf.train.Features(feature=feature)).SerializeToString()

def save_to_tfrecord(input_frame: tf.Tensor,
                     input_states: Mapping[str, tf.Tensor],
                     frame_index: int,
                     predictions: tf.Tensor,
                     output_states: Mapping[str, tf.Tensor],
                     groundtruth_label_id: tf.Tensor,
                     output_dataset_dir: str,
                     file_index: int):
  """Save results to tfrecord."""
  features = {}
  features['frame_id'] = _int64_feature([frame_index])
  features['groundtruth_label'] = _int64_feature(
      groundtruth_label_id.numpy().flatten().tolist())
  features['predictions'] = _float_feature(
      predictions.numpy().flatten().tolist())
  image_string = tf.io.encode_png(
      tf.squeeze(tf.cast(input_frame * 255., tf.uint8), axis=[0, 1]))
  features['image'] = _bytes_feature(image_string.numpy())

  # Input/Output states at time T
  for k, v in output_states.items():
    dtype = v[0].dtype
    if dtype == tf.int32:
      features['input/' + k] = _int64_feature(
          input_states[k].numpy().flatten().tolist())
      features['output/' + k] = _int64_feature(
          output_states[k].numpy().flatten().tolist())
    elif dtype == tf.float32:
      features['input/' + k] = _float_feature(
          input_states[k].numpy().flatten().tolist())
      features['output/' + k] = _float_feature(
          output_states[k].numpy().flatten().tolist())
    else:
      raise ValueError(f'Unrecongized dtype: {dtype}')

  tfe = _build_tf_example(features)
  record_file = '{}/movinet_stream_{:06d}.tfrecords'.format(
      output_dataset_dir, file_index)
  logging.info('Saving to %s.', record_file)
  with tf.io.TFRecordWriter(record_file) as writer:
    writer.write(tfe)


def stateful_representative_dataset_generator(
    model: tf.keras.Model,
    dataset_iter: Any,
    init_states: Mapping[str, tf.Tensor],
    save_dataset_to_tfrecords: bool = False,
    max_saved_files: int = 100,
    output_dataset_dir: Optional[str] = None,
    num_samples_per_video: int = num_frames,
    num_calibration_videos: int = 10):
  """Generates sample input data with states.

  Args:
    model: the inference keras model.
    dataset_iter: the dataset source.
    init_states: the initial states for the model.
    save_dataset_to_tfrecords: whether to save the representative dataset to
      tfrecords on disk.
    max_saved_files: the max number of saved tfrecords files.
    output_dataset_dir: the directory to store the saved tfrecords.
    num_samples_per_video: number of randomly sampled frames per video.
    num_calibration_videos: number of calibration videos to run.

  Yields:
    A dictionary of model inputs.
  """
  counter = 0
  for i in range(num_calibration_videos):
    if i % 100 == 0:
      logging.info('Reading representative dateset id %d.', i)

    example_input, example_label = next(dataset_iter)
    groundtruth_label_id = tf.argmax(tf.reshape(example_label, [1]), axis=0) # Replace original code: 'tf.argmax(example_label, axis=-1)' with tf.argmax(tf.reshape(example_label, [1]), axis=0)
    input_states = init_states
    # split video into frames along the temporal dimension.
    frames = tf.split(example_input, num_or_size_splits=num_frames, axis=0)
    frames = [tf.reshape(frame, shape=(1, 1, 172, 172, 3)) for frame in frames]

    random_indices = np.random.randint(
        low=1, high=len(frames), size=num_samples_per_video)
    # always include the first frame
    random_indices[0] = 0
    random_indices = set(random_indices)

    for frame_index, frame in enumerate(frames):
      predictions, output_states = model({'image': frame, **input_states})
      if frame_index in random_indices:
        if save_dataset_to_tfrecords and counter < max_saved_files:
          save_to_tfrecord(
              input_frame=frame,
              input_states=input_states,
              frame_index=frame_index,
              predictions=predictions,
              output_states=output_states,
              groundtruth_label_id=groundtruth_label_id,
              output_dataset_dir=output_dataset_dir,
              file_index=counter)
        yield {'image': frame, **input_states}
        counter += 1

      # update states for the next inference step
      input_states = output_states


def get_tflite_converter(
    saved_model_dir: str,
    quantization_mode: str,
    representative_dataset: Optional[Callable[..., Any]] = None
) -> tf.lite.TFLiteConverter:
  """Gets tflite converter."""
  converter = tf.lite.TFLiteConverter.from_saved_model(
      saved_model_dir=saved_model_dir)
  converter.optimizations = [tf.lite.Optimize.DEFAULT]

  if quantization_mode == 'float16':
    logging.info('Using float16 quantization.')
    converter.target_spec.supported_types = [tf.float16]

  elif quantization_mode == 'int8':
    logging.info('Using full interger quantization.')
    converter.representative_dataset = representative_dataset
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8

  elif quantization_mode == 'int_float_fallback':
    logging.info('Using interger quantization with float-point fallback.')
    converter.representative_dataset = representative_dataset

  else:
    logging.info('Using dynamic range quantization.')
  return converter

In [7]:
encoder = hub.KerasLayer('/model/init_states', trainable=False)



In [30]:
  inputs = tf.keras.layers.Input(
      shape=[1, resolution, resolution, 3],
      dtype=tf.float32,
      name='image')

  # Define the state inputs, which is a dict that maps state names to tensors.
  init_states_fn = encoder.resolved_object.signatures['init_states']
  state_shapes = {
      name: ([s if s > 0 else None for s in state.shape], state.dtype)
      for name, state in init_states_fn(
          tf.constant([1, 1, resolution, resolution, 3])).items()
  }
  states_input = {
      name: tf.keras.Input(shape[1:], dtype=dtype, name=name)
      for name, (shape, dtype) in state_shapes.items()
  }

  # The inputs to the model are the states and the video
  inputs = {**states_input, 'image': inputs}
  outputs = encoder(inputs)
  model = tf.keras.Model(inputs, outputs, name='movinet_stream')
  input_shape = tf.constant(
      [1, num_frames, resolution, resolution, 3])
  init_states = init_states_fn(input_shape)


  # config representative_datset_fn
  representative_dataset = functools.partial(
      stateful_representative_dataset_generator,
      model=model,
      dataset_iter=iter(test_ds),
      init_states=init_states,
      save_dataset_to_tfrecords=False,
      max_saved_files=100,
      output_dataset_dir='/model_quantized',
      num_samples_per_video=3,
      num_calibration_videos=10)


In [31]:
# From https://www.tensorflow.org/lite/performance/post_training_quantization

import tensorflow as tf
saved_model_dir = '/model'
quantized_dir = '/model.tflite.quantized'

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8  # or tf.uint8
converter.inference_output_type = tf.int8  # or tf.uint8
tflite_quant_model = converter.convert()
with open(quantized_dir, 'wb') as f:
  f.write(tflite_quant_model)

# Problems: Does not have quantize (start)/dequantize(send), and no quantization after buffer layers. 



### Show details of quantized model

In [32]:
# Load the Quantized TensorFlow Lite model
interpreter = tf.lite.Interpreter(model_path=quantized_dir)
# Allocate the tensors
interpreter.allocate_tensors()
# Get the details of the model's inputs and outputs
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Get the tensor details, including the weights
tensor_details = interpreter.get_tensor_details()
print(tensor_details)

[{'name': 'serving_default_state_block4_layer1_pool_frame_count:0', 'index': 0, 'shape': array([1], dtype=int32), 'shape_signature': array([1], dtype=int32), 'dtype': <class 'numpy.int32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'serving_default_state_block1_layer1_stream_buffer:0', 'index': 1, 'shape': array([  1,   2,  22,  22, 120], dtype=int32), 'shape_signature': array([  1,   2,  22,  22, 120], dtype=int32), 'dtype': <class 'numpy.int8'>, 'quantization': (0.1392422914505005, -125), 'quantization_parameters': {'scales': array([0.13924229], dtype=float32), 'zero_points': array([-125], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'serving_default_state_block1_layer2_stream_buffer:0', 'index': 2, 'shape': array([ 1,  2, 22, 22, 96], dtype=int32), 'shape_signature': array([ 1,  2, 22, 22, 96], dtype=i

### Check inference with quantized model

In [34]:
# Create the interpreter and signature runner

runner = interpreter.get_signature_runner()

# To run on a video, pass in one frame at a time
#init_states_fn = model.init_states
init_states = init_states_fn(tf.shape(tf.ones(shape=[1, 1, resolution, resolution, 3])))
states = init_states
for frames, label in list(test_ds.take(1)):
  if label.numpy() == 0:
    label_str = "Action1"
  elif label.numpy() == 1:
    label_str = "Action2"
  elif label.numpy() == 2:
    label_str = "Action3"
  print("True label = ", label_str)

  for clip in frames:
    # Input shape: [1, 1, 172, 172, 3]
    outputs = runner(**states, image=clip)
    logits = outputs.pop('logits')[0]
    states = outputs

probs = tf.nn.softmax(logits)
top_k = get_top_k(probs)
print()
print("Model probabilities for each label:")
for label, prob in top_k:
  print(label, prob)

frames, label = list(test_ds.take(1))[0]
to_gif(frames.numpy())

True label =  Action3


ValueError: ignored