In [None]:
# Bilateral Segmentation Network
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ConvBNRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d):
        super(ConvBNRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
        self.bn = norm_layer(out_channels)
        self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d):
        super(BasicBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size= 3, stride= stride, padding= 1, bias= False)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(inplanes, planes, kernel_size= 3, stride= stride, padding= 1, bias= False)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [None]:
class SpatialPath(nn.Module):
    
    def __init__(self, channels_in, channels_out):
        super(SpatialPath, self).__init__()
        channels_intermediate= 64
        
        self.conv_7x7_layer= ConvBNRelu(channels_in, channels_intermediate, 7, 2, 3, norm_layer= nn.BatchNorm2d)
        self.conv_3x3_1_layer= ConvBNRelu(channels_intermediate, channels_intermediate, 3, 2, 1, norm_layer= nn.BatchNorm2d)
        self.conv_3x3_2_layer= ConvBNRelu(channels_intermediate, channels_intermediate, 3, 2, 1, norm_layer= nn.BatchNorm2d)
        self.conv_1x1_layer= ConvBNRelu(channels_intermediate, channels_out, 1, 1, 0, norm_layer= nn.BatchNorm2d)
        
    def forward(self, x):
        
        x= self.conv_7x7_layer(x)
        x= self.conv_3x3_1_layer(x)
        x= self.conv_3x3_2_layer(x)
        x= self.conv_1x1_layer(x)
        
        return x

In [None]:
class FeatureFusionModule(nn.Module):
    
    def __init__(self, channels_in, channels_out, reduction= 1, norm_layer= nn.BatchNorm2d):
        super(FeatureFusionModule, self).__init__()
        self.conv1x1= ConvBNRelu(channels_in, channels_out, 1, 1, 0, norm_layer= norm_layer)
        
        self.channel_attention= nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            ConvBNRelu(channels_out, channels_out // reduction, 1, 1, 0, norm_layer= norm_layer),
            ConvBNRelu(channels_out // reduction, channels_out, 1, 1, 0, norm_layer= norm_layer),
            nn.Sigmoid()
        )
        
    def forward(self, x1, x2):
        fusion= torch.cat([x1, x2], dim= 1)
        out= self.conv1x1(fusion)
        attention= channel_attention(out)
        out+= out*attention
        return out

In [None]:
class GlobalAvgPooling(nn.Module):
    def __init__(self, in_channels, out_channels, norm_layer):
        super(GlobalAvgPooling, self).__init__()
        self.gap = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            norm_layer(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        size = x.size()[2:]
        pool = self.gap(x)
        out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
        return out

In [None]:
class AttentionRefineModule(nn.Module):
    def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
        super(AttentionRefineModule, self).__init__()
        self.conv3x3 = ConvBNRelu(in_channels, out_channels, 3, 1, 1, norm_layer=norm_layer)
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            ConvBNRelu(out_channels, out_channels, 1, 1, 0, norm_layer=norm_layer),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv3x3(x)
        attention = self.channel_attention(x)
        x = x * attention
        return x

In [None]:
class ContextPath(nn.Module):
    
    def __init__(self, backbone= 'resnet18', norm_layer= nn.BatchNorm2d):
        super(ContextPath, self).__init__()
        
        self.conv1= nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias= False)
        self.bn1= nn.BatchNorm2d(64)
        self.relu= nn.ReLU(inplace= True)
        self.maxpool= nn.MaxPool2d(kernel_size= 3, stride= 2, padding= 1)
        
        layers= [2, 2, 2, 2]
        self.layer1= self.conv_layer(BasicBlock, 64, layers[0], norm_layer= norm_layer)
        self.layer2= self.conv_layer(BasicBlock, 128, layers[1], norm_layer= norm_layer)
        self.layer3= self.conv_layer(BasicBlock, 256, layers[2], norm_layer= norm_layer)
        self.layer4= self.conv_layer(BasicBlock, 512, layers[3], norm_layer= norm_layer)

        channels_intermediate= 128
        self.global_context= GlobalAvgPooling(512, channels_intermediate, norm_layer= norm_layer)
        self.arms= nn.ModuleList(
            [
                AttentionRefineModule(512, channels_intermediate, norm_layer),
                AttentionRefineModule(256, channels_intermediate, norm_layer)
            ]
        )
        self.refines= nn.ModuleList(
            [
                ConvBNRelu(channels_intermediate, channels_intermediate, 3, 1, 1, norm_layer= norm_layer),
                ConvBNRelu(channels_intermediate, channels_intermediate, 3, 1, 1, norm_layer= norm_layer)
            ]
        )
        
    def conv_layer(self, block, planes, blocks, stride= 1, norm_layer= nn.BatchNorm2d):
        downsample= None
        if stride != 1 or 64 != planes*block.expansion:
            downsample= nn.Sequential(
                nn.Conv2d(64, planes*block.expansion, kernel_size= 1, stride= stride, bias= False),
                nn.BatchNorm2d(planes*block.expansion)
            )
            
        layers= []
        layers.append(block(64, planes, stride, downsample))
        inplanes= planes*block.expansion
        
        for _ in range(1, blocks):
            layers.append(block(inplanes, planes))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)

        context_blocks = []
        context_blocks.append(x)
        x = self.layer2(x)
        context_blocks.append(x)
        c3 = self.layer3(x)
        context_blocks.append(c3)
        c4 = self.layer4(c3)
        context_blocks.append(c4)
        context_blocks.reverse()

        global_context = self.global_context(c4)
        last_feature = global_context
        context_outputs = []
        for i, (feature, arm, refine) in enumerate(zip(context_blocks[:2], self.arms, self.refines)):
            feature = arm(feature)
            feature += last_feature
            last_feature = F.interpolate(feature, size=context_blocks[i + 1].size()[2:], mode='bilinear', align_corners=True)
            last_feature = refine(last_feature)
            context_outputs.append(last_feature)

        return context_outputs

