<a href="https://colab.research.google.com/github/xenagarage/automatic-signal-classification-with-deep-learning/blob/main/segnet_enet_jaeoh2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%matplotlib inline
%tensorflow_version 1.x
import torch
import numpy as np
import time
import os
import argparse

import sys
sys.path.append("../")

# from segnet import SegNet
# from loss import DiscriminativeLoss
# from dataset import tuSimpleDataset
# from logger import Logger

TensorFlow 1.x selected.


# Segnet

In [2]:
# SEGNET

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class ConvBnRelu(nn.Module):
    def __init__(self, input_ch, output_ch, kernel_size=3, padding=1):
        super(ConvBnRelu, self).__init__()
        self.conv =  nn.Sequential(
            nn.Conv2d(input_ch, output_ch, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(output_ch),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class SegNet(nn.Module):
    # refer from : https://github.com/delta-onera/segnet_pytorch/blob/master/segnet.py
    def __init__(self, input_ch, output_ch):
        super(SegNet, self).__init__()

        self.vgg16 = models.vgg16(pretrained=True)

        # Shared Encoder
        self.enc11 = ConvBnRelu(input_ch, 64)
        self.enc12 = ConvBnRelu(64, 64)

        self.enc21 = ConvBnRelu(64, 128)
        self.enc22 = ConvBnRelu(128, 128)

        self.enc31 = ConvBnRelu(128, 256)
        self.enc32 = ConvBnRelu(256, 256)
        self.enc33 = ConvBnRelu(256, 256)

        self.enc41 = ConvBnRelu(256, 512)
        self.enc42 = ConvBnRelu(512, 512)
        self.enc43 = ConvBnRelu(512, 512)

        self.enc51 = ConvBnRelu(512, 512)
        self.enc52 = ConvBnRelu(512, 512)
        self.enc53 = ConvBnRelu(512, 512)

        self.init_vgg_weigts()

        # Binary Segmentation Decoder
        self.sem_dec53 = ConvBnRelu(512, 512)
        self.sem_dec52 = ConvBnRelu(512, 512)
        self.sem_dec51 = ConvBnRelu(512, 512)

        self.sem_dec43 = ConvBnRelu(512, 512)
        self.sem_dec42 = ConvBnRelu(512, 512)
        self.sem_dec41 = ConvBnRelu(512, 256)

        self.sem_dec33 = ConvBnRelu(256, 256)
        self.sem_dec32 = ConvBnRelu(256, 256)
        self.sem_dec31 = ConvBnRelu(256, 128)

        self.sem_dec22 = ConvBnRelu(128, 128)
        self.sem_dec21 = ConvBnRelu(128, 64)

        self.sem_dec12 = ConvBnRelu(64, 64)

        # Instance Segmentation Decoder
        self.ins_dec53 = ConvBnRelu(512, 512)
        self.ins_dec52 = ConvBnRelu(512, 512)
        self.ins_dec51 = ConvBnRelu(512, 512)

        self.ins_dec43 = ConvBnRelu(512, 512)
        self.ins_dec42 = ConvBnRelu(512, 512)
        self.ins_dec41 = ConvBnRelu(512, 256)

        self.ins_dec33 = ConvBnRelu(256, 256)
        self.ins_dec32 = ConvBnRelu(256, 256)
        self.ins_dec31 = ConvBnRelu(256, 128)

        self.ins_dec22 = ConvBnRelu(128, 128)
        self.ins_dec21 = ConvBnRelu(128, 64)

        self.ins_dec12 = ConvBnRelu(64, 64)

        self.sem_out = nn.Conv2d(64, output_ch, kernel_size=3, stride=1, padding=1)
        self.ins_out = nn.Conv2d(64, 5, kernel_size=3, stride=1, padding=1)
                
    def forward(self, x):
        # Shared Encoder
        x = self.enc11(x)
        x = self.enc12(x)
        x, ind_1 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)

        x = self.enc21(x)
        x = self.enc22(x)
        x, ind_2 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)

        x = self.enc31(x)
        x = self.enc32(x)
        x = self.enc33(x)
        x, ind_3 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)

        x = self.enc41(x)
        x = self.enc42(x)
        x = self.enc43(x)
        x, ind_4 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)

        x = self.enc51(x)
        x = self.enc52(x)
        x = self.enc53(x)
        x, ind_5 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)

        # Binary Segmentation Decoder
        x1 = F.max_unpool2d(x, ind_5, kernel_size=2, stride=2)
        x1 = self.sem_dec53(x1)
        x1 = self.sem_dec52(x1)
        x1 = self.sem_dec51(x1)

        x1 = F.max_unpool2d(x1, ind_4, kernel_size=2, stride=2)
        x1 = self.sem_dec43(x1)
        x1 = self.sem_dec42(x1)
        x1 = self.sem_dec41(x1)

        x1 = F.max_unpool2d(x1, ind_3, kernel_size=2, stride=2)
        x1 = self.sem_dec33(x1)
        x1 = self.sem_dec32(x1)
        x1 = self.sem_dec31(x1)

        x1 = F.max_unpool2d(x1, ind_2, kernel_size=2, stride=2)
        x1 = self.sem_dec22(x1)
        x1 = self.sem_dec21(x1)

        x1 = F.max_unpool2d(x1, ind_1, kernel_size=2, stride=2)
        x1 = self.sem_dec12(x1)

        # Instance Segmentation Decoder
        x2 = F.max_unpool2d(x, ind_5, kernel_size=2, stride=2)
        x2 = self.ins_dec53(x2)
        x2 = self.ins_dec52(x2)
        x2 = self.ins_dec51(x2)

        x2 = F.max_unpool2d(x2, ind_4, kernel_size=2, stride=2)
        x2 = self.ins_dec43(x2)
        x2 = self.ins_dec42(x2)
        x2 = self.ins_dec41(x2)

        x2 = F.max_unpool2d(x2, ind_3, kernel_size=2, stride=2)
        x2 = self.ins_dec33(x2)
        x2 = self.ins_dec32(x2)
        x2 = self.ins_dec31(x2)

        x2 = F.max_unpool2d(x2, ind_2, kernel_size=2, stride=2)
        x2 = self.ins_dec22(x2)
        x2 = self.ins_dec21(x2)

        x2 = F.max_unpool2d(x2, ind_1, kernel_size=2, stride=2)
        x2 = self.ins_dec12(x2)

        sem = self.sem_out(x1)
        ins = self.ins_out(x2)

        return sem, ins

    def init_vgg_weigts(self):
        self.enc11.conv[0].weight.data = self.vgg16.features[0].weight.data

        self.enc12.conv[0].weight.data = self.vgg16.features[2].weight.data

        self.enc21.conv[0].weight.data = self.vgg16.features[5].weight.data

        self.enc22.conv[0].weight.data = self.vgg16.features[7].weight.data

        self.enc31.conv[0].weight.data = self.vgg16.features[10].weight.data

        self.enc32.conv[0].weight.data = self.vgg16.features[12].weight.data

        self.enc33.conv[0].weight.data = self.vgg16.features[14].weight.data

        self.enc41.conv[0].weight.data = self.vgg16.features[17].weight.data

        self.enc42.conv[0].weight.data = self.vgg16.features[19].weight.data

        self.enc43.conv[0].weight.data = self.vgg16.features[21].weight.data

        self.enc51.conv[0].weight.data = self.vgg16.features[21].weight.data

        self.enc52.conv[0].weight.data = self.vgg16.features[24].weight.data

        self.enc53.conv[0].weight.data = self.vgg16.features[26].weight.data

