In [0]:
!pip3 uninstall pytorch-hrvvi-ext -y
!pip3 install -U git+https://github.com/sbl1996/pytorch-hrvvi-ext.git


In [0]:
import sys
import os

import torch
import hutil
import matplotlib.pyplot as plt
print(hutil.__version__)

1.4.14


In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
gdrive = "/gdrive"
from google.colab import drive
drive.mount(gdrive, force_remount=True)
mydrive = os.path.join(gdrive, "My Drive")
!ls /gdrive/My\ Drive

def gpath(p):
    return os.path.join(mydrive, p)

Mounted at /gdrive
'Colab Notebooks'   eng-fra.pt	 images   repo	   weixin.pkl
 datasets	    fonts	 models   result


In [0]:
import random
import copy
from toolz import curry

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader

from torchvision.models import resnet50
from torchvision.transforms import Resize, RandomCrop, Compose, RandomHorizontalFlip

from hutil import one_hot, cuda
from hutil.train import init_weights, Trainer
from hutil.ext.summary import summary
from hutil.data import train_test_split, Fullset
from hutil.train.metrics import Accuracy, TrainLoss, Loss
from hutil.datasets import VOCSegmentation
from hutil.transforms import Compose as HCompose
from hutil.transforms.segmentation import SameTransform, ToTensor
from hutil.inference import freeze


█

In [0]:

def last_conv(m):
    convs = filter(lambda kv: 'conv' in kv[0], m.named_modules())
    return max(convs, key=lambda kv: kv[0])[1]

def get_out_channels(m):
    return last_conv(m).out_channels

def replace_first_conv(m, in_channels):
    conv = m[0].conv1
    m[0].conv1 = conv.__class__(
        in_channels, conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        bias=conv.bias is not None
    )
    return m


@curry
def conv_to_atrous(m, rate):
    r"""
    Convert Conv2d to Atrous Convolution.
    """
    if 'Conv2d' in type(m).__name__ and m.kernel_size != (1,1):
        kh, kw = m.kernel_size
        ph = int(((kh - 1) * (rate - 1) + kh - 1) / 2)
        pw = int(((kw - 1) * (rate - 1) + kw - 1) / 2)
        m.padding = (ph, pw)
        m.stride = (1, 1)
        m.dilation = (rate, rate)
    return m


class DeepLabV3(nn.Module):
    def __init__(self, backbone, num_classes, multi_grid=(1,1,1)):
        super().__init__()
        self.conv1 = backbone.conv1
        self.bn1 = backbone.bn1
        self.relu = backbone.relu
        self.maxpool = backbone.maxpool

        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4
        self.layer4[0].conv1.stride = (1,1)
        self.layer4[0].downsample[0].stride = (1,1)
        self.layer5 = copy.deepcopy(self.layer4)
        replace_first_conv(self.layer5, last_conv(self.layer4).out_channels)
        self.layer5[0].downsample = None
#         self.layer6 = copy.deepcopy(self.layer5)
#         self.layer7 = copy.deepcopy(self.layer6)

        self.fc = nn.Conv2d(get_out_channels(self.layer5), num_classes, kernel_size=1)

        r = 1
        for l in [self.layer4, self.layer5]:
            r *= 2
            for i, m in enumerate(l):
                m.apply(conv_to_atrous(rate=r*multi_grid[i]))
            l.apply(init_weights(nonlinearity='relu'))
       
    def base_parameters(self):
        for l in [self.layer5]:
            yield from l.parameters()

    def forward(self, x):
        size = x.size()[2:]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
#         x = self.layer6(x)
#         x = self.layer7(x)
        x = self.fc(x)
    
        x = F.interpolate(x, size=size, mode='bilinear', align_corners=True)    

        return x


In [0]:

WIDTH = 320
HEIGHT = 320
NUM_CLASSES = 20

img_transform = Compose([
    Resize(HEIGHT),
    RandomCrop(HEIGHT),
    RandomHorizontalFlip(),
])

transforms = HCompose([
    SameTransform(img_transform),
    ToTensor(NUM_CLASSES),
])