In [None]:
class BSNHead(nn.Module):
    
    def __init__(self, channels_in, channels_intermediate, num_classes, norm_layer= nn.BatchNorm2d):
        super(BSNHead, self).__init__()
        self.block= nn.Sequential(
            ConvBNRelu(channels_in, channels_intermediate, 3, 1, 1, norm_layer= norm_layer),
            nn.Dropout(0.1),
            nn.Conv2d(channels_intermediate, num_classes, 1)
        )
        
    def forward(self, x):
        x= self.block(x)
        return x

In [None]:
class BilateralSegmentationNetwork(nn.Module):
    
    def __init__(self, num_classes, backbone= 'resnet18'):
        
        super(BilateralSegmentationNetwork, self).__init__()
        
        self.spatial_path= SpatialPath(3, 128)
        self.context_path= ContextPath(backbone)
        self.ffm= FeatureFusionModule(256, 256, 4)
        self.head= BSNHead(256, 64, num_classes)
        
    def forward(self, x):
        size = x.size()[2:]
        spatial_out = self.spatial_path(x)
        context_out = self.context_path(x)
        fusion_out = self.ffm(spatial_out, context_out[-1])
        outputs = []
        x = self.head(fusion_out)
        x = F.interpolate(x, size, mode='bilinear', align_corners=True)
        outputs.append(x)
        return tuple(outputs)

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps, ImageFilter
import numpy as np
import glob, random

