In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# class DoubleConv(nn.Module):
#     """(convolution => [BN] => ReLU) * 2"""

#     def __init__(self, in_channels, out_channels, mid_channels=None):
#         super().__init__()
#         if not mid_channels:
#             mid_channels = out_channels
#         self.double_conv = nn.Sequential(
#             nn.Conv2d(in_channels, mid_channels, kernel_size=2, padding=1),
#             nn.BatchNorm2d(mid_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(mid_channels, out_channels, kernel_size=2, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x):
#         return self.double_conv(x)

# class Up(nn.Module):
#     """Upscaling then double conv"""

#     def __init__(self, in_channels, out_channels, bilinear=True):
#         super().__init__()

#         # if bilinear, use the normal convolutions to reduce the number of channels
#         if bilinear:
#             self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
#             self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
#         else:
#             self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
#             self.conv = DoubleConv(in_channels, out_channels)


#     def forward(self, x1, x2):
#         x1 = self.up(x1)
#         # input is CHW
#         diffY = x2.size()[2] - x1.size()[2]
#         diffX = x2.size()[3] - x1.size()[3]

#         x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
#                         diffY // 2, diffY - diffY // 2])
#         # if you have padding issues, see
#         # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
#         # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
#         x = torch.cat([x2, x1], dim=1)
#         return self.conv(x)