# Enet

In [3]:
# ENET

import torch
import torch.nn as nn
import torch.nn.functional as F


class InitialBlock(nn.Module):
    def __init__(self,
                 input_ch,
                 output_ch,
                 bias=False):
        super(InitialBlock, self).__init__()

        self.main_branch = nn.Sequential(
            nn.Conv2d(in_channels=input_ch,
                      out_channels=output_ch - 3,
                      kernel_size=3,
                      stride=2,
                      padding=1,
                      bias=bias),
            nn.BatchNorm2d(output_ch - 3),
            nn.PReLU()
        )
        self.ext_branch = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        main = self.main_branch(x)
        ext = self.ext_branch(x)

        out = torch.cat((main, ext), dim=1) # N, C, H, W

        return out


class RegularBottleNeck(nn.Module):
    def __init__(self,
                 input_ch,
                 output_ch,
                 projection_ratio=4,
                 regularizer_prob=0,
                 dilation=0,
                 assymmetric=False,
                 bias=False):
        super(RegularBottleNeck, self).__init__()

        reduced_depth = input_ch // projection_ratio

        self.ext_branch_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_ch,
                      out_channels=reduced_depth,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      bias=bias),
            nn.BatchNorm2d(reduced_depth),
            nn.PReLU()
        )

        if dilation:
            self.ext_branch_2 = nn.Sequential(
                nn.Conv2d(in_channels=reduced_depth,
                          out_channels=reduced_depth,
                          kernel_size=3,
                          stride=1,
                          padding=dilation,
                          bias=bias,
                          dilation=dilation),
                nn.BatchNorm2d(reduced_depth),
                nn.PReLU()
            )

        elif assymmetric:
            self.ext_branch_2 = nn.Sequential(
                nn.Conv2d(in_channels=reduced_depth,
                          out_channels=reduced_depth,
                          kernel_size=(5, 1),
                          stride=1,
                          padding=(2, 0),
                          bias=bias),
                nn.Conv2d(in_channels=reduced_depth,
                          out_channels=reduced_depth,
                          kernel_size=(1, 5),
                          stride=1,
                          padding=(0, 2),
                          bias=bias),
                nn.BatchNorm2d(reduced_depth),
                nn.PReLU()
            )

        else: # Regular Bottle Neck
            self.ext_branch_2 = nn.Sequential(
                nn.Conv2d(in_channels=reduced_depth,
                          out_channels=reduced_depth,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=bias),
                nn.BatchNorm2d(reduced_depth),
                nn.PReLU()
            )

        self.ext_branch_3 = nn.Sequential(
            nn.Conv2d(in_channels=reduced_depth,
                      out_channels=output_ch,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      bias=bias),
            nn.BatchNorm2d(output_ch),
            nn.PReLU()
        )

        self.regularizer = nn.Dropout2d(p=regularizer_prob)

        self.prelu = nn.PReLU()

    def forward(self, x):
        main = x
        ext = self.ext_branch_1(x)
        ext = self.ext_branch_2(ext)
        ext = self.ext_branch_3(ext)
        ext = self.regularizer(ext)

        out = self.prelu(main + ext)

        return out