class MyDataset(Dataset):
    
    def __init__(self, images, masks, num_classes):
        
        images_list= glob.glob(images+'/*.png')
        masks_list= glob.glob(masks+'/*.png')
        self.images= images_list
        self.masks= masks_list
        
        self.base_size= 520
        self.crop_size= 480
        
    def __getitem__(self, index):
        img= Image.open(self.images[index]).convert('RGB')
        mask= Image.open(self.masks[index])
        img, mask= self.transform(img, mask)
        print(img.shape, mask.shape)
        return img, mask
    
    def __len__(self):
        return len(self.images)
        
    def transform(self, img, mask):
        # random mirror
        if random.random() < 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
            
        crop_size = self.crop_size
        # random scale (short edge)
        short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
        w, h = img.size
        if h > w:
            ow = short_size
            oh = int(1.0 * h * ow / w)
        else:
            oh = short_size
            ow = int(1.0 * w * oh / h)
        img = img.resize((ow, oh), Image.BILINEAR)
        mask = mask.resize((ow, oh), Image.NEAREST)
        # pad crop
        if short_size < crop_size:
            padh = crop_size - oh if oh < crop_size else 0
            padw = crop_size - ow if ow < crop_size else 0
            img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
            mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
        # random crop crop_size
        w, h = img.size
        x1 = random.randint(0, w - crop_size)
        y1 = random.randint(0, h - crop_size)
        img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
        mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
        # gaussian blur as in PSP
        if random.random() < 0.5:
            img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
        # final transform
        img, mask = np.array(img), np.array(mask)
        return img, mask

In [None]:
from torch.autograd import Variable

class MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss):
    def init(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
        super(MixSoftmaxCrossEntropyLoss, self).init(ignore_index=ignore_index)
        self.aux = aux
        self.aux_weight = aux_weight

    def _aux_forward(self, *inputs, **kwargs):
        *preds, target = tuple(inputs)

        loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[0], target)
        for i in range(1, len(preds)):
            aux_loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[i], target)
            loss += self.aux_weight * aux_loss
        return loss

    def forward(self, *inputs, **kwargs):
        preds, target = tuple(inputs)
        inputs = tuple(list(preds) + [target])
        if self.aux:
            return dict(loss=self._aux_forward(*inputs))
        else:
            return dict(loss=super(MixSoftmaxCrossEntropyLoss, self).forward(*inputs))

In [None]:
class LRScheduler(object):
    r"""Learning Rate Scheduler

    Parameters
    ----------
    mode : str
        Modes for learning rate scheduler.
        Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'.
    base_lr : float
        Base learning rate, i.e. the starting learning rate.
    target_lr : float
        Target learning rate, i.e. the ending learning rate.
        With constant mode target_lr is ignored.
    niters : int
        Number of iterations to be scheduled.
    nepochs : int
        Number of epochs to be scheduled.
    iters_per_epoch : int
        Number of iterations in each epoch.
    offset : int
        Number of iterations before this scheduler.
    power : float
        Power parameter of poly scheduler.
    step_iter : list
        A list of iterations to decay the learning rate.
    step_epoch : list
        A list of epochs to decay the learning rate.
    step_factor : float
        Learning rate decay factor.
    """

    def init(self, mode, base_lr=0.01, target_lr=0, niters=0, nepochs=0, iters_per_epoch=0,
                 offset=0, power=0.9, step_iter=None, step_epoch=None, step_factor=0.1, warmup_epochs=0):
        super(LRScheduler, self).init()
        assert (mode in ['constant', 'step', 'linear', 'poly', 'cosine'])

        if mode == 'step':
            assert (step_iter is not None or step_epoch is not None)
        self.niters = niters
        self.step = step_iter
        epoch_iters = nepochs * iters_per_epoch
        if epoch_iters > 0:
            self.niters = epoch_iters
            if step_epoch is not None:
                self.step = [s * iters_per_epoch for s in step_epoch]

        self.step_factor = step_factor
        self.base_lr = base_lr
        self.target_lr = base_lr if mode == 'constant' else target_lr
        self.offset = offset
        self.power = power
        self.warmup_iters = warmup_epochs * iters_per_epoch
        self.mode = mode

    def call(self, optimizer, num_update):
        self.update(num_update)
        assert self.learning_rate >= 0
        self._adjust_learning_rate(optimizer, self.learning_rate)

    def update(self, num_update):
        N = self.niters - 1
        T = num_update - self.offset
        T = min(max(0, T), N)

        if self.mode == 'constant':
            factor = 0
        elif self.mode == 'linear':
            factor = 1 - T / N
        elif self.mode == 'poly':
            factor = pow(1 - T / N, self.power)
        elif self.mode == 'cosine':
            factor = (1 + math.cos(math.pi * T / N)) / 2
        elif self.mode == 'step':
            if self.step is not None:
                count = sum([1 for s in self.step if s <= T])
                factor = pow(self.step_factor, count)
            else:
                factor = 1
        else:
            raise NotImplementedError

        # warm up lr schedule
        if self.warmup_iters > 0 and T < self.warmup_iters:
            factor = factor * 1.0 * T / self.warmup_iters

        if self.mode == 'step':
            self.learning_rate = self.base_lr * factor
        else:
            self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor

    def _adjust_learning_rate(self, optimizer, lr):
        optimizer.param_groups[0]['lr'] = lr
        # enlarge the lr at the head
        for i in range(1, len(optimizer.param_groups)):
            optimizer.param_groups[i]['lr'] = lr * 10