data_home = "./VOC"
# ds1 = VOCDetection(data_home, year='2007', image_set='trainval', download=True)
# ds2 = VOCDetection(data_home, year='2012', image_set='trainval', download=True)
# ds = ConcatDataset([ds1, ds2])
ds = VOCSegmentation(data_home, year='2012', image_set='trainaug', download=True)
# rest, ds = train_test_split(
#     ds, test_ratio=0.001
# )
# ds_train = Fullset(ds, transforms)
# ds_val = Fullset(ds, transforms)

ds_train, ds_val = train_test_split(
    ds, 0.05, transform=transforms,
    test_transform=transforms)


VOC found. Skip download or extract
SBT found. Skip download or extract


In [0]:
@curry
def seg_loss(x, y, p=0.1):
    loss = F.cross_entropy(x, y, ignore_index=255)
    if random.random() < p:
        print(loss.item())
    return loss


backbone = resnet50(True)
del backbone.fc

net = DeepLabV3(backbone, NUM_CLASSES + 1)
freeze(backbone)

criterion = seg_loss(p=0.1)
optimizer = SGD([
    {"params": filter(lambda x: x.requires_grad, net.base_parameters())},
    {"params": net.fc.parameters(), "lr": 0.01}],
    lr=0.001, momentum=0.9, dampening=0.9, weight_decay=5e-4)
# optimizer = Adam(filter(lambda x: x.requires_grad,
#                         net.parameters()), lr=1e-3, weight_decay=1e-4)

lr_scheduler = MultiStepLR(optimizer, [200, 250, 300], gamma=0.1)

metrics = {
    'loss': TrainLoss(),
    'acc': Accuracy(),
}

trainer = Trainer(net, criterion, optimizer, lr_scheduler,
                  metrics=metrics, save_path=gpath("models"), name="DeepLabV3-VOC")


In [0]:
summary(net, (3, HEIGHT, WIDTH))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 160, 160]           9,408
       BatchNorm2d-2         [-1, 64, 160, 160]             128
              ReLU-3         [-1, 64, 160, 160]               0
         MaxPool2d-4           [-1, 64, 80, 80]               0
            Conv2d-5           [-1, 64, 80, 80]           4,096
       BatchNorm2d-6           [-1, 64, 80, 80]             128
              ReLU-7           [-1, 64, 80, 80]               0
            Conv2d-8           [-1, 64, 80, 80]          36,864
       BatchNorm2d-9           [-1, 64, 80, 80]             128
             ReLU-10           [-1, 64, 80, 80]               0
           Conv2d-11          [-1, 256, 80, 80]          16,384
      BatchNorm2d-12          [-1, 256, 80, 80]             512
           Conv2d-13          [-1, 256, 80, 80]          16,384
      BatchNorm2d-14          [-1, 256,

In [0]:
trainer.load_state_dict(torch.load(gpath("models/DeepLabV3-VOC_trainer_20.pth")))

In [0]:
criterion.p = 0.01

In [0]:

train_loader = DataLoader(
    ds_train, batch_size=16, shuffle=True, num_workers=1, pin_memory=True)
val_loader = DataLoader(ds_val, batch_size=32)


In [0]:
hist = trainer.fit(train_loader, 10)
plot_history(hist)

In [0]:
def plot_history(hist):
    for k, v in hist.items():
        fig, ax = plt.subplots()
        ax.plot(v)
        ax.set_title(k)

In [0]:
%time trainer.evaluate(val_loader)

In [0]:

def extract_class(img, class_, background_class=0):
    return img.masked_fill((img != background_class) & (img != class_), background_class)


def mean_iou(pred, gt, num_classes):
    n = 0
    miou = 0
    for c in range(num_classes):
        class_gt = extract_class(gt, c)
        class_pred = extract_class(pred, c)
        class_iou = class_seg_iou(pred, gt, c)
        if class_iou is not None:
            print(c)
            print(class_iou)
            n += 1
            miou += class_iou
    return miou / n


def class_seg_iou(pred, gt, class_):
    tp = ((gt == class_) & (pred == class_)).sum().item()
    fn = ((gt == class_) & (pred != class_)).sum().item()
    fp = ((gt != class_) & (pred == class_)).sum().item()
    if tp + fn == 0:
        return None
    else:
        return tp / (tp + fn + fp)