class DownSampleBottleNeck(nn.Module):
    def __init__(self,
                 input_ch,
                 output_ch,
                 projection_ratio=4,
                 regularizer_prob=0,
                 bias=False):
        super(DownSampleBottleNeck, self).__init__()

        reduced_depth = input_ch // projection_ratio

        self.ext_branch_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_ch,
                      out_channels=reduced_depth,
                      kernel_size=2,
                      stride=2,
                      padding=0,
                      bias=bias),
            nn.BatchNorm2d(reduced_depth),
            nn.PReLU()
        )

        self.ext_branch_2 = nn.Sequential(
            nn.Conv2d(in_channels=reduced_depth,
                      out_channels=reduced_depth,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=bias),
            nn.BatchNorm2d(reduced_depth),
            nn.PReLU()
        )

        self.ext_branch_3 = nn.Sequential(
            nn.Conv2d(in_channels=reduced_depth,
                      out_channels=output_ch,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      bias=bias),
            nn.BatchNorm2d(output_ch),
            nn.PReLU()
        )

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, return_indices=True)

        self.regularizer = nn.Dropout2d(p=regularizer_prob)

        self.prelu = nn.PReLU()

    def forward(self, x):
        main, ind = self.max_pool(x)
        ext = self.ext_branch_1(x)
        ext = self.ext_branch_2(ext)
        ext = self.ext_branch_3(ext)
        ext = self.regularizer(ext)

        # Feature map padding
        n, ch_ext, h, w = ext.size()
        ch_main = main.size()[1]
        padding = torch.autograd.Variable(torch.zeros(n, ch_ext-ch_main, h, w))
        if main.is_cuda:
            padding = padding.cuda()
        main = torch.cat((main, padding), dim=1)

        out = self.prelu(main + ext)

        return out, ind


class UpSampleBottleNeck(nn.Module):
    def __init__(self,
                 input_ch,
                 output_ch,
                 projection_ratio=4,
                 regularizer_prob=0,
                 bias=False):
        super(UpSampleBottleNeck, self).__init__()

        reduced_depth = input_ch // projection_ratio

        self.main_branch = nn.Sequential(
            nn.Conv2d(in_channels=input_ch,
                      out_channels=output_ch,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=bias),
            nn.BatchNorm2d(output_ch)
        )

        self.ext_branch_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_ch,
                      out_channels=reduced_depth,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      bias=bias),
            nn.BatchNorm2d(reduced_depth),
            nn.PReLU()
        )

        self.ext_branch_2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=reduced_depth,
                      out_channels=reduced_depth,
                      kernel_size=3,
                      stride=2,
                      padding=1,
                      output_padding=1,
                      bias=bias),
            nn.BatchNorm2d(reduced_depth),
            nn.PReLU()
        )

        self.ext_branch_3 = nn.Sequential(
            nn.Conv2d(in_channels=reduced_depth,
                      out_channels=output_ch,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      bias=bias),
            nn.BatchNorm2d(output_ch),
            nn.PReLU()
        )
        self.max_unpool = nn.MaxUnpool2d(kernel_size=2, stride=2, padding=0)

        self.regularizer = nn.Dropout2d(p=regularizer_prob)

        self.prelu = nn.PReLU()

    def forward(self, x, ind):
        main = self.main_branch(x)
        main = self.max_unpool(main, ind)
        ext = self.ext_branch_1(x)
        ext = self.ext_branch_2(ext)
        ext = self.ext_branch_3(ext)
        ext = self.regularizer(ext)

        out = self.prelu(main + ext)

        return out


