# Face mask detection using TensorFlow
<b>Welcome to the face mask detection inference walkthrough!  This notebook will walk you step by step through the process of using a pre-trained model to detect if people are wearing a mask or not in real-time.</b>
### Make sure to follow the [installation instructions](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md) before you start.

# Imports

In [None]:
import numpy as np
import os, sys, cv2, tarfile, zipfile
from itertools import combinations 
import six.moves.urllib as urllib
import tensorflow as tf
from distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops

if StrictVersion(tf.__version__) < StrictVersion('1.14.0'):
  raise ImportError('Please upgrade your TensorFlow installation to v1.14.*.')


Here are the imports from the object detection module.

In [None]:
from utils import label_map_util

from utils import visualization_utils as vis_util

## Model preparation 

Any model exported using the `export_inference_graph.py` tool can be loaded here simply by changing `PATH_TO_FROZEN_GRAPH` to point to a new .pb file.  

[Train a model](https://github.com/rajatvisitme/faceMaskDetection/blob/master/README.md) on you own dataset.

In [None]:
MODEL_NAME = 'inference_graph'
PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = 'training/labelmap.pbtxt'

## Load a (frozen) Tensorflow model into memory.

In [None]:
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

## Loading label map
Label maps map indices to category names, so that when our convolution network predicts `2`, we know that this corresponds to `without_mask`.  Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine.

In [None]:
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)

## To save output video.

## Start detection...

In [None]:
cap = cv2.VideoCapture(0) #Use only if not started capturing video before.

with detection_graph.as_default():
    with tf.Session() as sess:
        while True:
            ret, image_np = cap.read()
            
            image_np_expanded = np.expand_dims(image_np, axis=0)
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
            scores = detection_graph.get_tensor_by_name('detection_scores:0')
            classes = detection_graph.get_tensor_by_name('detection_classes:0')
            num_detections = detection_graph.get_tensor_by_name('num_detections:0')
            (boxes, scores, classes, num_detections) = sess.run(
                [boxes, scores, classes, num_detections],
                feed_dict={image_tensor: image_np_expanded})
            
            vis_util.visualize_boxes_and_labels_on_image_array(
                image_np,
                np.squeeze(boxes),
                np.squeeze(classes).astype(np.int32),
                np.squeeze(scores),
                category_index,
                use_normalized_coordinates=True,
                line_thickness=6)
            '''
            #use this only if saving the output video.
            if ret == True:
                # Saves for video
                out.write(image_np)

                # Display the resulting frame
                cv2.imshow('Mask Detection', image_np)
            '''
            cv2.imshow('Mask Detection System', cv2.resize(image_np, (800, 600)))
            
            if cv2.waitKey(25) & 0xFF == ord('q'):
                cv2.destroyAllWindows()
                break
        cap.release()
        #out.release()