# Source
# https://github.com/nyoki-mtl/pytorch-segmentation/blob/master/src/models/scse.py
class SCSEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.channel_excitation = nn.Sequential(nn.Linear(channel, int(channel // reduction)),
                                                nn.ReLU(inplace=True),
                                                nn.Linear(int(channel // reduction), channel))
        self.spatial_se = nn.Conv2d(channel, 1, kernel_size=1,
                                    stride=1, padding=0, bias=False)

    def forward(self, x):
        bahs, chs, _, _ = x.size()

        # Returns a new tensor with the same data as the self tensor but of a different size.
        chn_se = self.avg_pool(x).view(bahs, chs)
        chn_se = torch.sigmoid(self.channel_excitation(chn_se).view(bahs, chs, 1, 1))
        chn_se = torch.mul(x, chn_se)

        spa_se = torch.sigmoid(self.spatial_se(x))
        spa_se = torch.mul(x, spa_se)
        return torch.add(chn_se, 1, spa_se)


  


# New section

In [None]:
from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvision

input_size = (448, 448)

custom_decoder = True

class Interpolate(nn.Module):
    def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.size = size
        self.mode = mode
        self.scale_factor = scale_factor
        self.align_corners = align_corners

    def forward(self, x):
        x = self.interp(x, size=self.size, scale_factor=self.scale_factor,
                        mode=self.mode, align_corners=self.align_corners)
        return x

def conv3x3(in_, out):
    return nn.Conv2d(in_, out, 3, padding=1)


class ConvRelu(nn.Module):
    def __init__(self, in_, out):
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        return x


class DecoderBlockV2(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
        super(DecoderBlockV2, self).__init__()
        self.in_channels = in_channels

        if is_deconv:
            """
                Paramaters for Deconvolution were chosen to avoid artifacts, following
                link https://distill.pub/2016/deconv-checkerboard/
            """

            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                  padding=1),
                nn.ReLU(inplace=True)
            )
        else:
            self.block = nn.Sequential(
                Interpolate(scale_factor=2, mode='bilinear'),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x):
        return self.block(x)

class MyDecoderBlockV2(nn.Module):
    """
    All the other decoder blocks
    Performs Relu, BN, SCSE, 2x2 Transpose conv
    """
    def __init__(self, in_channels, out_channels, is_deconv=True):
          super(MyDecoderBlockV2, self).__init__()
          self.in_channels = in_channels

          self.block = nn.Sequential(
              nn.ReLU(),
              nn.BatchNorm2d(in_channels),
              SCSEBlock(in_channels),
              nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
          )

    def forward(self, x):
          return self.block(x)

class UNet16(nn.Module):
    def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
        """
        :param num_classes:
        :param num_filters:
        :param pretrained:
            False - no pre-trained network used
            True - encoder pre-trained with VGG16
        :is_deconv:
            False: bilinear interpolation is used in decoder
            True: deconvolution is used in decoder
        """
        super().__init__()
        self.num_classes = num_classes

        self.pool = nn.MaxPool2d(2, 2)

        #print(torchvision.models.vgg16(pretrained=pretrained))

        print("creating encoder")

        # >>>
        self.encoder = torchvision.models.vgg16(pretrained=pretrained).features
        #self.encoder = torchvision.models.resnet34(pretrained=pretrained).features

        print("creating relu")

        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Sequential(self.encoder[0],
                                   self.relu,
                                   self.encoder[2],
                                   self.relu)

        self.conv2 = nn.Sequential(self.encoder[5],
                                   self.relu,
                                   self.encoder[7],
                                   self.relu)

        self.conv3 = nn.Sequential(self.encoder[10],
                                   self.relu,
                                   self.encoder[12],
                                   self.relu,
                                   self.encoder[14],
                                   self.relu)

        self.conv4 = nn.Sequential(self.encoder[17],
                                   self.relu,
                                   self.encoder[19],
                                   self.relu,
                                   self.encoder[21],
                                   self.relu)

        self.conv5 = nn.Sequential(self.encoder[24],
                                   self.relu,
                                   self.encoder[26],
                                   self.relu,
                                   self.encoder[28],
                                   self.relu)
        
        print("creating center")

        self.center = DecoderBlockV2(512, num_filters * 8 * 2, num_filters * 8, is_deconv)

        print("creating decoder")

        self.dec5 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec4 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec3 = DecoderBlockV2(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
        self.dec2 = DecoderBlockV2(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec1 = ConvRelu(64 + num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)


    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(self.pool(conv1))
        conv3 = self.conv3(self.pool(conv2))
        conv4 = self.conv4(self.pool(conv3))
        conv5 = self.conv5(self.pool(conv4))

        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))

        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(torch.cat([dec2, conv1], 1))

        if self.num_classes > 1:
            x_out = F.log_softmax(self.final(dec1), dim=1)
        else:
            x_out = self.final(dec1)
            #x_out = F.sigmoid(x_out)

        return x_out

class UNetResNet(nn.Module):
    """PyTorch U-Net model using ResNet(34, 101 or 152) encoder.
    UNet: https://arxiv.org/abs/1505.04597
    ResNet: https://arxiv.org/abs/1512.03385
    Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/
    Args:
            encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152).
            num_classes (int): Number of output classes.
            num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
            dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
            pretrained (bool, optional):
                False - no pre-trained weights are being used.
                True  - ResNet encoder is pre-trained on ImageNet.
                Defaults to False.
            is_deconv (bool, optional):
                False: bilinear interpolation is used in decoder.
                True: deconvolution is used in decoder.
                Defaults to False.
    """

    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        if encoder_depth == 34:
            self.encoder = torchvision.models.resnet34(pretrained=pretrained)
            bottom_channel_nr = 512
        elif encoder_depth == 101:
            self.encoder = torchvision.models.resnet101(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 152:
            self.encoder = torchvision.models.resnet152(pretrained=pretrained)
            bottom_channel_nr = 2048
        else:
            raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')

        self.pool = nn.MaxPool2d(2, 2)

        self.relu = nn.ReLU(inplace=True)

        if custom_decoder:
            self.do_conv_1 = nn.Conv2d(64, 128, kernel_size=1, stride=1)
            self.do_conv_2 = nn.Conv2d(64, 128, kernel_size=1, stride=1)
            self.do_conv_3 = nn.Conv2d(128, 128, kernel_size=1, stride=1)
            self.do_conv_4 = nn.Conv2d(256, 128, kernel_size=1, stride=1)
            self.do_conv_5 = nn.Conv2d(512, 512, kernel_size=1, stride=1) 

            # Encoder layers
            self.conv1 = nn.Sequential(self.encoder.conv1,
                                      self.encoder.bn1,
                                      self.encoder.relu)
            self.conv2 = self.encoder.layer1
            self.conv3 = self.encoder.layer2
            self.conv4 = self.encoder.layer3
            self.conv5 = self.encoder.layer4
        else:
            self.conv1 = nn.Sequential(self.encoder.conv1,
                                      self.encoder.bn1,
                                      self.encoder.relu,
                                      self.pool)
            self.conv2 = self.encoder.layer1
            self.conv3 = self.encoder.layer2
            self.conv4 = self.encoder.layer3
            self.conv5 = self.encoder.layer4

        
        # >>>
        if custom_decoder:
            self.dec4 = nn.ConvTranspose2d(512, 128, kernel_size=2, stride=2)
            self.dec3 = MyDecoderBlockV2(256, 128)
            self.dec2 = MyDecoderBlockV2(256, 128)
            self.dec1 = MyDecoderBlockV2(256, 128)
            self.dec0 = MyDecoderBlockV2(256, 1)

        else:
            self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
            self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
            self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,
                                      is_deconv)
            self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,
                                      is_deconv)
            self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                      is_deconv)
            self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
            self.dec0 = ConvRelu(num_filters, num_filters)
            self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        if custom_decoder:
            conv1 = self.conv1(x)
            conv1_pool = self.pool(conv1)
            conv2 = self.conv2(conv1_pool)
            conv3 = self.conv3(conv2)
            conv4 = self.conv4(conv3)
            conv5 = self.conv5(conv4)

            # Apply 1x1 conv before concattenating
            conv1 = self.do_conv_1(conv1)
            conv2 = self.do_conv_2(conv2)
            conv3 = self.do_conv_3(conv3)
            conv4 = self.do_conv_4(conv4)
            conv5 = self.do_conv_5(conv5)

            # Upsampling
            dec4 = self.dec4(conv5)
            dec3 = self.dec3(torch.cat([dec4, conv4], 1))
            dec2 = self.dec2(torch.cat([dec3, conv3], 1))
            dec1 = self.dec1(torch.cat([dec2, conv2], 1))
            dec0 = self.dec0(torch.cat([dec1, conv1], 1))

            return dec0
        else:
            conv1 = self.conv1(x)
            conv2 = self.conv2(conv1)
            conv3 = self.conv3(conv2)
            conv4 = self.conv4(conv3)
            conv5 = self.conv5(conv4)

            pool = self.pool(conv5)
            center = self.center(pool)

            dec5 = self.dec5(torch.cat([center, conv5], 1))

            dec4 = self.dec4(torch.cat([dec5, conv4], 1))
            dec3 = self.dec3(torch.cat([dec4, conv3], 1))
            dec2 = self.dec2(torch.cat([dec3, conv2], 1))
            dec1 = self.dec1(dec2)
            dec0 = self.dec0(dec1)

            return self.final(F.dropout2d(dec0, p=self.dropout_2d))


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import json
from datetime import datetime
from pathlib import Path

import random
import numpy as np

import torch
import tqdm


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def map_location():
  if torch.cuda.is_available():
    print("cuda is available, using gpu")
    return torch.device("cuda:0")
    #return torch.device("cpu")
  else:
    print("cuda is unavailable, using cpu")
    return torch.device("cpu")
    

def cuda(x):
    return x.cuda(non_blocking=True) if torch.cuda.is_available() else x

def write_event(log, step, **data):
    data['step'] = step
    data['dt'] = datetime.now().isoformat()
    log.write(json.dumps(data, sort_keys=True))
    log.write('\n')
    log.flush()

def check_crop_size(image_height, image_width):
    """Checks if image size divisible by 32.
    Args:
        image_height:
        image_width:
    Returns:
        True if both height and width divisible by 32 and False otherwise.
    """
    return image_height % 32 == 0 and image_width % 32 == 0

def create_model(device, type ='vgg16'):
    assert type == 'vgg16' or type == 'resnet101'
    if type == 'vgg16':
        model = UNet16(pretrained=True)
    elif type == 'resnet101':
        model = UNetResNet(pretrained=True, encoder_depth=101, num_classes=1)
    else:
        assert False
    model.eval()
    return model.to(device)

def load_unet_vgg16(model_path):
    model = UNet16(pretrained=True)
    print("loading model")
    checkpoint = torch.load(model_path, map_location=map_location())
    print("loaded model")
    if 'model' in checkpoint:
        model.load_state_dict(checkpoint['model'])
    elif 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['check_point'])
    else:
        raise Exception('undefind model format')


    if (torch.cuda.is_available):
      model.cuda()
    print("evalling")
    model.eval()
    print("evalling done")

    return model

def load_unet_resnet_101(model_path):
    model = UNetResNet(pretrained=True, encoder_depth=101, num_classes=1)
    checkpoint = torch.load(model_path, map_location())
    if 'model' in checkpoint:
        model.load_state_dict(checkpoint['model'])
    elif 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['check_point'])
    else:
        raise Exception('undefind model format')

    model.cuda()
    model.eval()

    return model

