# Real-time semantic segmentation

#### 10 classes vs 20 classes
<table>
    <tr>
        <td>1 Hat</td>
        <td>1 Hat</td>
    </tr>
    <tr>
        <td>2 Hair</td>
        <td>2 Hair</td>
    </tr>
    <tr>
        <td>3 Arms</td>
        <td>3 Gloves, 14 Left-arm, 15 Right-arm</td>
    </tr>
    <tr>
        <td>4 Sunglasses</td>
        <td>4 Sunglasses</td>
    </tr>
    <tr>
        <td>5 UpperClothes</td>
        <td>5 UpperClothes, 6 Dress, 7 Coat, 10 Jumpsuits, 11 Scarf</td>
    </tr>
    <tr>
        <td>6 LowerClothes</td>
        <td>8 Socks, 9 Pants, 16 Left-leg, 17 Right-leg</td>
    </tr>    
    <tr>
        <td>7 Skirt</td>
        <td>12 Skirt</td>
    </tr>
    <tr>
        <td>8 Face</td>
        <td>13 Face</td>
    </tr>
    <tr>
        <td>9 Shoes</td>
        <td>18 Left-shoe, 19 Right-shoe</td>
    </tr>    
</table>

In this project, visualization is achieved by generating gif file instantly after processing of the video is finished. Each frame is fed into U-Net and then the output, after being equipped with a palette and then converted into an RGB image, is appended into a list. The list, which stores the segmentation of each frame, is finally used to generate the gif file. 

In [1]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import transforms
from torchvision.transforms import ToPILImage
import torchvision.transforms.functional as tf
import matplotlib.pyplot as plt
from PIL import Image
from math import sqrt
import time
import math
import os
import copy
import cv2
import random
import imageio
import warnings
warnings.filterwarnings("ignore")

In [2]:
num_classes = 10

In [3]:
colors_20 = [
    0,   0,   0,
    102,   0,   102,   # Hat
    0, 0,  153,        # Hair
    0, 128,  128,      # Glove
    0,   255, 255,     # Sunglasses
    255,  51, 0,       # UpperClothes
    0, 128, 128,       # Dress
    255, 153, 108,     # Coat
    64,   0,   0,      # Socks
    255,153,51,        # Pants
    204, 51,   0,      # Jumpsuits
    0, 153,   0,       # Scarf
    0,   255, 0,       # Skirt
    255, 255, 102,     # Face
    204, 236,  255,    # Left-arm
    255,217,179,       # Right-arm
    0,  102, 153,      # Left-leg
    102,204,255,       # Right-leg
    255,0,102,         # Left-shoe
    205,102,153        # Right-shoe
]

colors_10 = [
    0, 0, 0,
    9, 13, 172,        # Hat 1
    85, 26, 139,       # Hair 2
    255, 193, 193,     # Arms 3
    0, 255, 255,       # Sunglasses 4
    178, 34, 34,       # UpperClothes 5
    237, 104, 37,      # LowerClothes 6
    32, 178, 170,      # Skirt 7
    247, 209, 60,      # Face 8
    139, 10, 80        # Dhoes 9
]

if num_classes == 10:
    colors = colors_10
else:
    colors = colors_20

In [4]:
def pad_to_square(img):
    s = img.size
    w = abs((s[0]-max(s))//2)
    h = abs((s[1]-max(s))//2)
    padding = (w,h,max(s)-s[0]-w,max(s)-s[1]-h)
    return tf.pad(img, padding)


def my_transfrom(img, size=448):
    img = pad_to_square(img)
    img = tf.resize(img,(size,size))
    img = tf.to_tensor(img).float()
    img = tf.normalize(img,[0.485,0.456,0.406],[0.229,0.224,0.225])
    img = torch.unsqueeze(img,dim=0)
    return img

## UNet

In [5]:
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        x = self.conv(x)
        return x

    
class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        x = self.up(x)
        return x

    
class UNet(nn.Module):
    def __init__(self, dim = num_classes):
        super(UNet,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=3,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,dim,kernel_size=1,stride=1,padding=0)
        

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                m.weight.data.normal_(0, sqrt(2. / n))
                nn.init.constant_(m.bias,0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight,1)
                nn.init.constant_(m.bias,0)

    
    def forward(self,x):
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

In [6]:
if num_classes == 10:
    unet = UNet(dim=10).cuda()
    unet.load_state_dict(torch.load('UNm_10000III_20e_FewerClasses.pth'))
else:
    unet = UNet(dim=20).cuda()
    unet.load_state_dict(torch.load('UNm_10000_30e_MoreAug.pth'))
unet.eval();

In [7]:
video_capture = cv2.VideoCapture('exercise.mp4')
frames = []

while True:
    ret, frame = video_capture.read()
    if not ret:
        break
    rgb_frame = frame[:, :, ::-1]
    rgb_frame = my_transfrom(Image.fromarray(rgb_frame))   
    output = unet(rgb_frame.cuda())
    output = torch.max(output,1)[1]
    output = output.cpu().clone()
    output = torch.as_tensor(output,dtype = torch.uint8)
    output = ToPILImage()(output)
    output.putpalette(colors)
    output = output.convert('RGB')
    frames.append(output)

imageio.mimsave('exercise-'+str(num_classes)+'cl.gif', frames, 'GIF', duration = 0.05)