In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import onnx
from onnx_tf.backend import prepare

import tensorflow as tf

from tinynn.converter import TFLiteConverter

In [None]:
# For saving to TFLite with proper Conv2D. 
!pip install git+https://github.com/alibaba/TinyNeuralNetwork.git

In [2]:

# Taken from https://github.com/vidursatija/BlazeFace-CoreML/blob/master/ML/blazeface.py
class ResModule(nn.Module):
	def __init__(self, in_channels, out_channels, stride=1):
		super(ResModule, self).__init__()
		self.stride = stride
		self.channel_pad = out_channels - in_channels
		# kernel size is always 3
		kernel_size = 3

		if stride == 2:
			self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
			padding = 0
		else:
			padding = (kernel_size - 1) // 2

		self.convs = nn.Sequential(
			nn.Conv2d(in_channels=in_channels, out_channels=in_channels, 
						kernel_size=kernel_size, stride=stride, padding=padding, 
						groups=in_channels, bias=True),
			nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
						kernel_size=1, stride=1, padding=0, bias=True),
		)

		self.act = nn.ReLU(inplace=True)

	def forward(self, x):
		if self.stride == 2:
			h = F.pad(x, (0, 2, 0, 2), "constant", 0)
			x = self.max_pool(x)
		else:
			h = x

		if self.channel_pad > 0:
			x = F.pad(x, (0, 0, 0, 0, 0, self.channel_pad), "constant", 0)

		return self.act(self.convs(h) + x)


class ResBlock(nn.Module):
	def __init__(self, in_channels):
		super(ResBlock, self).__init__()
		layers = [ResModule(in_channels, in_channels) for _ in range(7)]

		self.f = nn.Sequential(*layers)

	def forward(self, x):
		return self.f(x)