def load_unet_resnet_34(model_path):
    model = UNetResNet(pretrained=True, encoder_depth=34, num_classes=1)
    checkpoint = torch.load(model_path, map_location())
    if 'model' in checkpoint:
        model.load_state_dict(checkpoint['model'])
    elif 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['check_point'])
    else:
        raise Exception('undefind model format')

    model.cuda()
    model.eval()

    return model

def train(args, model, criterion, train_loader, valid_loader, validation, init_optimizer, n_epochs=None, fold=None,
          num_classes=None):
    lr = args.lr
    n_epochs = n_epochs or args.n_epochs
    optimizer = init_optimizer(lr)

    root = Path(args.model_path)
    model_path = root / 'model_{fold}.pt'.format(fold=fold)
    if model_path.exists():
        state = torch.load(str(model_path))
        epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        print('Restored model, epoch {}, step {:,}'.format(epoch, step))
    else:
        epoch = 1
        step = 0

    save = lambda ep: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
    }, str(model_path))

    report_each = 10
    log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8')
    valid_losses = []
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        random.seed()
        tq = tqdm.tqdm(total=(len(train_loader) * args.batch_size))
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        losses = []
        tl = train_loader
        try:
            mean_loss = 0
            for i, (inputs, targets) in enumerate(tl):
                inputs = cuda(inputs)

                with torch.no_grad():
                    targets = cuda(targets)

                outputs = model(inputs)
                #print(outputs.shape, targets.shape)
                loss = criterion(outputs, targets)
                optimizer.zero_grad()
                batch_size = inputs.size(0)
                loss.backward()
                optimizer.step()
                step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss='{:.5f}'.format(mean_loss))
                if i and i % report_each == 0:
                    write_event(log, step, loss=mean_loss)
            write_event(log, step, loss=mean_loss)
            tq.close()
            save(epoch + 1)
            valid_metrics = validation(model, criterion, valid_loader, num_classes)
            write_event(log, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)
        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save(epoch)
            print('done.')
            return

