<a href="https://colab.research.google.com/github/sklationd/mobilenet-v2-facemask/blob/main/WebCamMask_ID.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Mask Evaluator with Own Webcam ~(MEOW 🐈)~**

1. There is nothing to worry about. Just press `Runtime` - `Execute All` in menu bar or `Ctrl/Command + F9`
2. Allow colab to access your webcam.
3. We do not save your webcam image stream. Don't worry about it.

<font color="red">**Our model is not robust to various environment and mask color. It will perform better performance if you wear a `blue/white` mask and background is `clear`.** </font>



# 0. Initialize

In [None]:
import numpy as np
import torchvision
from torchvision import transforms, datasets, models
import torch
from PIL import Image
from pathlib import Path
import os
import io
import cv2
import torch.nn.functional as F
import base64
import html
import time
from IPython.display import display, Javascript
from google.colab.output import eval_js

In [None]:
# Check whether GPU is enabled or not

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
if torch.cuda.is_available():
  print("GPU is available")
else:
  print("GPU is not available. Please enable GPU in Runtime - Change runtime type GPU ")
  exit()

In [None]:
# Download weight
'''
It may fail due to download limit of google drive.
If you encounter this issue, please let me know by below contact.
sklationd@gmail.com
''' 
!wget -nc --no-check-certificate 'https://docs.google.com/uc?export=download&id=1QHrOBzc4yZWCfC1P2AirWtjNvDKjAbe8' -O scripted_best.pt

In [None]:
# Helper function 

labels = [ "No Mask", "Correct Mask", "Only Chin is Covered", "Nose is not covered", "Chin is not covered",]

def getLabel(scores):
  maxScore = -np.Inf
  maxScoreIdx = -1;
  for i in range(len(scores)):
    if scores[i] > maxScore :
      maxScore = scores[i]
      maxScoreIdx = i
  return labels[maxScoreIdx]

def getColor(label):
  if label == labels[1]:
    return [63,141,60]
  else:
    return [255,0,0]

# 1. Define Js and Webcam Functions

In [None]:

def start_input():
  js = Javascript('''
    var video;
    var div = null;
    var stream;
    var captureCanvas;
    var imgElement;
    var labelElement;
    
    var pendingResolve = null;
    var shutdown = false;
    
    function removeDom() {
       stream.getVideoTracks()[0].stop();
       video.remove();
       div.remove();
       video = null;
       div = null;
       stream = null;
       imgElement = null;
       captureCanvas = null;
       labelElement = null;
    }
    
    function onAnimationFrame() {
      if (!shutdown) {
        window.requestAnimationFrame(onAnimationFrame);
      }
      if (pendingResolve) {
        var result = "";
        if (!shutdown) {
          captureCanvas.getContext('2d').drawImage(video, 0, 0, 512, 512);
          result = captureCanvas.toDataURL('image/jpeg', 0.8)
        }
        var lp = pendingResolve;
        pendingResolve = null;
        lp(result);
      }
    }
    
    async function createDom() {
      if (div !== null) {
        return stream;
      }

      div = document.createElement('div');
      div.style.border = '2px solid black';
      div.style.padding = '3px';
      div.style.width = '100%';
      div.style.maxWidth = '600px';
      document.body.appendChild(div);
      
      const modelOut = document.createElement('div');
      modelOut.innerHTML = "<span>Status:</span>";
      labelElement = document.createElement('span');
      labelElement.innerText = 'No data';
      labelElement.style.fontWeight = 'bold';
      modelOut.appendChild(labelElement);
      div.appendChild(modelOut);
           
      video = document.createElement('video');
      video.style.display = 'block';
      video.width = div.clientWidth - 6;
      video.setAttribute('playsinline', '');
      video.onclick = () => { shutdown = true; };
      stream = await navigator.mediaDevices.getUserMedia(
          {video: { facingMode: "environment"}});
      div.appendChild(video);

      imgElement = document.createElement('img');
      imgElement.style.position = 'absolute';
      imgElement.style.zIndex = 1;
      imgElement.onclick = () => { shutdown = true; };
      div.appendChild(imgElement);
      
      const instruction = document.createElement('div');
      instruction.innerHTML = 
          '<span style="color: red; font-weight: bold;">' +
          'When finished, click here or on the video to stop this demo</span>';
      div.appendChild(instruction);
      instruction.onclick = () => { shutdown = true; };
      
      video.srcObject = stream;
      await video.play();

      captureCanvas = document.createElement('canvas');
      captureCanvas.width = 512; //video.videoWidth;
      captureCanvas.height = 512; //video.videoHeight;
      window.requestAnimationFrame(onAnimationFrame);
      
      return stream;
    }
    async function takePhoto(label, imgData) {
      if (shutdown) {
        removeDom();
        shutdown = false;
        return '';
      }

      var preCreate = Date.now();
      stream = await createDom();
      
      var preShow = Date.now();
      if (label != "") {
        labelElement.innerHTML = label;
      }
            
      if (imgData != "") {
        var videoRect = video.getClientRects()[0];
        imgElement.style.top = videoRect.top + "px";
        imgElement.style.left = videoRect.left + "px";
        imgElement.style.width = videoRect.width + "px";
        imgElement.style.height = videoRect.height + "px";
        imgElement.src = imgData;
      }
      
      var preCapture = Date.now();
      var result = await new Promise(function(resolve, reject) {
        pendingResolve = resolve;
      });
      shutdown = false;
      
      return {'create': preShow - preCreate, 
              'show': preCapture - preShow, 
              'capture': Date.now() - preCapture,
              'img': result};
    }
    ''')

  display(js)
  
