In [None]:
import os
import urllib.request
import zipfile
import tarfile
import time
import math
import numpy as np
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from PIL import Image, ImageOps, ImageFilter

In [None]:
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, anno_class_img):
        for t in self.transforms:
            img, anno_class_img = t(img, anno_class_img)
        return img, anno_class_img


class Scale(object):
    def __init__(self, scale):
        self.scale = scale

    def __call__(self, img, anno_class_img):

        width = img.size[0]
        height = img.size[1]
        scale = np.random.uniform(self.scale[0], self.scale[1])
        scaled_w = int(width * scale)
        scaled_h = int(height * scale)
        img = img.resize((scaled_w, scaled_h), Image.BICUBIC)
        anno_class_img = anno_class_img.resize((scaled_w, scaled_h), Image.NEAREST)
        if scale > 1.0:
            left = scaled_w - width
            left = int(np.random.uniform(0, left))
            top = scaled_h-height
            top = int(np.random.uniform(0, top))
            img = img.crop((left, top, left+width, top+height))
            anno_class_img = anno_class_img.crop((left, top, left+width, top+height))
        else:
            p_palette = anno_class_img.copy().getpalette()
            img_original = img.copy()
            anno_class_img_original = anno_class_img.copy()
            pad_width = width-scaled_w
            pad_width_left = int(np.random.uniform(0, pad_width))
            pad_height = height-scaled_h
            pad_height_top = int(np.random.uniform(0, pad_height))
            img = Image.new(img.mode, (width, height), (0, 0, 0))
            img.paste(img_original, (pad_width_left, pad_height_top))
            anno_class_img = Image.new(anno_class_img.mode, (width, height), (0))
            anno_class_img.paste(anno_class_img_original,(pad_width_left, pad_height_top))
            anno_class_img.putpalette(p_palette)

        return img, anno_class_img


class RandomRotation(object):
    def __init__(self, angle):
        self.angle = angle

    def __call__(self, img, anno_class_img):
        rotate_angle = (np.random.uniform(self.angle[0], self.angle[1]))
        img = img.rotate(rotate_angle, Image.BILINEAR)
        anno_class_img = anno_class_img.rotate(rotate_angle, Image.NEAREST)

        return img, anno_class_img


class RandomMirror(object):
    def __call__(self, img, anno_class_img):
        if np.random.randint(2):
            img = ImageOps.mirror(img)
            anno_class_img = ImageOps.mirror(anno_class_img)
        return img, anno_class_img


class Resize(object):
    def __init__(self, input_size):
        self.input_size = input_size

    def __call__(self, img, anno_class_img):
        img = img.resize((self.input_size, self.input_size),Image.BICUBIC)
        anno_class_img = anno_class_img.resize((self.input_size, self.input_size), Image.NEAREST)

        return img, anno_class_img


class Normalize_Tensor(object):
    def __init__(self, color_mean, color_std):
        self.color_mean = color_mean
        self.color_std = color_std

    def __call__(self, img, anno_class_img):
        img = transforms.functional.to_tensor(img)
        img = transforms.functional.normalize(img, self.color_mean, self.color_std)
        anno_class_img = np.array(anno_class_img)
        index = np.where(anno_class_img == 255)
        anno_class_img[index] = 0
        anno_class_img = torch.from_numpy(anno_class_img)

        return img, anno_class_img

In [None]:
def make_datapath_list(rootpath):
    imgpath_template = os.path.join(rootpath, 'JPEGImages', '%s.jpg')
    annopath_template = os.path.join(rootpath, 'SegmentationClass', '%s.png')
    
    train_id_names = os.path.join(rootpath + 'ImageSets/Segmentation/train.txt')
    val_id_names = os.path.join(rootpath + 'ImageSets/Segmentation/val.txt')
    
    train_img_list = list()
    train_anno_list = list()
    
    for line in open(train_id_names):
        file_id = line.strip()
        img_path = (imgpath_template % file_id)
        anno_path = (annopath_template % file_id)
        train_img_list.append(img_path)
        train_anno_list.append(anno_path)
        
    val_img_list = list()
    val_anno_list = list()
    
    for line in open(val_id_names):
        file_id = line.strip()
        img_path = (imgpath_template % file_id)
        anno_path = (annopath_template % file_id)
        val_img_list.append(img_path)
        val_anno_list.append(anno_path)
        
    return train_img_list, train_anno_list, val_img_list, val_anno_list