class ENet(nn.Module):
    def __init__(self, input_ch, output_ch):
        super(ENet, self).__init__()

        # Initial
        self.initial_block = InitialBlock(input_ch=input_ch, output_ch=16)

        # Shared Encoder
        # BottleNeck1
        self.bottleNeck1_0 = DownSampleBottleNeck(input_ch=16, output_ch=64, regularizer_prob=0.01)
        self.bottleNeck1_1 = RegularBottleNeck(input_ch=64, output_ch=64, regularizer_prob=0.01)
        self.bottleNeck1_2 = RegularBottleNeck(input_ch=64, output_ch=64, regularizer_prob=0.01)
        self.bottleNeck1_3 = RegularBottleNeck(input_ch=64, output_ch=64, regularizer_prob=0.01)
        self.bottleNeck1_4 = RegularBottleNeck(input_ch=64, output_ch=64, regularizer_prob=0.01)

        # BottleNeck2
        self.bottleNeck2_0 = DownSampleBottleNeck(input_ch=64, output_ch=128, regularizer_prob=0.1)
        self.bottleNeck2_1 = RegularBottleNeck(input_ch=128, output_ch=128, regularizer_prob=0.1)
        self.bottleNeck2_2 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=2, regularizer_prob=0.1)
        self.bottleNeck2_3 = RegularBottleNeck(input_ch=128, output_ch=128, assymmetric=True, regularizer_prob=0.1)
        self.bottleNeck2_4 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=4, regularizer_prob=0.1)
        self.bottleNeck2_5 = RegularBottleNeck(input_ch=128, output_ch=128, regularizer_prob=0.1)
        self.bottleNeck2_6 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=8, regularizer_prob=0.1)
        self.bottleNeck2_7 = RegularBottleNeck(input_ch=128, output_ch=128, assymmetric=True, regularizer_prob=0.1)
        self.bottleNeck2_8 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=16, regularizer_prob=0.1)

        # Binary Segmentation
        # BottleNeck3
        self.semBottleNeck3_0 = RegularBottleNeck(input_ch=128, output_ch=128, regularizer_prob=0.1)
        self.semBottleNeck3_1 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=2, regularizer_prob=0.1)
        self.semBottleNeck3_2 = RegularBottleNeck(input_ch=128, output_ch=128, assymmetric=True, regularizer_prob=0.1)
        self.semBottleNeck3_3 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=4, regularizer_prob=0.1)
        self.semBottleNeck3_4 = RegularBottleNeck(input_ch=128, output_ch=128, regularizer_prob=0.1)
        self.semBottleNeck3_5 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=8, regularizer_prob=0.1)
        self.semBottleNeck3_6 = RegularBottleNeck(input_ch=128, output_ch=128, assymmetric=True, regularizer_prob=0.1)
        self.semBottleNeck3_7 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=16, regularizer_prob=0.1)

        # BottleNeck4
        self.semBottleNeck4_0 = UpSampleBottleNeck(input_ch=128, output_ch=64, regularizer_prob=0.1)
        self.semBottleNeck4_1 = RegularBottleNeck(input_ch=64, output_ch=64, regularizer_prob=0.1)
        self.semBottleNeck4_2 = RegularBottleNeck(input_ch=64, output_ch=64, regularizer_prob=0.1)

        # BottleNeck5
        self.semBottleNeck5_0 = UpSampleBottleNeck(input_ch=64, output_ch=16, regularizer_prob=0.1)
        self.semBottleNeck5_1 = RegularBottleNeck(input_ch=16, output_ch=16, regularizer_prob=0.1)

        self.sem_out = nn.ConvTranspose2d(in_channels=16,
                                          out_channels=output_ch,
                                          kernel_size=3,
                                          stride=2,
                                          padding=1,
                                          output_padding=1,
                                          bias=False)

        # Instance Segmentation
        # BottleNeck3
        self.insBottleNeck3_0 = RegularBottleNeck(input_ch=128, output_ch=128, regularizer_prob=0.1)
        self.insBottleNeck3_1 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=2, regularizer_prob=0.1)
        self.insBottleNeck3_2 = RegularBottleNeck(input_ch=128, output_ch=128, assymmetric=True, regularizer_prob=0.1)
        self.insBottleNeck3_3 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=4, regularizer_prob=0.1)
        self.insBottleNeck3_4 = RegularBottleNeck(input_ch=128, output_ch=128, regularizer_prob=0.1)
        self.insBottleNeck3_5 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=8, regularizer_prob=0.1)
        self.insBottleNeck3_6 = RegularBottleNeck(input_ch=128, output_ch=128, assymmetric=True, regularizer_prob=0.1)
        self.insBottleNeck3_7 = RegularBottleNeck(input_ch=128, output_ch=128, dilation=16, regularizer_prob=0.1)

        # BottleNeck4
        self.insBottleNeck4_0 = UpSampleBottleNeck(input_ch=128, output_ch=64, regularizer_prob=0.1)
        self.insBottleNeck4_1 = RegularBottleNeck(input_ch=64, output_ch=64, regularizer_prob=0.1)
        self.insBottleNeck4_2 = RegularBottleNeck(input_ch=64, output_ch=64, regularizer_prob=0.1)

        # BottleNeck5
        self.insBottleNeck5_0 = UpSampleBottleNeck(input_ch=64, output_ch=16, regularizer_prob=0.1)
        self.insBottleNeck5_1 = RegularBottleNeck(input_ch=16, output_ch=16, regularizer_prob=0.1)

        self.ins_out = nn.ConvTranspose2d(in_channels=16,
                                          out_channels=5,
                                          kernel_size=3,
                                          stride=2,
                                          padding=1,
                                          output_padding=1,
                                          bias=False)


    def forward(self, x):
        # Initial
        x = self.initial_block(x)

        # Shared Encoder
        # Stage1
        x, ind_1 = self.bottleNeck1_0(x)
        x = self.bottleNeck1_1(x)
        x = self.bottleNeck1_2(x)
        x = self.bottleNeck1_3(x)
        x = self.bottleNeck1_4(x)

        # Stage2
        x, ind_2 = self.bottleNeck2_0(x)
        x = self.bottleNeck2_1(x)
        x = self.bottleNeck2_2(x)
        x = self.bottleNeck2_3(x)
        x = self.bottleNeck2_4(x)
        x = self.bottleNeck2_5(x)
        x = self.bottleNeck2_6(x)
        x = self.bottleNeck2_7(x)
        x = self.bottleNeck2_8(x)

        # Binary Segmentation
        # Stage3
        x1 = self.semBottleNeck3_0(x)
        x1 = self.semBottleNeck3_1(x1)
        x1 = self.semBottleNeck3_2(x1)
        x1 = self.semBottleNeck3_3(x1)
        x1 = self.semBottleNeck3_4(x1)
        x1 = self.semBottleNeck3_5(x1)
        x1 = self.semBottleNeck3_6(x1)
        x1 = self.semBottleNeck3_7(x1)

        # Stage4
        x1 = self.semBottleNeck4_0(x1, ind_2)
        x1 = self.semBottleNeck4_1(x1)
        x1 = self.semBottleNeck4_2(x1)

        # Stage5
        x1 = self.semBottleNeck5_0(x1, ind_1)
        x1 = self.semBottleNeck5_1(x1)

        # Instance Segmentation
        # Stage3
        x2 = self.semBottleNeck3_0(x)
        x2 = self.semBottleNeck3_1(x2)
        x2 = self.semBottleNeck3_2(x2)
        x2 = self.semBottleNeck3_3(x2)
        x2 = self.semBottleNeck3_4(x2)
        x2 = self.semBottleNeck3_5(x2)
        x2 = self.semBottleNeck3_6(x2)
        x2 = self.semBottleNeck3_7(x2)

        # Stage4
        x2 = self.semBottleNeck4_0(x2, ind_2)
        x2 = self.semBottleNeck4_1(x2)
        x2 = self.semBottleNeck4_2(x2)

        # Stage5
        x2 = self.semBottleNeck5_0(x2, ind_1)
        x2 = self.semBottleNeck5_1(x2)

        # Stage 6
        sem = self.sem_out(x1)
        ins = self.ins_out(x2)

        return sem, ins

