In [None]:
# %matplotlib qt
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from IPython.display import clear_output
import tensorflow as tf
from tensorflow.keras import layers


class DFTModel(tf.keras.Model):
    def __init__(self):
        super(DFTModel, self).__init__()
        self.channel_dft = layers.Lambda(lambda x: tf.signal.fft2d(x))
        self.log_scale = layers.Lambda(lambda x: tf.math.log(tf.abs(x) + 1e-9))

    def call(self, inputs):
        x = tf.cast(inputs, tf.complex64)
        x = tf.transpose(x, [0,3,1,2]) # move channel axis to the front
        x = self.channel_dft(x)
        # Extract the real and imaginary parts
        xreal = tf.math.real(x)
        ximag = tf.math.imag(x)
        
        # Calculate the absolute value
        x = tf.math.sqrt(xreal**2+ximag**2)
        x = self.log_scale(x)
        x = tf.transpose(x, [0,2,3,1]) # move channel axis back
        x = (x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x))
        x *= 255
        x = tf.cast(x, tf.uint8)
        return x
    
model = DFTModel()


cap = cv2.VideoCapture(0)
# Set the width and height of the video
cap.set(3, 640)  # width
cap.set(4, 480)  # height
n = 0
mean = None
M2 = None
fig=plt.figure()
while True:
    ret, frame = cap.read()
    if not ret:
        break
    clear_output(wait=True)
    dft = model(frame[np.newaxis,...]).numpy()[0]
    
    # Update mean and standard deviation using Welford's method
    n += .1
    if mean is None:
        mean = dft
        M2 = np.ones(mean.shape)
    else:
        delta = dft - mean
        mean += (delta / n).astype(np.uint8)
        M2 += delta * (dft - mean)

    # Calculate standard deviation
    if n < 2:
        std = None
    else:
        std = np.sqrt(M2 / (n - 1))    
    
    final = np.concatenate([frame, dft, mean],axis=1)
    plt.imshow(cv2.cvtColor(final, cv2.COLOR_BGR2RGB))
    plt.show()

In [None]:
# Need to initialize the model if you didn't do inference above.
# model = DFTModel()
# x = np.zeros((1,480,640,3))
# y = model(x)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
open("dft_model.tflite", "wb").write(tflite_model)

interpreter = tf.lite.Interpreter(model_path="dft_model.tflite")

interpreter.allocate_tensors()

# help(interpreter)
interpreter.get_input_details()

In [None]:
def output_tensor(interpreter, i):
  """Gets a model's ith output tensor.

  Args:
    interpreter: The ``tf.lite.Interpreter`` holding the model.
    i (int): The index position of an output tensor.
  Returns:
    The output tensor at the specified position.
  """
  return interpreter.tensor(interpreter.get_output_details()[i]['index'])()


def input_details(interpreter, key):
  """Gets a model's input details by specified key.

  Args:
    interpreter: The ``tf.lite.Interpreter`` holding the model.
    key (int): The index position of an input tensor.
  Returns:
    The input details.
  """
  return interpreter.get_input_details()[0][key]


def input_size(interpreter):
  """Gets a model's input size as (width, height) tuple.

  Args:
    interpreter: The ``tf.lite.Interpreter`` holding the model.
  Returns:
    The input tensor size as (width, height) tuple.
  """
  _, height, width, _ = input_details(interpreter, 'shape')
  return width, height


def input_tensor(interpreter):
  """Gets a model's input tensor view as numpy array of shape (height, width, 3).

  Args:
    interpreter: The ``tf.lite.Interpreter`` holding the model.
  Returns:
    The input tensor view as :obj:`numpy.array` (height, width, 3).
  """
  tensor_index = input_details(interpreter, 'index')
  return interpreter.tensor(tensor_index)()[0]


def set_input(interpreter, data):
  """Copies data to a model's input tensor.

  Args:
    interpreter: The ``tf.lite.Interpreter`` to update.
    data: The input tensor.
  """
  input_tensor(interpreter)[:, :] = data

In [None]:
# set_input(interpreter, frame)

In [None]:
# interpreter.invoke()

In [None]:
# output_details = interpreter.get_output_details()[0]
# output_data = interpreter.tensor(output_details['index'])().copy()
# output_data.shape

In [None]:
cap = cv2.VideoCapture(0)
# Set the width and height of the video
cap.set(3, 640)  # width
cap.set(4, 480)  # height
while True:
    ret, frame = cap.read()
    if not ret:
        break
    clear_output(wait=True)
    x = tf.convert_to_tensor(frame)
    x = x[tf.newaxis,...]
    set_input(interpreter, x)
    interpreter.invoke()
    output_details = interpreter.get_output_details()[0]
    output_data = interpreter.tensor(output_details['index'])()[0].copy()
    final = np.concatenate([frame, output_data],axis=1)
    plt.imshow(cv2.cvtColor(final, cv2.COLOR_BGR2RGB))
    plt.show()