In [None]:
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
import random
from PIL import Image
import matplotlib.pyplot as plt

class ImgDataSet(Dataset):
    def __init__(self, img_dir, img_fnames, img_transform, mask_dir, mask_fnames, mask_transform):
        self.img_dir = img_dir
        self.img_fnames = img_fnames
        self.img_transform = img_transform

        self.mask_dir = mask_dir
        self.mask_fnames = mask_fnames
        self.mask_transform = mask_transform

        self.seed = np.random.randint(2147483647)

    def __getitem__(self, i):
        fname = self.img_fnames[i]
        fpath = os.path.join(self.img_dir, fname)
        img = Image.open(fpath)
        if self.img_transform is not None:
            random.seed(self.seed)
            img = self.img_transform(img)
            #print('image shape', img.shape)

        mname = self.mask_fnames[i]
        mpath = os.path.join(self.mask_dir, mname)
        mask = Image.open(mpath)
        #print('khanh1', np.min(test[:]), np.max(test[:]))
        if self.mask_transform is not None:
            mask = self.mask_transform(mask)
            #print('mask shape', mask.shape)
            #print('khanh2', np.min(test[:]), np.max(test[:]))

        return img, mask #torch.from_numpy(np.array(mask, dtype=np.int64))

    def __len__(self):
        return len(self.img_fnames)


class ImgDataSetJoint(Dataset):
    def __init__(self, img_dir, img_fnames, joint_transform, mask_dir, mask_fnames, img_transform = None, mask_transform = None):
        self.joint_transform = joint_transform

        self.img_dir = img_dir
        self.img_fnames = img_fnames
        self.img_transform = img_transform

        self.mask_dir = mask_dir
        self.mask_fnames = mask_fnames
        self.mask_transform = mask_transform

        self.seed = np.random.randint(2147483647)

    def __getitem__(self, i):
        fname = self.img_fnames[i]
        fpath = os.path.join(self.img_dir, fname)
        img = Image.open(fpath)

        mname = self.mask_fnames[i]
        mpath = os.path.join(self.mask_dir, mname)
        mask = Image.open(mpath)

        if self.joint_transform is not None:
            img, mask = self.joint_transform([img, mask])

        #debug
        # img = np.asarray(img)
        # mask = np.asarray(mask)
        # plt.subplot(121)
        # plt.imshow(img)
        # plt.subplot(122)
        # plt.imshow(img)
        # plt.imshow(mask, alpha=0.4)
        # plt.show()

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.mask_transform is not None:
            mask = self.mask_transform(mask)

        return img, mask #torch.from_numpy(np.array(mask, dtype=np.int64))

    def __len__(self):
        return len(self.img_fnames)

In [None]:
import sys
import os
import numpy as np
from pathlib import Path
import cv2 as cv
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import argparse
from os.path import join
from PIL import Image
import gc
from tqdm import tqdm
from google.colab import files

def evaluate_img(model, img, train_tfms, img_width, img_height):
    input_width, input_height = input_size[0], input_size[1]

    img_1 = cv.resize(img, (input_width, input_height), cv.INTER_AREA)
    X = train_tfms(Image.fromarray(img_1))
    X = Variable(X.unsqueeze(0)).cuda()  # [N, 1, H, W]

    mask = model(X)

    mask = F.sigmoid(mask[0, 0]).data.cpu().numpy()
    mask = cv.resize(mask, (img_width, img_height), cv.INTER_AREA)
    return mask

