<a href="https://colab.research.google.com/github/sigmunjr/TEK5030_deep_learning/blob/master/TEK5030_deep_learning_EX2_solution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exercise 2: Live training

This exercise is primarily so you can play with deep learning more directly. You need to use a webcamera to gather images in a dataset. Then you can run training and visualize your result iteratively.

This first code block is just made to interact the notebook/website.  

*   You can add new objects by typing the label name in the "new object" field and pressing enter
*   After you have added some object they can be selected in the "selected" field
*   You can add an image from your webcam with "Add image to..." button.

**TODO:**

Add some labels and add a few images to each label.

You may try to classify different objects you have laying around, or you can try to classify different hand gestures etc.

In [0]:
try:
  from webcam_in_notebook import WebCamera, LiveDataset, DatasetGUI
except Error:
  !pip install git+https://github.com/sigmunjr/webcam-in-notebook.git
  from webcam_in_notebook import WebCamera, LiveDataset, DatasetGUI
import tensorflow as tf

dataset = LiveDataset()
gui = DatasetGUI(dataset, WebCamera())
gui.display()

## Create a network
Create a network for classifying the images. It can be wise to use a pretrained net to get the best possible results.

You should also consider setting most of the layers trainable to *false*, to make it work with as few images as possible.

In [0]:
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2

model = MobileNetV2(weights='imagenet', include_top=True)
for l in model.layers[:-1]:
  l.trainable = False
criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(0.001)

Now it is time to train your network. Dataset have a method *get_batch* to fetch a batch of images and corresponding labels (You can check out the class at the bottom of this notebook).

    images, labels = dataset.get_batch(16)

With this you get a batch of 16 images and labels in the format of numpy arrays. If you use tf.GradientTape to train your networks, you can train as normal. If you use *fit* to train you model, you can either use [tf.data.Dataset.from_generator](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator) to train your model or you can simply:

- get a large batch from the dataset
- provide *x*, *y* and *batch_size* to model.fit

To get best possible performance you may want to add some extra data augmentation, e.g. with the help of [tf.keras.preprocessing.image.ImageDataGenerator](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator).

**TODO:**
Train you network.

In [0]:
# TODO: train your network
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.01),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=['accuracy'],
    run_eagerly=False,
    )
images, labels = dataset.get_batch(16*10)
model.fit(images, labels, batch_size=16, epochs=10)


Finally you can test your network live. I have made a class for writing text on the web camera image *LiveJavascriptTextField*.

**TODO:**
Run your network on images from the webcamera to set *predicted_label* and *label_probablility*

In [0]:
def run_inference(camera):
  from webcam_in_notebook import LiveJavascriptTextField
  js_textfield = LiveJavascriptTextField(
      'position: absolute;color: lightgreen; font-size: xx-large;'
      )

  for i, img in enumerate(camera):
    # Todo: Run your network on images from the webcamera to set predicted_label and label_probablility
    out = net(LiveDataset.convert_to_tf_image(img)[tf.newaxis])[0]
    predicted_label = np.argmax(out.numpy())
    output_probability = out[predicted_label]

    dataset.name_map[predicted_label],
    js_textfield.updateText("{}: {:.2f}".format(
        dataset.name_map[predicted_label],
        output_probability
    ))

webcam = webcam_in_notebook.WebCamera()
run_inference(webcam)

## Run visualization

Finally you can try to run the visualization techniques from *Excercies 1* on images from you webcamera.

## Code for GUI and dataset

In [0]:
from webcam_in_notebook import webcam_in_notebook
import random
import ipywidgets as widgets
import time
import threading
from google.colab.patches import cv2_imshow
from IPython.display import display
from IPython.display import HTML
from IPython.display import clear_output
import tensorflow as tf
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

import cv2
import numpy as np
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input


class LiveDataset:
    def __init__(self):
        self.train_set = {}
        self.label_map = {}
        self.name_map = {}

    def add_image(self, image, label):
        train_image = self.convert_to_tf_image(image)
        self.add_label(label)
        self.train_set[label] += [train_image]

    def add_label(self, label):
        if label in self.train_set:
            return
        n_labels = len(self.label_map)
        self.train_set[label] = []
        self.label_map[label] = n_labels
        self.name_map[n_labels] = label

    def get_batch(self, batch_size=16):
        labels = []
        images = []
        for i in range(batch_size):
            labels += [np.random.choice(list(self.name_map.keys()))]
            examples = self.train_set[self.name_map[labels[-1]]]
            images += random.sample(examples, 1)
        return np.stack(images), np.stack(labels)

    @staticmethod
    def convert_to_tf_image(image):
        train_image = cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), (224, 224))
        train_image = tf.cast(train_image, tf.float32)
        train_image = preprocess_input(train_image)
        return train_image
  
class DatasetGUI:
  def __init__(self, dataset, webcam):
    self.dataset = dataset
    self.webcam = webcam
    self.build_widgets()
    self.js_textfield = LiveJavascriptTextField()
  
  def build_widgets(self):
    self.select = widgets.Select(
        options=list(self.dataset.label_map.keys()),
        description='Selected:',
        disabled=False
    )
    self.button = widgets.Button(
        description='Add image',
        disabled = len(self.dataset.label_map) == 0,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Add image',
        icon='check' # (FontAwesome names without the `fa-` prefix)
    )
    self.text_field = widgets.Text(
        placeholder='Object label',
        description='New object:',
        disabled=False
    )
    self.select.observe(self.onSelect, 'value')
    self.text_field.on_submit(self.addObject)
    self.button.on_click(self.addImage)

  def addImage(self, _):
    img = self.webcam.next()
    self.dataset.add_image(img, self.select.value)
    self.js_textfield.updateText('label : count')
    for key, value in self.dataset.train_set.items():
      self.js_textfield.addText(key + ':' + str(len(value)))

  def onSelect(self, selected):
    self.button.description = 'Add image to ' + selected['new']

  def addObject(self, label):
    if label.value in self.dataset.label_map:
        return
    self.dataset.add_label(label.value)
    label.value = ''
    self.select.options = list(self.dataset.label_map.keys())
    self.button.disabled = False
    self.select.disabled = False
  
  def display(self):
    display(self.text_field, self.select, self.button)

from IPython.display import display, Javascript
from google.colab.output import eval_js

class LiveJavascriptTextField:
  def __init__(self, style=''):
    self.initText(style)

  def initText(self, style):
    js = Javascript('''
      that = this;
      async function initText(style)
      {
        const video = document.querySelector("#output-area");
        const div = document.createElement('div');
        div.innerHTML = '';
        that.text_area = div;
        div.style = style;
        video.appendChild(
          div
          );
      }
      async function updateText(text)
      {
        that.text_area.innerHTML = text;
      }
      async function addText(text)
      {
        that.text_area.innerHTML += '<p>' + text + '</p>';
      }
    ''')
    display(js);
    eval_js('initText("{}")'.format(style));
  
  def addText(self, text):
    eval_js('addText("{}")'.format(text))
  
  def updateText(self, text):
    eval_js('updateText("{}")'.format(text))