<a href="https://colab.research.google.com/github/wakewakame/easyaiortc/blob/main/examples/semantic_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# セマンティックセグメンテーションのデモ
セマンティックセグメンテーションの学習済みモデルを動かしてみるサンプルプログラムです。  
TensorFlow用、TensorFlowLite用、PyTorch用の3つを用意しました。  
  
動作させる際には処理速度向上のため、 `ランタイム` > `ランタイムのタイプを変更` から `ハードウェア アクセラレータ` を `GPU` に変更することをお勧めします。  

## 1. easyaiortcのインストール

In [None]:
!apt install libavdevice-dev libavfilter-dev libopus-dev libvpx-dev pkg-config
!pip install git+https://github.com/wakewakame/easyaiortc.git

## 2. セマンティックセグメンテーションを行う抽象クラスの作成

In [None]:
from abc import ABC, ABCMeta, abstractmethod

class SemanticSegmentation(metaclass = ABCMeta):
    @abstractmethod
    def estimate(self, input_image):
        pass

    def colorful(self, estimated_image, original_image=None, alpha=0.5):
        estimated_image = cv2.resize(estimated_image, dsize=(original_image.shape[1], original_image.shape[0]), interpolation=cv2.INTER_NEAREST)
        output_image = estimated_image * int(255 / 21)
        output_image = cv2.applyColorMap(output_image, cv2.COLORMAP_HSV)
        if original_image is not None:
            output_image = cv2.addWeighted(
                src1=original_image, alpha=alpha,
                src2=output_image, beta=1.0-alpha,
                gamma=0.0
            )
        return output_image