# Loss

In [4]:
# LOSS

from torch.nn.modules.loss import _Loss
from torch.autograd import Variable
import torch


class DiscriminativeLoss(_Loss):

    def __init__(self, delta_var=0.5, delta_dist=1.5,
                 norm=2, alpha=1.0, beta=1.0, gamma=0.001,
                 usegpu=True, size_average=True):
        super(DiscriminativeLoss, self).__init__(size_average)
        self.delta_var = delta_var
        self.delta_dist = delta_dist
        self.norm = norm
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.usegpu = usegpu
        assert self.norm in [1, 2]

    def forward(self, input, target, n_clusters):
#         _assert_no_grad(target)
        return self._discriminative_loss(input, target, n_clusters)

    def _discriminative_loss(self, input, target, n_clusters):
        bs, n_features, height, width = input.size()
        max_n_clusters = target.size(1)

        input = input.contiguous().view(bs, n_features, height * width)
        target = target.contiguous().view(bs, max_n_clusters, height * width)

        c_means = self._cluster_means(input, target, n_clusters)
        l_var = self._variance_term(input, target, c_means, n_clusters)
        l_dist = self._distance_term(c_means, n_clusters)
        l_reg = self._regularization_term(c_means, n_clusters)

        loss = self.alpha * l_var + self.beta * l_dist + self.gamma * l_reg

        return loss

    def _cluster_means(self, input, target, n_clusters):
        bs, n_features, n_loc = input.size()
        max_n_clusters = target.size(1)

        # bs, n_features, max_n_clusters, n_loc
        input = input.unsqueeze(2).expand(bs, n_features, max_n_clusters, n_loc)
        # bs, 1, max_n_clusters, n_loc
        target = target.unsqueeze(1)
        # bs, n_features, max_n_clusters, n_loc
        input = input * target

        means = []
        for i in range(bs):
            # n_features, n_clusters, n_loc
            input_sample = input[i, :, :n_clusters[i]]
            # 1, n_clusters, n_loc,
            target_sample = target[i, :, :n_clusters[i]]
            # n_features, n_cluster
            mean_sample = input_sample.sum(2) / (target_sample.sum(2) + 0.00001)

            # padding
            n_pad_clusters = max_n_clusters - n_clusters[i]
            assert n_pad_clusters >= 0
            if n_pad_clusters > 0:
                pad_sample = torch.zeros(n_features, n_pad_clusters)
                pad_sample = Variable(pad_sample)
                if self.usegpu:
                    pad_sample = pad_sample.cuda()
                mean_sample = torch.cat((mean_sample, pad_sample), dim=1)
            means.append(mean_sample)

        # bs, n_features, max_n_clusters
        means = torch.stack(means)

        return means

    def _variance_term(self, input, target, c_means, n_clusters):
        bs, n_features, n_loc = input.size()
        max_n_clusters = target.size(1)

        # bs, n_features, max_n_clusters, n_loc
        c_means = c_means.unsqueeze(3).expand(bs, n_features, max_n_clusters, n_loc)
        # bs, n_features, max_n_clusters, n_loc
        input = input.unsqueeze(2).expand(bs, n_features, max_n_clusters, n_loc)
        # bs, max_n_clusters, n_loc
        var = (torch.clamp(torch.norm((input - c_means), self.norm, 1) -
                           self.delta_var, min=0) ** 2) * target

        var_term = 0
        for i in range(bs):
            # n_clusters, n_loc
            var_sample = var[i, :n_clusters[i]]
            # n_clusters, n_loc
            target_sample = target[i, :n_clusters[i]]

            # n_clusters
            c_var = var_sample.sum(1) / (target_sample.sum(1) + 0.00001)
            var_term += c_var.sum() / int(n_clusters[i])
        var_term /= bs

        return var_term

    def _distance_term(self, c_means, n_clusters):
        bs, n_features, max_n_clusters = c_means.size()

        dist_term = 0
        for i in range(bs):
            if n_clusters[i] <= 1:
                continue

            # n_features, n_clusters
            mean_sample = c_means[i, :, :n_clusters[i]]

            # n_features, n_clusters, n_clusters
            means_a = mean_sample.unsqueeze(2).expand(n_features, n_clusters[i], n_clusters[i])
            means_b = means_a.permute(0, 2, 1)
            diff = means_a - means_b

            margin = 2 * self.delta_dist * (1.0 - torch.eye(n_clusters[i]))
            margin = Variable(margin)
            if self.usegpu:
                margin = margin.cuda()
            c_dist = torch.sum(torch.clamp(margin - torch.norm(diff, self.norm, 0), min=0) ** 2)
            dist_term += c_dist / (2 * n_clusters[i] * (n_clusters[i] - 1))
        dist_term /= bs

        return dist_term

    def _regularization_term(self, c_means, n_clusters):
        bs, n_features, max_n_clusters = c_means.size()

        reg_term = 0
        for i in range(bs):
            # n_features, n_clusters
            mean_sample = c_means[i, :, :n_clusters[i]]
            reg_term += torch.mean(torch.norm(mean_sample, self.norm, 0))
        reg_term /= bs

        return reg_term