# From https://github.com/google/mediapipe/blob/master/mediapipe/models/palm_detection.tflite
class PalmDetector(nn.Module):
	def __init__(self):
		super(PalmDetector, self).__init__()

		self.backbone1 = nn.Sequential(
			nn.ConstantPad2d((0, 1, 0, 1), value=0.0),
			nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=0, bias=True),
			nn.ReLU(inplace=True),

			ResBlock(32),
			ResModule(32, 64, stride=2),
			ResBlock(64),
			ResModule(64, 128, stride=2),
			ResBlock(128)
		)

		self.backbone2 = nn.Sequential(
			ResModule(128, 256, stride=2),
			ResBlock(256)
		)

		self.backbone3 = nn.Sequential(
			ResModule(256, 256, stride=2),
			ResBlock(256)
		)

		self.upscale8to16 = nn.Sequential(
			nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2, padding=0, bias=True),
			nn.ReLU(inplace=True)
		)
		self.scaled16add = ResModule(256, 256)

		self.upscale16to32 = nn.Sequential(
			nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0, bias=True),
			nn.ReLU(inplace=True),
		)
		self.scaled32add = ResModule(128, 128)

		self.class_32 = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=1, stride=1, padding=0, bias=True)
		self.class_16 = nn.Conv2d(in_channels=256, out_channels=2, kernel_size=1, stride=1, padding=0, bias=True)
		self.class_8 = nn.Conv2d(in_channels=256, out_channels=6, kernel_size=1, stride=1, padding=0, bias=True)

		self.reg_32 = nn.Conv2d(in_channels=128, out_channels=36, kernel_size=1, stride=1, padding=0, bias=True)
		self.reg_16 = nn.Conv2d(in_channels=256, out_channels=36, kernel_size=1, stride=1, padding=0, bias=True)
		self.reg_8 = nn.Conv2d(in_channels=256, out_channels=108, kernel_size=1, stride=1, padding=0, bias=True)


	def forward(self, x):
		b1 = self.backbone1(x.permute(0, 3, 1, 2)) # 32x32 # 
		# print(b1.size())

		b2 = self.backbone2(b1) # 16x16
		# print(b2.size())

		b3 = self.backbone3(b2) # 8x8
		# print(b3.size())

		b2 = self.upscale8to16(b3) + b2 # 16x16
		b2 = self.scaled16add(b2) # 16x16
		# print(b2.size())

		b1 = self.upscale16to32(b2) + b1 # 32x32
		b1 = self.scaled32add(b1)
		# print(b1.size())

		c8 = self.class_8(b3).permute(0, 2, 3, 1).reshape(-1, 384, 2)
		c16 = self.class_16(b2).permute(0, 2, 3, 1).reshape(-1, 512, 2)
		c32 = self.class_32(b1).permute(0, 2, 3, 1).reshape(-1, 2048, 2)

		r8 = self.reg_8(b3).permute(0, 2, 3, 1).reshape(-1, 384, 18)
		r16 = self.reg_16(b2).permute(0, 2, 3, 1).reshape(-1, 512, 18)
		r32 = self.reg_32(b1).permute(0, 2, 3, 1).reshape(-1, 2048, 18)

		c = torch.cat([c32, c16, c8], dim=1)
		r = torch.cat([r32, r16, r8], dim=1) # needs to be anchored

		return r, c

	def modify_finetune_layers(self):
		for param in self.parameters():
			param.requires_grad = False
		self.class_32 = nn.Conv2d(in_channels=128, out_channels=4, kernel_size=1, stride=1, padding=0, bias=True)
		self.class_16 = nn.Conv2d(in_channels=256, out_channels=4, kernel_size=1, stride=1, padding=0, bias=True)
		self.class_8 = nn.Conv2d(in_channels=256, out_channels=12, kernel_size=1, stride=1, padding=0, bias=True)


	def load_weights(self, path):
	    self.load_state_dict(torch.load(path))
	    self.eval()
	
	def save_weights(self, path = "palmdetector_finetuned.pth"):
		torch.save(self.state_dict(), path)

	def save_to_tflite(self, path = 'palm_detection_finetuned.tflite'):
		# Uses the converter from https://github.com/alibaba/TinyNeuralNetwork. 
		# pip install git+https://github.com/alibaba/TinyNeuralNetwork.git
		# The reason to use this over converting to onnx and then to tensorflow pb and then to tflite is
		# that this converter takes care of converting pytorch's NCWH to NWHC. 
		# Pytorch uses NCWH for optimzied GPU training, while TfLite only supports NWHC for Conv2D operation. 
		# Converting NCWH will cause a transpose before convolution and another transpose after convolution. 
		# This causes the model to inference very very slow. 
		x = torch.rand(1, 3, 256, 256)
		converter = TFLiteConverter(m, x, path, input_transpose=True)
		converter.convert()

	def save_to_onnx(self, path = 'palm_detection_finetuned'):
		onnx_path = path + '.onnx'
		x = torch.autograd.Variable(torch.FloatTensor(torch.rand(1, 256, 256, 3)))
		torch.onnx.export(
			self, 
			x, 
			onnx_path, 
			export_params=True, 
			input_names=['input'], 
			output_names=['classifier', 'regressor'], 
			# opset_version=9, 
			do_constant_folding=False
		)
		print('Saved onnx at ' + onnx_path)

	def load_anchors(self, path):
	    self.anchors = torch.tensor(np.load(path), dtype=torch.float32)
	    assert(self.anchors.ndimension() == 2)
	    assert(self.anchors.shape[0] == 2944)
	    assert(self.anchors.shape[1] == 4)

	def _preprocess(self, x):
	    """Converts the image pixels to the range [-1, 1]."""
	    return x.float() / 127.5 - 1.0

	def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors):
	    detection_boxes = self._decode_boxes(raw_box_tensor, anchors)
	    
	    thresh = 100
	    raw_score_tensor = raw_score_tensor.clamp(-thresh, thresh)
	    detection_scores = raw_score_tensor.sigmoid().squeeze(dim=-1)
	    
	    # Note: we stripped off the last dimension from the scores tensor
	    # because there is only has one class. Now we can simply use a mask
	    # to filter out the boxes with too low confidence.
	    mask = detection_scores >= 0.7

	    # Because each image from the batch can have a different number of
	    # detections, process them one at a time using a loop.
	    output_detections = []
	    for i in range(raw_box_tensor.shape[0]):
	        boxes = detection_boxes[i, mask[i]]
	        scores = detection_scores[i, mask[i]].unsqueeze(dim=-1)
	        output_detections.append(torch.cat((boxes, scores), dim=-1))

	    return output_detections

	def predict_on_image(self, img):
	    """Makes a prediction on a single image.
	    Arguments:
	        img: a NumPy array of shape (H, W, 3) or a PyTorch tensor of
	             shape (3, H, W). The image's height and width should be 
	             128 pixels.
	    Returns:
	        A tensor with face detections.
	    """
	    if isinstance(img, np.ndarray):
	        img = torch.from_numpy(img).permute((2, 0, 1))

	    return self.predict_on_batch(img.unsqueeze(0))

	def predict_on_batch(self, x):
	    """Makes a prediction on a batch of images.
	    Arguments:
	        x: a NumPy array of shape (b, H, W, 3) or a PyTorch tensor of
	           shape (b, 3, H, W). The height and width should be 128 pixels.
	    Returns:
	        A list containing a tensor of face detections for each image in 
	        the batch. If no faces are found for an image, returns a tensor
	        of shape (0, 17).
	    Each face detection is a PyTorch tensor consisting of 17 numbers:
	        - ymin, xmin, ymax, xmax
	        - x,y-coordinates for the 6 keypoints
	        - confidence score
	    """
	    if isinstance(x, np.ndarray):
	        x = torch.from_numpy(x).permute((0, 3, 1, 2))

	    assert x.shape[1] == 3
	    assert x.shape[2] == 256
	    assert x.shape[3] == 256

	    # 1. Preprocess the images into tensors:
	    # x = x.to(self._device())
	    x = self._preprocess(x)

	    # 2. Run the neural network:
	    with torch.no_grad():
	        out = self.__call__(x)

	    # 3. Postprocess the raw predictions:
	    detections = self._tensors_to_detections(out[1], out[0], self.anchors)

	    # 4. Non-maximum suppression to remove overlapping detections:
	    filtered_detections = []
	    for i in range(len(detections)):
	        faces = self._weighted_non_max_suppression(detections[i])
	        if len(faces) > 0:
		        faces = torch.stack(faces)
		        filtered_detections.append(faces)

	    return filtered_detections

	def _decode_boxes(self, raw_boxes, anchors):
	    """Converts the predictions into actual coordinates using
	    the anchor boxes. Processes the entire batch at once.
	    """
	    boxes = torch.zeros_like(raw_boxes)

	    x_center = raw_boxes[..., 0] / 256 * anchors[:, 2] + anchors[:, 0]
	    y_center = raw_boxes[..., 1] / 256 * anchors[:, 3] + anchors[:, 1]

	    w = raw_boxes[..., 2] / 256 * anchors[:, 2] * 2.6
	    h = raw_boxes[..., 3] / 256 * anchors[:, 3] * 2.6

	    y_center = y_center - h / 5.2

	    boxes[..., 0] = x_center - w / 2.  # ymin
	    boxes[..., 1] = y_center - h / 2.  # xmin
	    boxes[..., 2] = x_center + w / 2.  # ymax
	    boxes[..., 3] = y_center + h / 2.  # xmax

	    for k in range(7):
	        offset = 4 + k*2
	        keypoint_x = raw_boxes[..., offset    ] / 256 * anchors[:, 2] + anchors[:, 0]
	        keypoint_y = raw_boxes[..., offset + 1] / 256 * anchors[:, 3] + anchors[:, 1]
	        boxes[..., offset    ] = keypoint_x
	        boxes[..., offset + 1] = keypoint_y

	    return boxes

	def _weighted_non_max_suppression(self, detections):
	    """The alternative NMS method as mentioned in the BlazeFace paper:
	    "We replace the suppression algorithm with a blending strategy that
	    estimates the regression parameters of a bounding box as a weighted
	    mean between the overlapping predictions."
	    The original MediaPipe code assigns the score of the most confident
	    detection to the weighted detection, but we take the average score
	    of the overlapping detections.
	    The input detections should be a Tensor of shape (count, 17).
	    Returns a list of PyTorch tensors, one for each detected face.
	    
	    This is based on the source code from:
	    mediapipe/calculators/util/non_max_suppression_calculator.cc
	    mediapipe/calculators/util/non_max_suppression_calculator.proto
	    """
	    if len(detections) == 0: return []

	    output_detections = []

	    # Sort the detections from highest to lowest score.
	    remaining = torch.argsort(detections[:, 18], descending=True)

	    while len(remaining) > 0:
	        detection = detections[remaining[0]]

	        # Compute the overlap between the first box and the other 
	        # remaining boxes. (Note that the other_boxes also include
	        # the first_box.)
	        first_box = detection[:4]
	        other_boxes = detections[remaining, :4]
	        ious = overlap_similarity(first_box, other_boxes)

	        # If two detections don't overlap enough, they are considered
	        # to be from different faces.
	        mask = ious >= 0.3
	        overlapping = remaining[mask]
	        remaining = remaining[~mask]

	        # Take an average of the coordinates from the overlapping
	        # detections, weighted by their confidence scores.
	        weighted_detection = detection.clone()
	        if len(overlapping) > 1:
	            coordinates = detections[overlapping, :18]
	            scores = detections[overlapping, 18:19]
	            total_score = scores.sum()
	            weighted = (coordinates * scores).sum(dim=0) / total_score
	            weighted_detection[:18] = weighted
	            weighted_detection[18] = total_score / len(overlapping)

	        output_detections.append(weighted_detection)

	    return output_detections