In [None]:
class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
    
    def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3,
                 warmup_iters=500, warmup_method='linear', last_epoch=-1):
        if warmup_method not in ("constant", "linear"):
            raise ValueError(
                "Only 'constant' or 'linear' warmup_method accepted "
                "got {}".format(warmup_method))

        self.target_lr = target_lr
        self.max_iters = max_iters
        self.power = power
        self.warmup_factor = warmup_factor
        self.warmup_iters = warmup_iters
        self.warmup_method = warmup_method

        super(WarmupPolyLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        N = self.max_iters - self.warmup_iters
        T = self.last_epoch - self.warmup_iters
        if self.last_epoch < self.warmup_iters:
            if self.warmup_method == 'constant':
                warmup_factor = self.warmup_factor
            elif self.warmup_method == 'linear':
                alpha = float(self.last_epoch) / self.warmup_iters
                warmup_factor = self.warmup_factor * (1 - alpha) + alpha
            else:
                raise ValueError("Unknown warmup type.")
            return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs]
        factor = pow(1 - T / N, self.power)
        return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs]

In [None]:
import torch.utils.data as data
from score import SegmentationMetric

class NetworkTrainer(object):
    
    def __init__(self, gpus, device, learning_rate, num_classes, batch_size):
        self.device= torch.device(device)
        
        train_dataset= MyDataset('E:\\pytorch\\bisenet\\dataset\\images', 'E:\\pytorch\\bisenet\\dataset\\masks', num_classes)
        iterations= len(train_dataset) // (num_gpus*batch_size)
        
        self.train_loader= data.DataLoader(dataset= train_dataset, num_workers=0, pin_memory=True)
        self.model= BilateralSegmentationNetwork(num_classes=3, backbone='resnet18')
        self.criterion= MixSoftmaxCrossEntropyLoss()
        self.optimizer= torch.optim.Adam(self.model.parameters(), lr= learning_rate)
        self.lr_scheduler= WarmupPolyLR(self.optimizer, power= 0.9)
        self.metric= SegmentationMetric(num_classes)
        
    def train(self):
        
        for iteration, (images, masks) in enumerate(self.train_loader):
            iteration+= 1
            self.lr_scheduler.step()
            
            images= images.to(self.device)
            masks= masks.to(self.device)
            print(images.shape)
            outputs= self.model(images)
            loss_dict= self.criterion(outputs, masks)
            losses= sum(loss for loss in loss_dict.values())
            
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            
            print('iteration ', iteration, ' over')
        
        torch.save(model.state_dict(), './weights.pth')

In [None]:
import torch.backends.cudnn as cudnn

if __name__ == '__main__':
    num_gpus= 1
    learning_rate= 0.0001
    if torch.cuda.is_available():
        cudnn.benchmark= True
        device= 'cuda'
        
    learning_rate= learning_rate*num_gpus
    batch_size= 4
    num_classes= 4
    network_trainer= NetworkTrainer(num_gpus, device, learning_rate, num_classes, batch_size)
    network_trainer.train()
    torch.cuda.empty_cache()