# Dataset

In [5]:
# DATASET

import torch
from torch.utils import data
from skimage.transform import AffineTransform, warp
from skimage import img_as_float64, img_as_float32, img_as_ubyte
import warnings
import numpy as np
import matplotlib.pyplot as plt
import cv2
import json
import glob
import os


class tuSimpleDataset(data.Dataset):
    # refer from : 
    # https://github.com/vxy10/ImageAugmentation
    # https://github.com/TuSimple/tusimple-benchmark/blob/master/example/lane_demo.ipynb
    def __init__(self, file_path, size=[640, 360], gray=True, train=True, intensity=10):
        warnings.simplefilter("ignore")

        self.width = size[0]
        self.height = size[1]
        self.n_seg = 5
        self.file_path = file_path
        self.flags = {'size':size, 'gray':gray, 'train':train, 'intensity':intensity}
        self.json_lists = glob.glob(os.path.join(self.file_path, '*.json'))
        self.labels = []
        for json_list in self.json_lists:
            self.labels += [json.loads(line) for line in open(json_list)]
        self.lanes = [lane['lanes'] for lane in self.labels]
        self.y_samples = [y_sample['h_samples'] for y_sample in self.labels]
        self.raw_files = [raw_file['raw_file'] for raw_file in self.labels]

        self.img = np.zeros(size, np.uint8)
        self.label_img = np.zeros(size, np.uint8)
        self.ins_img = np.zeros((0,size[0],size[1]), np.uint8)
        
        self.len = len(self.labels)
        
    def random_transform(self):
        intensity=self.flags['intensity']
        def _get_delta(intensity):
            delta = np.radians(intensity)
            rand_delta = np.random.uniform(low=-delta, high=delta)
            return rand_delta

        trans_M = AffineTransform(scale=(.9, .9),
                                 translation=(-_get_delta(intensity), _get_delta(intensity)),
                                 shear=_get_delta(intensity))
        self.img = img_as_float32(self.img)
        self.label_img = img_as_float32(self.label_img)
        self.ins_img = img_as_float32(self.ins_img)

        self.img = warp(self.img, trans_M)
        self.label_img = warp(self.label_img, trans_M)
        for i in range(len(self.ins_img)):
            self.ins_img[i] = warp(self.ins_img[i], trans_M)
    
    def image_resize(self):
        ins = []
        self.img = cv2.resize(self.img, tuple(self.flags['size']), interpolation=cv2.INTER_CUBIC)
        self.label_img = cv2.resize(self.label_img, tuple(self.flags['size']), interpolation=cv2.INTER_CUBIC)
        for i in range(len(self.ins_img)):
            dst = cv2.resize(self.ins_img[i], tuple(self.flags['size']), interpolation=cv2.INTER_CUBIC)
            ins.append(dst)

        self.ins_img = np.array(ins, dtype=np.uint8)
    
    def preprocess(self):
        # CLAHE nomalization
        img = cv2.cvtColor(self.img, cv2.COLOR_RGB2LAB)
        img_plane = cv2.split(img)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        img_plane[0] = clahe.apply(img_plane[0])
        img = cv2.merge(img_plane)
        self.img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
   
    def get_lane_image(self, idx):
        lane_pts = [[(x,y) for (x,y) in zip(lane, self.y_samples[idx]) if x >= 0] for lane in self.lanes[idx]]
        while len(lane_pts) < self.n_seg:
            lane_pts.append(list())
        self.img = plt.imread(os.path.join(self.file_path, self.raw_files[idx]))
        self.height, self.width, _ = self.img.shape
        self.label_img = np.zeros((self.height, self.width), dtype=np.uint8)
        self.ins_img = np.zeros((0, self.height, self.width), dtype=np.uint8)
        
        for i, lane_pt in enumerate(lane_pts):
            cv2.polylines(self.label_img, np.int32([lane_pt]), isClosed=False, color=(1), thickness=15)
            gt = np.zeros((self.height, self.width), dtype=np.uint8)
            gt = cv2.polylines(gt, np.int32([lane_pt]), isClosed=False, color=(1), thickness=7)
            self.ins_img = np.concatenate([self.ins_img, gt[np.newaxis]])

    def __getitem__(self, idx):
        self.get_lane_image(idx)
        self.image_resize()
        self.preprocess()

        if self.flags['train']:
            #self.random_transform()
            self.img = np.array(np.transpose(self.img, (2,0,1)), dtype=np.float32)
            self.label_img = np.array(self.label_img, dtype=np.float32)
            self.ins_img = np.array(self.ins_img, dtype=np.float32)
            return torch.Tensor(self.img), torch.LongTensor(self.label_img), torch.Tensor(self.ins_img)
        else:
            self.img = np.array(np.transpose(self.img, (2,0,1)), dtype=np.float32)
            return torch.Tensor(self.img)
    
    def __len__(self):
        return self.len