def evaluate_img_patch(model, img, train_tfms, img_width, img_height):
    input_width, input_height = input_size[0], input_size[1]

    img_height, img_width, img_channels = img.shape

    if img_width < input_width or img_height < input_height:
        return evaluate_img(model, img)

    stride_ratio = 0.1
    stride = int(input_width * stride_ratio)

    normalization_map = np.zeros((img_height, img_width), dtype=np.int16)

    patches = []
    patch_locs = []
    for y in range(0, img_height - input_height + 1, stride):
        for x in range(0, img_width - input_width + 1, stride):
            segment = img[y:y + input_height, x:x + input_width]
            normalization_map[y:y + input_height, x:x + input_width] += 1
            patches.append(segment)
            patch_locs.append((x, y))

    patches = np.array(patches)
    if len(patch_locs) <= 0:
        return None

    preds = []
    for i, patch in enumerate(patches):
        patch_n = train_tfms(Image.fromarray(patch))
        X = Variable(patch_n.unsqueeze(0)).cuda()  # [N, 1, H, W]
        masks_pred = model(X)
        mask = F.sigmoid(masks_pred[0, 0]).data.cpu().numpy()
        preds.append(mask)

    probability_map = np.zeros((img_height, img_width), dtype=float)
    for i, response in enumerate(preds):
        coords = patch_locs[i]
        probability_map[coords[1]:coords[1] + input_height, coords[0]:coords[0] + input_width] += response

    return probability_map

def disable_axis():
    plt.axis('off')
    plt.gca().axes.get_xaxis().set_visible(False)
    plt.gca().axes.get_yaxis().set_visible(False)
    plt.gca().axes.get_xaxis().set_ticklabels([])
    plt.gca().axes.get_yaxis().set_ticklabels([])

def infer():
    out_viz_dir = 'drive/MyDrive/synced/outviz'
    out_pred_dir = 'drive/MyDrive/synced/outpred'
    img_dir = 'drive/MyDrive/synced/images'
    # >>>
    model_type = 'resnet34'
    # model_type = 'vgg16'
    model_path = 'drive/MyDrive/synced/models/model_best.pt'
    threshold = 0.2


    if out_viz_dir != '':
        os.makedirs(out_viz_dir, exist_ok=True)
        for path in Path(out_viz_dir).glob('*.*'):
            os.remove(str(path))

    if out_pred_dir != '':
        os.makedirs(out_pred_dir, exist_ok=True)
        for path in Path(out_pred_dir).glob('*.*'):
            os.remove(str(path))

    if model_type == 'vgg16':
        model = load_unet_vgg16(model_path)
    elif model_type  == 'resnet101':
        model = load_unet_resnet_101(model_path)
    elif model_type  == 'resnet34':
        model = load_unet_resnet_34(model_path)
        print(model)
    else:
        print('undefined model name pattern')
        exit()


    channel_means = [0.485, 0.456, 0.406]
    channel_stds  = [0.229, 0.224, 0.225]


    paths = [path for path in Path(img_dir).glob('*.*')]
    paths = paths[0:10]
    for path in tqdm(paths):
        print(str(path))

        train_tfms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(channel_means, channel_stds)])

        img_0 = Image.open(str(path))
        img_0 = np.asarray(img_0)
        if len(img_0.shape) != 3:
            print(f'incorrect image shape: {path.name}{img_0.shape}')
            continue

        img_0 = img_0[:,:,:3]

        img_height, img_width, img_channels = img_0.shape

        prob_map_full = evaluate_img(model, img_0, train_tfms, img_width, img_height)

        if out_pred_dir != '':
            cv.imwrite(filename=join(out_pred_dir, f'{path.stem}.jpg'), img=(prob_map_full * 255).astype(np.uint8))

        if out_viz_dir != '':
            # plt.subplot(121)
            # plt.imshow(img_0), plt.title(f'{img_0.shape}')
            if img_0.shape[0] > 2000 or img_0.shape[1] > 2000:
                img_1 = cv.resize(img_0, None, fx=0.2, fy=0.2, interpolation=cv.INTER_AREA)
            else:
                img_1 = img_0

            # plt.subplot(122)
            # plt.imshow(img_0), plt.title(f'{img_0.shape}')
            # plt.show()

            prob_map_patch = evaluate_img_patch(model, img_1, train_tfms, img_width, img_height)

            #plt.title(f'name={path.stem}. \n cut-off threshold = {threshold}', fontsize=4)
            prob_map_viz_patch = prob_map_patch.copy()
            prob_map_viz_patch = prob_map_viz_patch/ prob_map_viz_patch.max()
            prob_map_viz_patch[prob_map_viz_patch < threshold] = 0.0
            fig = plt.figure()
            st = fig.suptitle(f'name={path.stem} \n cut-off threshold = {threshold}', fontsize="x-large")
            ax = fig.add_subplot(231)
            ax.imshow(img_1)
            ax = fig.add_subplot(232)
            ax.imshow(prob_map_viz_patch)
            ax = fig.add_subplot(233)
            ax.imshow(img_1)
            ax.imshow(prob_map_viz_patch, alpha=0.4)

            prob_map_viz_full = prob_map_full.copy()
            prob_map_viz_full[prob_map_viz_full < threshold] = 0.0

            ax = fig.add_subplot(234)
            ax.imshow(img_0)
            ax = fig.add_subplot(235)
            ax.imshow(prob_map_viz_full)
            ax = fig.add_subplot(236)
            ax.imshow(img_0)
            ax.imshow(prob_map_viz_full, alpha=0.4)

            plt.savefig(join(out_viz_dir, f'{path.stem}.jpg'), dpi=500)
            plt.close('all')

        gc.collect()

