In [1]:
import os
import gc
import numpy as np
import torch
from torchvision import transforms as tfs
import cv2
import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

print(torch.cuda.is_available())

True


**Load the Data**

In [2]:
# load datasets
def load_images(file_names):
    images = []
    for file_name in file_names:
        img = cv2.imread(file_name, -1)
        img = img_standardization(img)
        images.append(img)
    images = np.array(images)
    return images


def unit16b2uint8(img):
    if img.dtype == 'uint8':
        return img
    elif img.dtype == 'uint16':
        return img.astype(np.uint8)
    else:
        raise TypeError(
            'No such of img transfer type: {} for img'.format(img.dtype))


def img_standardization(img):
    img = unit16b2uint8(img)
    """
    if len(img.shape) == 2:
        img = np.expand_dims(img, 2)
        img = np.tile(img, (1, 1, 3))
        return img
    elif len(img.shape) == 3:
        return img
    else:
        raise TypeError('The Depth of image large than 3 \n')
    """
    return img


def binaryzation(image):
    image[image > 0] = 1
    return image


SHOW_DATA = False

In [3]:
def get_dataset(width_in, height_in, width_out, height_out):

    # get train x
    # absp = '/'.join(os.path.abspath(__file__).split('\\')[:-1]) + '/'
    train_x_path = 'C:/me/dataset1/train/'
    train_x_list = [os.path.join(train_x_path, image)
                    for image in os.listdir(train_x_path)]
    train_X = load_images(train_x_list)

    # get train y
    train_y_path = 'C:/me/dataset1/train_GT/SEG'
    train_y_list = [os.path.join(train_y_path, image)
                    for image in os.listdir(train_y_path)]
    train_Y = load_images(train_y_list)
    

    if SHOW_DATA:
        import matplotlib.pyplot as plt
        plt.subplot(1, 2, 1)
        plt.imshow(train_x[3])
        plt.subplot(1, 2, 2)
        plt.imshow(train_y[3] * 255)
        plt.show()

    """
    result_path = absp + '../supplementary/dataset1/test_RES'
    if not os.path.exists(result_path):
        os.mkdir(result_path)
    """

    train_X, test_X, train_y, test_y = train_test_split(train_X, train_Y,
                                                        test_size=0.1,
                                                        random_state=0)
    
    train_y = binaryzation(train_y)  # for instance segmentation

    return train_X, train_y, test_X, test_y

width_in = 628
height_in = 628
width_out = 628
height_out = 628

x_train, y_train, x_val, y_val = get_dataset(
    width_in, height_in, width_out, height_out)

In [4]:
y_train[0].min(), y_train[0].max(), y_val[0].min(), y_val[0].max()

(0, 1, 0, 34)

In [5]:
x_train.shape, x_val.shape

((157, 628, 628), (18, 628, 628))

In [6]:
# input augmentation
im_aug = tfs.Compose([
    tfs.RandomResizedCrop((width_in, height_in)),
    tfs.RandomHorizontalFlip(),
    tfs.RandomVerticalFlip(),
    tfs.RandomCrop(width_in)
])

Define the network

In [7]:
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(
                scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(
                in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
        # import pdb; pdb.set_trace()

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


In [8]:
class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=2, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 32)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 256)
        self.up1 = Up(512, 128, bilinear)
        self.up2 = Up(256, 64, bilinear)
        self.up3 = Up(128, 32, bilinear)
        self.up4 = Up(64, 32, bilinear)
        self.outpred = OutConv(32, n_classes)

    def forward(self, x):
        x = x/255.0
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outpred(x)
        return logits


In [9]:
use_gpu = True

unet = UNet(n_channels=1, n_classes=2,)
if use_gpu:
    unet = unet.cuda()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(unet.parameters(), lr=0.01, momentum=0.99)

In [10]:
size_all = 0
for name,parameters in unet.named_parameters():
    size_all += torch.Size.numel(parameters.size())
    # print(name,':',parameters.size())
print(size_all)

3351714


Start Training

In [11]:
def train_step(inputs, labels, optimizer, criterion, unet, width_out, height_out):
    optimizer.zero_grad()
    # forward + backward + optimize
    outputs = unet(inputs)
    # outputs.shape =(batch_size, n_classes, img_cols, img_rows)
    outputs = outputs.permute(0, 2, 3, 1)
    # outputs.shape =(batch_size, img_cols, img_rows, n_classes)
    m = outputs.shape[0]
    outputs = outputs.reshape(m*width_out*height_out, 2)
    labels = labels.reshape(m*width_out*height_out)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    return loss