# Logger

In [6]:
%pip install scipy==1.2.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [7]:
# LOGGER

import tensorflow as tf
import numpy as np
import scipy.misc

# from PIL import Image


try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO         # Python 3.x


class Logger(object):

    def __init__(self, log_dir):
        """Create a summary writer logging to log_dir."""
        self.writer = tf.summary.FileWriter(log_dir)

    def scalar_summary(self, tag, value, step):
        """Log a scalar variable."""
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
        self.writer.add_summary(summary, step)

    def image_summary(self, tag, images, step):
        """Log a list of images."""

        img_summaries = []
        for i, img in enumerate(images):
            # Write the image to a string
            try:
                s = StringIO()
            except:
                s = BytesIO()
            scipy.misc.toimage(img).save(s, format="png")

            # Create an Image object
            img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
                                       height=img.shape[0],
                                       width=img.shape[1])
            # Create a Summary value
            img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))

        # Create and write Summary
        summary = tf.Summary(value=img_summaries)
        self.writer.add_summary(summary, step)

    def histo_summary(self, tag, values, step, bins=1000):
        """Log a histogram of the tensor of values."""

        # Create a histogram using numpy
        counts, bin_edges = np.histogram(values, bins=bins)

        # Fill the fields of the histogram proto
        hist = tf.HistogramProto()
        hist.min = float(np.min(values))
        hist.max = float(np.max(values))
        hist.num = int(np.prod(values.shape))
        hist.sum = float(np.sum(values))
        hist.sum_squares = float(np.sum(values**2))

        # Drop the start of the first bin
        bin_edges = bin_edges[1:]

        # Add bin edges and counts
        for edge in bin_edges:
            hist.bucket_limit.append(edge)
        for c in counts:
            hist.bucket.append(c)

        # Create and write Summary
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
        self.writer.add_summary(summary, step)
        self.writer.flush()

# Train

## Define parameters for training

In [14]:
args = {
    "train_path": "/content/tusimple/",
    "lr": 1e-5,
    "batch_size": 10,
    "img_size": [224, 224],
    "epoch": 10
}

INPUT_CHANNELS = 3
OUTPUT_CHANNELS = 2
LEARNING_RATE = args["lr"] #1e-5
BATCH_SIZE = args["batch_size"] #20
NUM_EPOCHS = args["epoch"] #100
LOG_INTERVAL = 10
SIZE = [args["img_size"][0], args["img_size"][1]] #[224, 224]

## Manage DataSet and Drive connection

In [9]:
# mound google drive

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
# unzip the dataset

from zipfile import ZipFile as TuSimpleCompressed
file_name = "/content/drive/MyDrive/tusimple.zip"

with TuSimpleCompressed(file_name, 'r') as tusimplerar:
  tusimplerar.extractall()
  print('Done')


Done


## Train method