In [4]:
m = PalmDetector()
m.modify_finetune_layers()
m.load_weights("./palmdetector_finetuned.pth")
m.load_anchors('./anchors.npy')
# m.eval()

In [5]:
m.eval()

PalmDetector(
  (backbone1): Sequential(
    (0): ConstantPad2d(padding=(0, 1, 0, 1), value=0.0)
    (1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
    (2): ReLU(inplace=True)
    (3): ResBlock(
      (f): Sequential(
        (0): ResModule(
          (convs): Sequential(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
            (1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
          )
          (act): ReLU(inplace=True)
        )
        (1): ResModule(
          (convs): Sequential(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
            (1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
          )
          (act): ReLU(inplace=True)
        )
        (2): ResModule(
          (convs): Sequential(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
            (1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
          )

In [6]:
m.save_weights()

In [7]:
m.save_to_tflite('../models/palm_detection_torch_finetuned.tflite')

RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[1, 256, 4, 257] to have 3 channels, but got 256 channels instead

In [8]:
m.save_to_onnx()

Saved onnx at palm_detection_finetuned.onnx


In [9]:
onnx_model = onnx.load('palm_detection_finetuned.onnx')
tf_rep = prepare(onnx_model)  # prepare tf representation
tf_rep.export_graph("palm_detection_finetuned")



INFO:tensorflow:Assets written to: palm_detection_finetuned\assets


INFO (tensorflow) Assets written to: palm_detection_finetuned\assets


In [10]:
converter = tf.lite.TFLiteConverter.from_saved_model('palm_detection_finetuned')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]

tflite_path = 'palm_detection_torch_finetuned_quantized_float16' + '.tflite'
tflite_model = converter.convert()
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)
print('Saved tflite at ' + tflite_path)

Saved tflite at palm_detection_torch_finetuned_quantized_float16.tflite


In [11]:
import onnx
from onnx2keras import onnx_to_keras

onnx_model = onnx.load('palm_detection_finetuned.onnx')
k_model = onnx_to_keras(onnx_model, input_names=['input'], change_ordering=True)

[0, 0, 0, 0, 0, 0, 1, 1]




[0, 0, 0, 0, 0, 0, 2, 2]
[0, 0, 0, 0, 0, 32, 0, 0]
Tensor("Placeholder:0", shape=(None, 64, 64, 64), dtype=float32) Tensor("Placeholder_1:0", shape=(None, 32, 64, 64), dtype=float32)


ValueError: Exception encountered when calling layer "220" (type Lambda).

Dimensions must be equal, but are 64 and 32 for '{{node 220/Add}} = AddV2[T=DT_FLOAT](Placeholder, Placeholder_1)' with input shapes: [?,64,64,64], [?,32,64,64].

Call arguments received:
  • inputs=['tf.Tensor(shape=(None, 64, 64, 64), dtype=float32)', 'tf.Tensor(shape=(None, 32, 64, 64), dtype=float32)']
  • mask=None
  • training=None