In [12]:
def get_val_loss(x_val, y_val, width_out, height_out,
                 unet, batch_size=1, use_gpu=True):

    epoch_iter = np.ceil(x_val.shape[0] / batch_size).astype(int)

    j_scores = []

    predict_ys = []
    for i in range(epoch_iter):
        # preprocess
        batch_val_x = torch.from_numpy(
            x_val[i*batch_size:(i + 1)*batch_size]).float()
        if (len(batch_val_x.size()) == 3):
            batch_val_x = batch_val_x.unsqueeze(1)
        if use_gpu:
            batch_val_x = batch_val_x.cuda()

        # get predict
        batch_predict_y = unet(batch_val_x)
        batch_predict_y = batch_predict_y.cpu().detach().numpy()
        predict_ys.append(batch_predict_y)

    predict_ys = np.array(predict_ys)
    shape = predict_ys.shape
    # print("predict_ys.shape", predict_ys.shape)
    predict_ys = np.reshape(
        predict_ys, (shape[0]*shape[1], shape[2], shape[3], shape[4]))

    # post process
    predict_ys = post_process(predict_ys)

    # get GT
    gt_ys = y_val

    if False :
        show_pic(x_val[0], gt_ys[0], predict_ys[0])

    # calc jaccard score
    for j in range(len(predict_ys)):
        j_score = calc_jaccard(predict_ys[j], gt_ys[j])
        j_scores.append(j_score)

    j_score = np.mean(j_scores)
    return j_score


def post_process(batch_predict_y):
    """post process of the result"""
    # shape: [batch_size, 2, width, height]

    batch_predict_y = batch_predict_y[:, 1, :, :]

    res = []
    for predict_y in batch_predict_y:
        # binarization
        predict_y[predict_y > 0] = 1
        predict_y[predict_y <= 0] = 0

        # open
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
        predict_y = cv2.erode(predict_y, kernel)  # 腐蚀
        predict_y = cv2.dilate(predict_y, kernel)  # 膨胀

        # parse
        predict_y = predict_y.astype(np.uint8) * 255
        __, contours, _ = cv2.findContours(
            predict_y, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)  # 寻找连通域

        areas = [cv2.contourArea(cnt) for cnt in contours]
        cellIndexs = np.argsort(areas)

        predict_y = np.zeros([predict_y.shape[0], predict_y.shape[1]])
        for j in range(len(cellIndexs)):
            cv2.drawContours(predict_y, contours, j, j, cv2.FILLED)

        predict_y = predict_y.astype(int)
        res.append(predict_y)

    return res


def calc_jaccard(imgA, imgB):
    """calculate the jaccard score"""
    num_A = len(np.unique(imgA))
    num_B = len(np.unique(imgB))

    if num_A < num_B:
        i = imgA
        imgA = imgB
        imgB = i

    unqA = np.unique(imgA)
    for i in range(len(unqA)):
        imgA[imgA == unqA[i]] = i

    unqB = np.unique(imgB)
    for i in range(len(unqB)):
        imgB[imgB == unqB[i]] = i

    hit_matrix = np.zeros([unqA.size, unqB.size])

    for i in range(1, unqA.size):
        A_chan = (imgA == i)
        for j in range(1, unqB.size):
            B_chan = (imgB == j)
            A_and_B = A_chan * B_chan
            B_chan[A_chan == 1] = 1
            hit_matrix[i, j] = np.sum(A_and_B) / np.sum(B_chan)

    jaccard_list = []
    for j in range(1, unqB.size):
        jac_col = np.max(hit_matrix[:, j])
        jaccard_list.append(jac_col)

    j_score = np.sum(jaccard_list) / max(num_A, num_B)
    """
    print("A, B:", num_A, num_B, "max:", max(jaccard_list),
          "min:", min(jaccard_list), "mean:", j_score)
    print(jaccard_list)
    """

    return j_score

def show_pic(picA, picB, picC):
    plt.subplot(1, 3, 1)
    plt.title("x")
    plt.imshow(picA, cmap='gray')

    plt.subplot(1, 3, 2)
    plt.title("GT")
    plt.imshow(picB)

    if picC is not None:
        plt.subplot(1, 3, 3)
        plt.title("Predict")
        plt.imshow(picC)

    plt.show()