In [10]:
def train():
    # refer from : https://github.com/Sayan98/pytorch-segnet/blob/master/src/train.py
    is_better = True
    prev_loss = float('inf')
    
    model.train()
    
    for epoch in range(NUM_EPOCHS):
        t_start = time.time()
        loss_f = []

        for batch_idx, (imgs, sem_labels, ins_labels) in enumerate(train_dataloader):
            loss = 0

            img_tensor = torch.autograd.Variable(imgs).cuda()
            sem_tensor = torch.autograd.Variable(sem_labels).cuda()
            ins_tensor = torch.autograd.Variable(ins_labels).cuda()

            # Init gradients
            optimizer.zero_grad()

            # Predictions
            sem_pred, ins_pred = model(img_tensor)

            # Discriminative Loss
            disc_loss = criterion_disc(ins_pred, ins_tensor, [5] * len(img_tensor))
            loss += disc_loss

            # CrossEntropy Loss

            ce_loss = criterion_ce(sem_pred.permute(0,2,3,1).contiguous().view(-1,OUTPUT_CHANNELS),
                                   sem_tensor.view(-1))
            loss += ce_loss

            loss.backward()
            optimizer.step()

            loss_f.append(loss.cpu().data.numpy())

            if batch_idx % LOG_INTERVAL == 0:
                print('\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(imgs), len(train_dataloader.dataset),
                    100. * batch_idx / len(train_dataloader), loss.item()))

                #Tensorboard
                info = {'loss': loss.item(), 'ce_loss': ce_loss.item(), 'disc_loss': disc_loss.item(), 'epoch': epoch}

                for tag, value in info.items():
                    logger.scalar_summary(tag, value, batch_idx + 1)

                # 2. Log values and gradients of the parameters (histogram summary)
                for tag, value in model.named_parameters():
                    tag = tag.replace('.', '/')
                    logger.histo_summary(tag, value.data.cpu().numpy(), batch_idx + 1)
                    # logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), batch_idx + 1)

                # 3. Log training images (image summary)
                info = {'images': img_tensor.view(-1, 3, SIZE[0], SIZE[1])[:BATCH_SIZE].cpu().numpy(),
                        'labels': sem_tensor.view(-1, SIZE[0], SIZE[1])[:BATCH_SIZE].cpu().numpy(),
                        'sem_preds': sem_pred.view(-1, 2, SIZE[0], SIZE[1])[:BATCH_SIZE,1].data.cpu().numpy(),
                        'ins_preds': ins_pred.view(-1, SIZE[0], SIZE[1])[:BATCH_SIZE*5].data.cpu().numpy()}

                for tag, images in info.items():
                    logger.image_summary(tag, images, batch_idx + 1)
            
        dt = time.time() - t_start
        is_better = np.mean(loss_f) < prev_loss
        scheduler.step()
        
        if is_better:
            prev_loss = np.mean(loss_f)
            print("\t\tBest Model.")
            torch.save(model.state_dict(), "model_best.pth")
            
        print("Epoch #{}\tLoss: {:.8f}\t Time: {:2f}s, Lr: {:2f}".format(epoch+1, np.mean(loss_f), dt, optimizer.param_groups[0]['lr']))

## Setup for training

In [11]:
logger = Logger('./logs')

train_path = args["train_path"]
train_dataset = tuSimpleDataset(train_path, size=SIZE)
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16)

In [12]:
model = SegNet(input_ch=INPUT_CHANNELS, output_ch=OUTPUT_CHANNELS).cuda() 
#model = ENet(input_ch=INPUT_CHANNELS, output_ch=OUTPUT_CHANNELS).cuda() 
if os.path.isfile("model_best.pth"):
    print("Loaded model_best.pth")
    model.load_state_dict(torch.load("model_best.pth"))

criterion_ce = torch.nn.CrossEntropyLoss().cuda()
criterion_disc = DiscriminativeLoss(delta_var=0.1,
                                    delta_dist=0.6,
                                    norm=2,
                                    usegpu=True).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,30,40,50,60,70,80], gamma=0.9)

## Start training

In [15]:
train()

		Best Model.
Epoch #1	Loss: 0.05057811	 Time: 380.017411s, Lr: 0.000010
		Best Model.
Epoch #2	Loss: 0.04951959	 Time: 379.367138s, Lr: 0.000010
		Best Model.
Epoch #3	Loss: 0.04806189	 Time: 379.197200s, Lr: 0.000010
		Best Model.
Epoch #4	Loss: 0.04709388	 Time: 377.008153s, Lr: 0.000009
		Best Model.
Epoch #5	Loss: 0.04551448	 Time: 376.666710s, Lr: 0.000009
		Best Model.
Epoch #6	Loss: 0.04458997	 Time: 377.541611s, Lr: 0.000009
		Best Model.
Epoch #7	Loss: 0.04381244	 Time: 375.320836s, Lr: 0.000009
		Best Model.
Epoch #8	Loss: 0.04322505	 Time: 376.079275s, Lr: 0.000009
		Best Model.
Epoch #9	Loss: 0.04221848	 Time: 376.452669s, Lr: 0.000009
		Best Model.
Epoch #10	Loss: 0.04185854	 Time: 377.882461s, Lr: 0.000009