## 3. TensorFlow版の実装
参考元 : [https://github.com/tensorflow/models/tree/master/research/deeplab](https://github.com/tensorflow/models/tree/master/research/deeplab)  

In [None]:
import os
import tarfile
import urllib
import tensorflow as tf
import numpy as np
import cv2


class TensorFlowSegm(SemanticSegmentation):
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    INPUT_SIZE = 513
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    def __init__(self, download_path=None):
        url = 'http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz'
        #url = 'http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz'
        #url = 'http://download.tensorflow.org/models/deeplabv3_pascal_train_aug_2018_01_04.tar.gz'
        #url = 'http://download.tensorflow.org/models/deeplabv3_pascal_trainval_2018_01_04.tar.gz'
        if download_path is None:
            download_path = os.path.join(os.getcwd(), url.split("/")[-1])
        if not os.path.isfile(download_path):
            data = urllib.request.urlopen(url).read()
            with open(download_path, mode="wb") as f:
                f.write(data)
        self.graph = tf.Graph()
        graph_def = None
        with tarfile.open(download_path) as tar_file:
            for tar_info in tar_file.getmembers():
                if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
                    file_handle = tar_file.extractfile(tar_info)
                    graph_def = tf.compat.v1.GraphDef.FromString(file_handle.read())
                    break
        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')
        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')
        self.sess = tf.compat.v1.Session(graph=self.graph)

    def estimate(self, input_image):
        height, width, _ = input_image.shape
        resize_ratio = self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        input_image = cv2.resize(input_image, dsize=target_size, interpolation=cv2.INTER_LINEAR)
        batch_seg_map = self.sess.run(
            self.OUTPUT_TENSOR_NAME,
            feed_dict={self.INPUT_TENSOR_NAME: [input_image]}
        )
        output_image = batch_seg_map[0].astype(np.uint8)
        return output_image

## 4. TensorFlowLite版の実装
参考元1 : [https://www.tensorflow.org/lite/guide/python](https://www.tensorflow.org/lite/guide/python)  
参考元2 : [https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/default/1](https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/default/1)  

In [None]:
import os
import urllib
import tensorflow as tf
import numpy as np
import cv2


class TensorFlowLiteSegm(SemanticSegmentation):
    def __init__(self, download_path=None):
        url = 'https://storage.googleapis.com/tfhub-lite-models/tensorflow/lite-model/deeplabv3/1/default/1.tflite'
        #url = 'https://storage.googleapis.com/tfhub-lite-models/tensorflow/lite-model/deeplabv3/1/metadata/2.tflite'
        if download_path is None:
            download_path = os.path.join(os.getcwd(), url.split("/")[-1])
        if not os.path.isfile(download_path):
            data = urllib.request.urlopen(url).read()
            with open(download_path, mode="wb") as f:
                f.write(data)
        self.interpreter = tf.lite.Interpreter(model_path=download_path, num_threads=None)
        self.interpreter.allocate_tensors()
        self.input_details = self.interpreter.get_input_details()
        self.output_details = self.interpreter.get_output_details()
        self.height, self.width = self.input_details[0]['shape'].tolist()[1:3]
        self.output_index = self.output_details[0]['index']

    def estimate(self, input_image):
        input_image = cv2.resize(input_image, dsize=(self.width, self.height), interpolation=cv2.INTER_LINEAR)
        input_data = np.expand_dims(input_image, axis=0)
        if self.input_details[0]['dtype'] == np.float32:
            input_data = (np.float32(input_data) - 127.5) / 127.5
        self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
        self.interpreter.invoke()
        output_data = self.interpreter.get_tensor(self.output_details[0]['index'])
        output_data = np.squeeze(output_data)
        output_image = output_data.argmax(2).astype(np.uint8)
        return output_image

## 5. PyTorch版の実装
参考元1 : [https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/](https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/)  
参考元2 : [https://pytorch.org/vision/stable/models.html](https://pytorch.org/vision/stable/models.html)  

In [None]:
import torch
import torchvision
import numpy as np
import cv2


class PyTorchSegm(SemanticSegmentation):
    def __init__(self):
        #self.model = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
        #self.model = torchvision.models.segmentation.fcn_resnet101(pretrained=True)
        #self.model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
        #self.model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
        self.model = torchvision.models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True)
        #self.model = torchvision.models.segmentation.lraspp_mobilenet_v3_large(pretrained=True)
        self.model.eval()
        if torch.cuda.is_available():
            self.model.to('cuda')
        self.preprocess = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def estimate(self, input_image):
        input_tensor = self.preprocess(input_image)
        input_batch = input_tensor.unsqueeze(0)
        if torch.cuda.is_available():
            input_batch = input_batch.to('cuda')
        with torch.no_grad():
            output_tensor = self.model(input_batch)['out'][0]
        output_image = output_tensor.argmax(0).byte().cpu().numpy()
        return output_image

## 6. 画像での実行

In [None]:
import easyaiortc
import urllib.request
import numpy as np
import cv2
from google.colab.patches import cv2_imshow

segm_type = "TensorFlow" #@param ["TensorFlow", "TensorFlowLite", "PyTorch"]
img_url = "https://github.com/wakewakame/openpose_ext/blob/main/media/human.jpg?raw=true" #@param {type:"string"}

segm = None
if segm_type == "TensorFlow":
    segm = TensorFlowSegm()
elif segm_type == "TensorFlowLite":
    segm = TensorFlowLiteSegm()
else:
    segm = PyTorchSegm()

jpeg = urllib.request.urlopen(img_url).read()
img = cv2.imdecode(np.asarray(bytearray(jpeg), dtype="uint8"), cv2.IMREAD_COLOR)

segm_img = img
segm_img = cv2.cvtColor(segm_img, cv2.COLOR_BGR2RGB)
segm_img = segm.estimate(segm_img)
segm_img = segm.colorful(segm_img, img)

cv2_imshow(segm_img)

## 7. AppRTCでの実行

In [None]:
import easyaiortc

segm_type = "TensorFlow" #@param ["TensorFlow", "TensorFlowLite", "PyTorch"]
segm = None
if segm_type == "TensorFlow":
    segm = TensorFlowSegm()
elif segm_type == "TensorFlowLite":
    segm = TensorFlowLiteSegm()
else:
    segm = PyTorchSegm()

# 接続の開始
rtc = easyaiortc.EasyAppRTC(preview=True, width=1280, height=720)

try:
    # 接続されている間はループ
    while rtc.is_alive():
        # 映像の受信
        img = rtc.get()
        if img is None:
            continue

        segm_img = img
        segm_img = cv2.cvtColor(segm_img, cv2.COLOR_BGR2RGB)
        segm_img = segm.estimate(segm_img)
        segm_img = segm.colorful(segm_img, img)

        # 回転した映像を送信
        rtc.put(segm_img)

# Ctrl+Zで終了
except KeyboardInterrupt:
    pass