In [None]:
# infer()

In [None]:
import torch
from torch import nn
from pathlib import Path
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split, RandomSampler
import torch.nn.functional as F
from torch.autograd import Variable
import shutil
import os
import argparse
from tqdm.notebook import tqdm
import numpy as np
import imageio
from sklearn.metrics import precision_score, recall_score, f1_score

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def create_model(device, type ='vgg16'):
    if type == 'vgg16':
        print('create vgg16 model')
        model = UNet16(pretrained=True)
    elif type == 'resnet101':
        encoder_depth = 101
        num_classes = 1
        print('create resnet101 model')
        model = UNetResNet(encoder_depth=encoder_depth, num_classes=num_classes, pretrained=True)
    elif type == 'resnet34':
        encoder_depth = 34
        num_classes = 1
        print('create resnet34 model')
        model = UNetResNet(encoder_depth=encoder_depth, num_classes=num_classes, pretrained=True)
    else:
        assert False
    model.eval()
    return model.to(device)

def adjust_learning_rate(optimizer, epoch, lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def find_latest_model_path(dir):
    model_paths = []
    epochs = []
    for path in Path(dir).glob('*.pt'):
        if 'epoch' not in path.stem:
            continue
        model_paths.append(path)
        parts = path.stem.split('_')
        epoch = int(parts[-1])
        epochs.append(epoch)

    if len(epochs) > 0:
        epochs = np.array(epochs)
        max_idx = np.argmax(epochs)
        return model_paths[max_idx]
    else:
        return None

def train(train_loader, model, criterion, optimizer, validation, valid_loader, model_dir, n_epoch, lr, adjust_lr, batch_size):

    latest_model_path = find_latest_model_path(model_dir)

    best_model_path = os.path.join(*[model_dir, 'model_best.pt'])

    if latest_model_path is not None:
        state = torch.load(latest_model_path)
        epoch = state['epoch']
        model.load_state_dict(state['model'])
        epoch = epoch

        #if latest model path does exist, best_model_path should exists as well
        assert Path(best_model_path).exists() == True, f'best model path {best_model_path} does not exist'
        #load the min loss so far
        best_state = torch.load(latest_model_path)
        min_val_los = best_state['valid_loss']

        print(f'Restored model at epoch {epoch}. Min validation loss so far is : {min_val_los}')
        epoch += 1
        print(f'Started training model from epoch {epoch}')
    else:
        print('Started training model from epoch 0')
        epoch = 0
        min_val_los = 9999

    valid_losses = []
    scaled_lr = 0.5 * lr
    for epoch in range(epoch, n_epoch + 1):

        if (adjust_lr):
          adjust_learning_rate(optimizer, epoch, lr)
        else:
          lr = scaled_lr

        tq = tqdm(total=(len(train_loader) * batch_size))
        tq.set_description(f'Epoch {epoch}')

        losses = AverageMeter()

        model.train()
        for i, (input, target) in enumerate(train_loader):
            input_var  = Variable(input).cuda()
            target_var = Variable(target).cuda()

            masks_pred = model(input_var)

            masks_probs_flat = masks_pred.view(-1)
            true_masks_flat  = target_var.view(-1)

            loss = criterion(masks_probs_flat, true_masks_flat)
            losses.update(loss)
            tq.set_postfix(loss='{:.5f}'.format(losses.avg))
            tq.update(batch_size)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        valid_metrics = validation(model, valid_loader, criterion)

        if epoch % 5 == 0:
          calculate_metrics(model, valid_loader)

        valid_loss = valid_metrics['valid_loss']
        valid_losses.append(valid_loss)
        print(f'\tvalid_loss = {valid_loss:.5f}')
        tq.close()

        #save the model of the current epoch
        epoch_model_path = os.path.join(*[model_dir, f'model_epoch_{epoch}.pt'])
        torch.save({
            'model': model.state_dict(),
            'epoch': epoch,
            'valid_loss': valid_loss,
            'train_loss': losses.avg
        }, epoch_model_path)

        if valid_loss < min_val_los:
            min_val_los = valid_loss

            torch.save({
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_loss': valid_loss,
                'train_loss': losses.avg
            }, best_model_path)

def validate(model, val_loader, criterion):
    losses = AverageMeter()
    model.eval()
    with torch.no_grad():

        for i, (input, target) in enumerate(val_loader):
            input_var = Variable(input).cuda()
            target_var = Variable(target).cuda()

            output = model(input_var)
            loss = criterion(output, target_var)

            losses.update(loss.item(), input_var.size(0))

    return {'valid_loss': losses.avg}

def calculate_metrics(model, val_loader):
  
  precisions = torch.zeros(len(val_loader))
  recalls = torch.zeros(len(val_loader))
  f1s = torch.zeros(len(val_loader))

  model.eval()

  with torch.no_grad():
    for k, (input, target) in enumerate(tqdm(val_loader)):
        target_var = Variable(target).cuda()
        input_var = Variable(input).cuda()

        output = model(input_var)

        target = target_var.cpu()[0,0]

        # padded_target = F.pad(input=target, pad=(1, 1, 1, 1), mode='constant', value=0)

        # target_transitions = padded_target.clone()

        # for i in range(1, padded_target.shape[0]-1):
        #   for j in range(1, padded_target.shape[1]-1):
        #     neighbours1 = padded_target[i-1][j-1] + padded_target[i-1][j] + padded_target[i-1][j+1]
        #     neighbours2 = padded_target[i][j-1] + padded_target[i][j+1]
        #     neighbours3 = padded_target[i+1][j-1] + padded_target[i+1][j] + padded_target[i+1][j+1]

        #     if (neighbours1 + neighbours2 + neighbours3 > 0):
        #       target_transitions[i][j] = 1

        # target = target_transitions[1:target_transitions.shape[0]-1, 1:target_transitions.shape[1]-1]
        target = target.flatten()
        target[target>=0.5] = 1
        target[target<0.5] = 0

        output = F.sigmoid(output).cpu()[0,0].flatten()
        output[output>=0.5] = 1
        output[output<0.5] = 0

        
        precisions[k] = precision_score(target, output, zero_division=1)
        recalls[k] = recall_score(target, output, zero_division = 1)
        f1s[k] = f1_score(target, output, zero_division=1)
  
  print("precision: ", torch.mean(precisions))
  print("recall: ", torch.mean(recalls))
  print("f1 ", torch.mean(f1s))            

def save_check_point(state, is_best, file_name = 'checkpoint.pth.tar'):
    torch.save(state, file_name)
    if is_best:
        shutil.copy(file_name, 'model_best.pth.tar')

def calc_crack_pixel_weight(mask_dir):
    avg_w = 0.0
    n_files = 0
    for path in Path(mask_dir).glob('*.*'):
        n_files += 1
        # m = ndimage.imread(path)
        m = imageio.imread(path)
        ncrack = np.sum((m > 0)[:])
        w = float(ncrack)/(m.shape[0]*m.shape[1])
        avg_w = avg_w + (1-w)

    avg_w /= float(n_files)

    return avg_w / (1.0 - avg_w)

def train_all():
    
    # parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    # parser.add_argument('-n_epoch', default=10, type=int, metavar='N', help='number of total epochs to run')
    # parser.add_argument('-lr', default=0.001, type=float, metavar='LR', help='initial learning rate')
    # parser.add_argument('-momentum', default=0.9, type=float, metavar='M', help='momentum')
    # parser.add_argument('-print_freq', default=20, type=int, metavar='N', help='print frequency (default: 10)')
    # parser.add_argument('-weight_decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
    # parser.add_argument('-batch_size',  default=4, type=int,  help='weight decay (default: 1e-4)')
    # parser.add_argument('-num_workers', default=4, type=int, help='output dataset directory')

    # parser.add_argument('-data_dir',type=str, help='input dataset directory')
    # parser.add_argument('-model_dir', type=str, help='output dataset directory')
    # parser.add_argument('-model_type', type=str, required=False, default='resnet101', choices=['vgg16', 'resnet101', 'resnet34'])

    n_epoch = 50
    lr = 0.001
    momentum = 0.9
    print_freq = 20
    weight_decay = 1e-4
    batch_size = 4
    num_workers = 4

    # Ablation study parameters
    adjust_lr = False
    scale_imgs = True

    data_dir = 'drive/MyDrive/synced'
    model_dir = 'drive/MyDrive/synced/models'
    model_type = 'resnet34'

    os.makedirs(model_dir, exist_ok=True)

    DIR_IMG  = os.path.join(data_dir, 'images')
    DIR_MASK = os.path.join(data_dir, 'masks')

    patterns = ('CFD*.jpg', 'CRACK500*.jpg')
    #patterns = ('*.jpg', 'abcdefgh')
    imgs = []
    masks = []
    for pattern in patterns:
      imgs.extend(Path(DIR_IMG).glob(pattern))
      masks.extend(Path(DIR_MASK).glob(pattern))
    img_names  = [path.name for path in imgs]
    mask_names = [path.name for path in masks]

    print(f'total images = {len(img_names)}')

    print(f'cuda available: {torch.cuda.is_available()}')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = create_model(device, model_type)

    # optimizer = torch.optim.SGD(model.parameters(), lr,
    #                             momentum=momentum,
    #                             weight_decay=weight_decay)
    optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay = weight_decay)

    # crack_weight = 0.4*calc_crack_pixel_weight(DIR_MASK)
    # print(f'positive weight: {crack_weight}')
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([10])).to('cuda')
    criterion = nn.BCEWithLogitsLoss().to('cuda')

    channel_means = [0.485, 0.456, 0.406]
    channel_stds  = [0.229, 0.224, 0.225]
    
    if (scale_imgs):
      train_tfms = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(channel_means, channel_stds),
                                     transforms.Resize(224)])
      
      mask_tfms = transforms.Compose([transforms.ToTensor(), transforms.Resize(224)])

      
    else:
      train_tfms = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(channel_means, channel_stds)])
      
      mask_tfms = transforms.Compose([transforms.ToTensor()])


    
    val_tfms = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize(channel_means, channel_stds)])


    dataset = ImgDataSet(img_dir=DIR_IMG, img_fnames=img_names, img_transform=train_tfms, mask_dir=DIR_MASK, mask_fnames=mask_names, mask_transform=mask_tfms)
    dataset = torch.utils.data.Subset(dataset, range(int(0.5 * len(dataset))))
    train_size = int(0.85*len(dataset))
    valid_size = len(dataset) - train_size
    train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

    train_loader = DataLoader(train_dataset, batch_size, shuffle=False, pin_memory=torch.cuda.is_available(), num_workers=num_workers)
    valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, pin_memory=torch.cuda.is_available(), num_workers=num_workers)

    model.cuda()

    train(train_loader, model, criterion, optimizer, validate, valid_loader, model_dir, n_epoch, lr, adjust_lr, batch_size)