In [None]:
class DataTransform():
    def __init__(self, input_size, color_mean, color_std):
        self.data_transform = {
            'train': Compose([
                Scale(scale=[0.5, 1.5]),
                RandomRotation(angle=[-10,10]),
                RandomMirror(),
                Resize(input_size),
                Normalize_Tensor(color_mean, color_std),
            ]),
            'val': Compose([
                Resize(input_size),
                Normalize_Tensor(color_mean, color_std)
            ])
        }
        
    def __call__(self, phase, img, anno_class_img):
        return self.data_transform[phase](img, anno_class_img)
    

class VOCDataset(data.Dataset):
    def __init__(self, img_list, anno_list, phase, transform):
        self.img_list = img_list
        self.anno_list = anno_list
        self.phase = phase
        self.transform = transform
        
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, index):
        img, anno_class_img = self.pull_item(index)
        
        return img, anno_class_img
    
    def pull_item(self, index):
        image_file_path = self.img_list[index]
        img = Image.open(image_file_path)
        anno_file_path = self.anno_list[index]
        anno_class_img = Image.open(anno_file_path)
        img, anno_class_img = self.transform(self.phase, img, anno_class_img)
        
        return img, anno_class_img

In [None]:
class PSPNet_ResNet50(nn.Module):
    def __init__(self, n_classes):
        super(PSPNet_ResNet50, self).__init__()
        
        block_config = [3, 4, 6, 3]
        img_size = 475
        img_size_8 = np.ceil(img_size/8).astype(np.int8)
        
        # encode
        self.feature_conv = FeatureMap_convolution()
        self.feature_res_1 = ResidualBlockPSP(n_blocks=block_config[0], in_channels=128, mid_channels=64, out_channels=256, stride=1, dilation=1)
        self.feature_res_2 = ResidualBlockPSP(n_blocks=block_config[1], in_channels=256, mid_channels=128, out_channels=512, stride=2, dilation=1)
        self.feature_dilated_res_1 = ResidualBlockPSP(n_blocks=block_config[2], in_channels=512, mid_channels=256, out_channels=1024, stride=1, dilation=2)
        self.feature_dilated_res_2 = ResidualBlockPSP(n_blocks=block_config[3], in_channels=1024, mid_channels=512, out_channels=2048, stride=1, dilation=4)
        
        # pyramid pooling
        self.pyramid_pooling = PyramidPooling(in_channels=2048, pool_sizes=[6, 3, 2, 1], height=img_size_8, width=img_size_8)
        
        # decoder
        self.decode_feature = DecodePSPFeature(height=img_size, width=img_size, n_classes=n_classes)
        
        # auxloss
        self.aux = AuxiliaryPSPlayers(in_channels=1024, height=img_size, width=img_size, n_classes=n_classes)
        
    def forward(self, x):
        x = self.feature_conv(x)
        x = self.feature_res_1(x)
        x = self.feature_res_2(x)
        x = self.feature_dilated_res_1(x)
        output_aux = self.aux(x)
        x = self.feature_dilated_res_2(x)
        x = self.pyramid_pooling(x)
        output = self.decode_feature(x)
        
        return (output, output_aux)