In [13]:
batch_size = 2
epochs = 100
epoch_lapse = 5
threshold = 0.5
learning_rate = 0.0001

In [14]:
epoch_iter = np.ceil(x_train.shape[0] / batch_size).astype(int)

for _ in range(20):
    total_loss = 0
    for i in tqdm.tqdm(range(epoch_iter), ascii=True, ncols=120):
        batch_train_x = torch.from_numpy(
            x_train[i * batch_size: (i + 1) * batch_size]).float()
        batch_train_y = torch.from_numpy(
            y_train[i*batch_size:(i + 1)*batch_size]).long()

        if (len(batch_train_x.size()) == 3):
            batch_train_x = batch_train_x.unsqueeze(1)
        if (len(batch_train_y.size()) == 3):
            batch_train_y = batch_train_y.unsqueeze(1)

        if use_gpu:
            batch_train_x = batch_train_x.cuda()
            batch_train_y = batch_train_y.cuda()
        batch_loss = train_step(
            batch_train_x, batch_train_y, optimizer, criterion, unet, width_out, height_out)
        total_loss += batch_loss
    # print("i:", _, "total_loss:", total_loss.cpu().item())
    
    if (_+1) % 5 == 0:
        val_loss = get_val_loss(x_val, y_val, width_out, height_out, unet, batch_size)
        print("Total loss in epoch %f : %f and validation loss : %f" %
              (_+1, total_loss, val_loss))
    
    continue
import datetime
filename = 'unet-' + datetime.datetime.now().strftime('%Y%m%d%H%M%S') +'.pth'
torch.save(unet.state_dict(), 'C:/me/test/save/' + filename)

100%|###################################################################################| 79/79 [00:16<00:00,  4.93it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.20it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.17it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.16it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.11it/s]
  0%|                                                                                            | 0/79 [00:00<?, ?it/s]

Total loss in epoch 5.000000 : 4.261444 and validation loss : 0.480019


100%|###################################################################################| 79/79 [00:15<00:00,  5.14it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.11it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.15it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.12it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.05it/s]
  0%|                                                                                            | 0/79 [00:00<?, ?it/s]

Total loss in epoch 10.000000 : 3.233885 and validation loss : 0.563821


100%|###################################################################################| 79/79 [00:15<00:00,  5.14it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.14it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.09it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.07it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.01it/s]
  0%|                                                                                            | 0/79 [00:00<?, ?it/s]

Total loss in epoch 15.000000 : 2.601982 and validation loss : 0.563295


100%|###################################################################################| 79/79 [00:15<00:00,  5.19it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.12it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.01it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.15it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.00it/s]
  0%|                                                                                            | 0/79 [00:00<?, ?it/s]

Total loss in epoch 20.000000 : 2.311662 and validation loss : 0.551328


100%|###################################################################################| 79/79 [00:15<00:00,  5.14it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.09it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.10it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.09it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.01it/s]
  0%|                                                                                            | 0/79 [00:00<?, ?it/s]

Total loss in epoch 25.000000 : 2.427837 and validation loss : 0.590202


100%|###################################################################################| 79/79 [00:15<00:00,  5.14it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.09it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.05it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.00it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.12it/s]
  0%|                                                                                            | 0/79 [00:00<?, ?it/s]

Total loss in epoch 30.000000 : 1.830118 and validation loss : 0.578145


100%|###################################################################################| 79/79 [00:15<00:00,  5.16it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.09it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.09it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  4.99it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.06it/s]
  0%|                                                                                            | 0/79 [00:00<?, ?it/s]

Total loss in epoch 35.000000 : 1.950086 and validation loss : 0.609834


100%|###################################################################################| 79/79 [00:15<00:00,  5.19it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.13it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.05it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.10it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.10it/s]
  0%|                                                                                            | 0/79 [00:00<?, ?it/s]

Total loss in epoch 40.000000 : 1.605719 and validation loss : 0.610865


100%|###################################################################################| 79/79 [00:15<00:00,  5.14it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.11it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.09it/s]
100%|###################################################################################| 79/79 [00:15<00:00,  5.04it/s]
 67%|#######################################################6                           | 53/79 [00:10<00:05,  5.06it/s]


KeyboardInterrupt: 