In [None]:
# Run validation without training

model_type = 'resnet34'
model_dir = 'drive/MyDrive/synced/models'

def get_valid_loader():
  data_dir = 'drive/MyDrive/synced'

  DIR_IMG  = os.path.join(data_dir, 'images')
  DIR_MASK = os.path.join(data_dir, 'masks')

  patterns = ('CFD*.jpg', 'CRACK500*.jpg')
  imgs = []
  masks = []
  for pattern in patterns:
    imgs.extend(Path(DIR_IMG).glob(pattern))
    masks.extend(Path(DIR_MASK).glob(pattern))
  img_names  = [path.name for path in imgs]
  mask_names = [path.name for path in masks]

  channel_means = [0.485, 0.456, 0.406]
  channel_stds  = [0.229, 0.224, 0.225]

  train_tfms = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(channel_means, channel_stds)])


  val_tfms = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize(channel_means, channel_stds)])

  mask_tfms = transforms.Compose([transforms.ToTensor()])

  dataset = ImgDataSet(img_dir=DIR_IMG, img_fnames=img_names, img_transform=train_tfms, mask_dir=DIR_MASK, mask_fnames=mask_names, mask_transform=mask_tfms)
  dataset = torch.utils.data.Subset(dataset, range(int(0.5 * len(dataset))))
  train_size = int(0.85*len(dataset))
  valid_size = len(dataset) - train_size
  _, valid_dataset = random_split(dataset, [train_size, valid_size])


  return DataLoader(valid_dataset, 4, shuffle=False, pin_memory=torch.cuda.is_available(), num_workers=4)

def get_model():

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  model = create_model(device, model_type)

  best_model_path = os.path.join(*[model_dir, 'model_best.pt'])

  state = torch.load(best_model_path)
  model.load_state_dict(state['model']) 
  
  return model

# calculate_metrics(get_model(), get_valid_loader())


In [None]:
train_all()

total images = 2449
cuda available: True
create resnet34 model


  cpuset_checked))


Restored model at epoch 0. Min validation loss so far is : 0.6900064676352169
Started training model from epoch 1


HBox(children=(FloatProgress(value=0.0, max=1040.0), HTML(value='')))

KeyboardInterrupt: ignored

In [None]:
print("hello net")

In [None]:
# !rm -rf drive/MyDrive/synced/models