In [None]:
class conv2DBatchNormRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, bias):
        super(conv2DBatchNormRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        outputs = self.relu(x)
        
        return outputs
    
    
class FeatureMap_convolution(nn.Module):
    def __init__(self):
        super(FeatureMap_convolution, self).__init__()
        
        self.cbnr_1 = conv2DBatchNormRelu(3,64,3,2,1,1,False)
        self.cbnr_2 = conv2DBatchNormRelu(64,64,3,1,1,1,False)
        self.cbnr_3 = conv2DBatchNormRelu(64,128,3,1,1,1,False)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
    def forward(self, x):
        x = self.cbnr_1(x)
        x = self.cbnr_2(x)
        x = self.cbnr_3(x)
        outputs = self.maxpool(x)
        
        return outputs

In [None]:
class ResidualBlockPSP(nn.Sequential):
    def __init__(self, n_blocks, in_channels, mid_channels, out_channels, stride, dilation):
        super(ResidualBlockPSP, self).__init__()
        
        self.add_module('block1', bottleNeckPSP(in_channels, mid_channels, out_channels, stride, dilation))
        
        for i in range(n_blocks-1):
            self.add_module(f"block{i+2}", bottleNeckIdentifyPSP(out_channels, mid_channels, stride, dilation))
            

class conv2DBatchNorm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, bias):
        super(conv2DBatchNorm, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        x = self.conv(x)
        outputs = self.batchnorm(x)
        
        return outputs
            
            
class bottleNeckPSP(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, stride, dilation):
        super(bottleNeckPSP, self).__init__()
        
        self.cbr_1 = conv2DBatchNormRelu(in_channels, mid_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.cbr_2 = conv2DBatchNormRelu(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
        self.cb_3 = conv2DBatchNormRelu(mid_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.cb_residual = conv2DBatchNorm(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, dilation=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        conv = self.cb_3(self.cbr_2(self.cbr_1(x)))
        residual = self.cb_residual(x)
        
        return self.relu(conv+residual)
    
    
class bottleNeckIdentifyPSP(nn.Module):
    def __init__(self, in_channels, mid_channels, stride, dilation):
        super(bottleNeckIdentifyPSP, self).__init__()
        
        self.cbr_1 = conv2DBatchNormRelu(in_channels, mid_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.cbr_2 = conv2DBatchNormRelu(mid_channels, mid_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False)
        self.cb_3 = conv2DBatchNormRelu(mid_channels, in_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        conv = self.cb_3(self.cbr_2(self.cbr_1(x)))
        residual = x
        return self.relu(conv+residual)

In [None]:
class PyramidPooling(nn.Module):
    def __init__(self, in_channels, pool_sizes, height, width):
        super(PyramidPooling, self).__init__()
        
        self.height = height
        self.width = width
        
        out_channels = int(in_channels / len(pool_sizes))
        
        self.avpool_1 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[0])
        self.cbr_1 = conv2DBatchNormRelu(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        
        self.avpool_2 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[1])
        self.cbr_2 = conv2DBatchNormRelu(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        
        self.avpool_3 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[2])
        self.cbr_3 = conv2DBatchNormRelu(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        
        self.avpool_4 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[3])
        self.cbr_4 = conv2DBatchNormRelu(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        
    def forward(self, x):
        out1 = self.cbr_1(self.avpool_1(x))
        out1 = F.interpolate(out1, size=(self.height, self.width), mode='bilinear', align_corners=True)
        
        out2 = self.cbr_2(self.avpool_2(x))
        out2 = F.interpolate(out2, size=(self.height, self.width), mode='bilinear', align_corners=True)
        
        out3 = self.cbr_3(self.avpool_3(x))
        out3 = F.interpolate(out3, size=(self.height, self.width), mode='bilinear', align_corners=True)
        
        out4 = self.cbr_4(self.avpool_4(x))
        out4 = F.interpolate(out4, size=(self.height, self.width), mode='bilinear', align_corners=True)
   
        output = torch.cat([x, out1, out2, out3, out4], dim=1)
    
        return output

In [None]:
class DecodePSPFeature(nn.Module):
    def __init__(self, height, width, n_classes):
        super(DecodePSPFeature, self).__init__()
        
        self.height = height
        self.width = width
        self.cbr = conv2DBatchNormRelu(in_channels=4096, out_channels=512, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classification = nn.Conv2d(in_channels=512, out_channels=n_classes, kernel_size=1, stride=1, padding=0)
        
    def forward(self, x):
        x = self.cbr(x)
        x = self.classification(x)
        output = F.interpolate(x, size=(self.height, self.width), mode='bilinear', align_corners=True)
        
        return output
    
    
class AuxiliaryPSPlayers(nn.Module):
    def __init__(self, in_channels, height, width, n_classes):
        super(AuxiliaryPSPlayers, self).__init__()
        
        self.height = height
        self.width = width
        self.cbr = conv2DBatchNormRelu(in_channels=in_channels, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classification = nn.Conv2d(in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)
        
    def forward(self, x):
        x = self.cbr(x)
        x = self.dropout(x)
        x = self.classification(x)
        output = F.interpolate(x, size=(self.height, self.width), mode='bilinear', align_corners=True)
        
        return output

In [None]:
net = PSPNet_ResNet50(n_classes=21)
net

In [None]:
data_dir = '/home/dotdash/data/'
weights_dir = '/home/dotdash/weights/'
target_path = os.path.join(data_dir, 'VOCtrainval_11-May-2012.tar')
url = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar'

if not os.path.exists(data_dir):
    os.mkdir(data_dir)

if not os.path.exists(weights_dir):
    os.mkdir(weights_dir)  

if not os.path.exists(target_path):
    urllib.request.urlretrieve(url, target_path)
    tar = tarfile.TarFile(target_path)
    tar.extractall(data_dir)
    tar.close()

In [None]:
rootpath = '/home/dotdash/data/VOCdevkit/VOC2012/'
train_img_list, train_anno_list, val_img_list, val_anno_list = make_datapath_list(rootpath=rootpath)
color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)
train_dataset = VOCDataset(
    train_img_list, 
    train_anno_list, 
    phase='train',
    transform=DataTransform(
        input_size=475, 
        color_mean=color_mean,
        color_std=color_std
    )
)
val_dataset = VOCDataset(
    val_img_list, 
    val_anno_list, 
    phase='val',
    transform=DataTransform(
        input_size=475, 
        color_mean=color_mean,
        color_std=color_std
    )
)
batch_size = 1
train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
dataloaders_dict = {'train': train_dataloader, 'val': val_dataloader}

In [None]:
net = PSPNet_ResNet50(n_classes=150)
state_dict = torch.load('/home/dotdash/weights/pspnet50_ADE20K.pth')

In [None]:
n_classes = 21
net.decode_feature.classification = nn.Conv2d(in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)
net.aux.classification = nn.Conv2d(in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

net.decode_feature.classification.apply(weights_init)
net.aux.classification.apply(weights_init)

In [None]:
class PSPLoss(nn.Module):
    def __init__(self, aux_weight=0.4):
        super(PSPLoss, self).__init__()
        self.aux_weight = aux_weight
        
    def forward(self, outputs, targets):
        loss = F.cross_entropy(outputs[0], targets, reduction='mean')
        loss_aux = F.cross_entropy(outputs[1], targets, reduction='mean')
        
        return loss+self.aux_weight*loss_aux

In [None]:
criterion = PSPLoss(aux_weight=0.4)

In [None]:
optimizer = optim.SGD([
    {'params': net.feature_conv.parameters(), 'lr': 1e-3},
    {'params': net.feature_res_1.parameters(), 'lr': 1e-3},
    {'params': net.feature_res_2.parameters(), 'lr': 1e-3},
    {'params': net.feature_dilated_res_1.parameters(), 'lr': 1e-3},
    {'params': net.feature_dilated_res_2.parameters(), 'lr': 1e-3},
    {'params': net.pyramid_pooling.parameters(), 'lr': 1e-3},
    {'params': net.decode_feature.parameters(), 'lr': 1e-2},
    {'params': net.aux.parameters(), 'lr': 1e-2},
], momentum=0.9, weight_decay=0.0001)

In [None]:
def lambda_epoch(epoch):
    max_epoch = 30
    return math.pow((1-epoch/max_epoch), 0.9)

In [None]:
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch)

In [None]:
def train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    torch.backends.cudnn.benchmark = True
    
    num_train_imgs = len(dataloaders_dict['train'].dataset)
    num_val_imgs = len(dataloaders_dict['val'].dataset)
    batch_size = dataloaders_dict['train'].batch_size
    
    iteration = 1
    logs = []
    batch_multiplier = 3
    
    for epoch in range(num_epochs):
        t_epoch_start = time.time()
        t_iter_start = time.time()
        epoch_train_loss = 0.0
        epoch_val_loss = 0.0
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()
                scheduler.step()
                optimizer.zero_grad()
                print('train')
            else:
                if ((epoch+1)%5==0):
                    net.eval()
                    print('val')
                else:
                    continue
            count = 0
            for imges, anno_class_imges in dataloaders_dict[phase]:
                if imges.size()[0] == 1:
                    continue
                
                imges = imges.to(device)
                anno_class_imges = anno_class_imges.to(device)
                
                if (phase == 'train') and (count == 0):
                    optimizer.step()
                    optimizer.zero_grad()
                    count = batch_multiplier
                    
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(imges)
                    loss = criterion(outputs, anoo_class_imges.long()) / batch_multiplier
                    
                    if phase == 'train':
                        loss.backward()
                        count -= 1
                        
                        if (iteration % 10 == 0):
                            t_iter_finish = time.time()
                            duration = t_iter_finish - t_iter_start
                            print(f"iteration {iteration} | Loss: {loss.item()/batch_size+batch_multiplier:.4f} | 10iter: {duration} sec.")
                            t_iter_start = time.time()
                            
                        epoch_train_loss += loss.item() * batch_multiplier
                        iteration += 1
                    else:
                        epoch_val_loss += loss.item() * batch_multiplier
                        
        t_epoch_finish = time.time()
        print(f"epoch {epoch+1} | Epoch_Train_Loss: {epoch_train_loss/num_train_imgs:.4f} | Epoch_Val_Loss: {epoch_val_loss/num_val_imgs:.4f} | timer: {t_epoch_finish-t_epoch_start:.4f}")
        
        log_epoch = {'epoch': epoch+1, 'train_loss': epoch_train_loss/num_train_imgs, 'val_loss': epoch_val_loss/num_val_imgs}
        logs.append(log_epoch)
        df = pd.DataFrame(logs)
        df.to_csv('log_output.csv')
    
    torch.save(net.state_dict(), f"weights/pspnet50_{epoch+1}.pth")

In [None]:
# My GPU is not powerful enough. 
num_epochs = 2
train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs=num_epochs)

In [None]:
import torch.onnx

In [None]:
model = PSPNet_ResNet50(n_classes=21)
state_dict = torch.load("/home/dotdash/weights/pspnet50_30.pth",
                        map_location={'cuda:0': 'cpu'})
model.load_state_dict(state_dict)
model.eval()
batch_size, channel_size, height, width = 1, 3, 475, 475
x = torch.randn(batch_size, channel_size, height, width, requires_grad=True)
torch_out = model(x) 

In [None]:
torch.onnx.export(
    model, # model being run 
    x, # model input (or a tuple for multiple inputs) 
    '/home/dotdash/weights/temp.onnx', # where to save the model  
    export_params=True, # store the trained parameter weights inside the model file 
    opset_version=10, # the ONNX version to export the model to 
    do_constant_folding=True, # whether to execute constant folding for optimization 
    input_names = ['modelInput'], # the model's input names 
    output_names = ['modelOutput'], # the model's output names 
    dynamic_axes={'modelInput' : {0 : 'batch_size'},'modelOutput' : {0 : 'batch_size'}} # variable length axes 
) 