def take_photo(label, img_data):
  data = eval_js('takePhoto("{}", "{}")'.format(label, img_data))
  return data

# 2. Model Inference and Drawing Result

In [None]:
# Preprocess 
preprocess = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load model
model = torch.jit.load('./scripted_best.pt')
model.to(device)
model.eval()

# Clean output
!

In [None]:
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
    # Plots one bounding box on image img
    tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness
    color = color or [random.randint(0, 255) for _ in range(3)]
    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
    cv2.rectangle(img, c1, c2, color, thickness=tl)
    if label:
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(img, c1, c2, color, -1)  # filled
        cv2.putText(img, label, (c1[0], c1[1] - 2), 2, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)


def js_reply_to_image(js_reply):
    jpeg_bytes = base64.b64decode(js_reply['img'].split(',')[1])
    image_PIL = Image.open(io.BytesIO(jpeg_bytes))

    return image_PIL

def get_drawing_array(image_array): 
    drawing_array = np.zeros([512,512,4], dtype=np.uint8)
    img = preprocess(image_array).to(device)
    img = img.unsqueeze(0)
    pred = model(img).squeeze()

    label = getLabel(pred)
    plot_one_box([30,30,482,482], drawing_array, label=label, color=getColor(label), line_thickness=3)

    drawing_array[:,:,3] = (drawing_array.max(axis = 2) > 0 ).astype(int) * 255

    return drawing_array

def drawing_array_to_bytes(drawing_array):
    drawing_PIL = Image.fromarray(drawing_array, 'RGBA')
    iobuf = io.BytesIO()
    drawing_PIL.save(iobuf, format='png')
    drawing_bytes = 'data:image/png;base64,{}'.format((str(base64.b64encode(iobuf.getvalue()), 'utf-8')))
    return drawing_bytes


# 3. Start application

In [None]:
start_input()
label_html = 'Capturing...'
img_data = ''
count = 0 
while True:
    js_reply = take_photo(label_html, img_data)
    if not js_reply:
        break
    image = js_reply_to_image(js_reply)
    drawing_array = get_drawing_array(image) 
    drawing_bytes = drawing_array_to_bytes(drawing_array)
    img_data = drawing_bytes