In [1]:
# unet
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(
                scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(
                in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
        # import pdb; pdb.set_trace()

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        """softmax"""
        x = self.conv(x)
        # x = F.softmax(x, dim=1)
        return x


class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=2, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 32)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 256)
        self.up1 = Up(512, 128, bilinear)
        self.up2 = Up(256, 64, bilinear)
        self.up3 = Up(128, 32, bilinear)
        self.up4 = Up(64, 32, bilinear)
        self.outpred = OutConv(32, n_classes)

    def forward(self, x):
        x = x/255.0
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outpred(x)
        return logits


In [2]:
# data
import os
import gc

import cv2
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image


data_transforms = transforms.Compose([
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(90),
    transforms.ToTensor(),
])

def load_file(filename):
    img = cv2.imread(filename, -1)
    img = img.astype(np.float32)
    return img


class TestSet(Dataset):
    def __init__(self, cell_dir, data_reader=load_file):

        super(TestSet, self).__init__()
        data_paths = os.listdir(cell_dir)
        self.data_paths = []
        for p in data_paths:
            self.data_paths.append(cell_dir + p)

        self.data_reader = data_reader
        pass

    def __getitem__(self, index):
        cell_path = self.data_paths[index]
        cell = self.data_reader(cell_path)

        # Normalization
        cell = cell - cell.min()
        cell = cell / cell.max() * 255
        return cell

    def __len__(self):
        return len(self.data_paths)

    pass


class CellDataset(Dataset):
    def __init__(self, txtpath, transform=data_transforms, data_reader=None):
        super(CellDataset, self).__init__()

        data_paths = []
        with open(txtpath, 'r') as fh:
            for line in fh:
                line = line.strip('\n')
                line = line.rstrip('\n')
                words = line.split()    # 0和1分别是cell和mask路径
                data_paths.append((words[0], words[1]))

        self.data_paths = data_paths
        self.transform = transform
        self.data_reader = data_reader
        pass

    def __getitem__(self, index):

        cell_path, mask_path = self.data_paths[index]
        cell = self.data_reader(cell_path)
        mask = self.data_reader(mask_path)

        # Normalization
        cell = cell - cell.min()
        cell = cell / cell.max() * 255

        if self.transform is not None:
            img = np.uint8([cell, mask, mask]).transpose(1, 2, 0)
            img = Image.fromarray(img)
            img = self.transform(img)
            cell = img[0]
            mask = img[1] * 255

        return cell, mask

    def __len__(self):
        return len(self.data_paths)


def get_dataset(cell_dir, mask_dir, valid_rate, tmp_dir, use_exist=True):

    valid_txt = tmp_dir + "valid_data.txt"
    train_txt = tmp_dir + "train_data.txt"

    use_exist = use_exist and os.path.isfile(
        valid_txt) and os.path.isfile(train_txt)

    if not use_exist:
        # generate list of file names
        cell_list = [os.path.join(cell_dir, image)
                     for image in os.listdir(cell_dir)]
        mask_list = [os.path.join(mask_dir, image)
                     for image in os.listdir(mask_dir)]

        # separate the lists according to valid_rate
        sample_size = len(cell_list)
        valid_size = int(sample_size * valid_rate)
        valid_index = np.random.choice(
            a=sample_size, size=valid_size, replace=False, p=None)

        # save the lists in txt files
        with open(valid_txt, "w+") as f:
            f.truncate()
            for i in valid_index:
                f.write(cell_list[i] + " " + mask_list[i] + '\n')

        with open(train_txt, "w+") as f:
            f.truncate()
            for i in range(sample_size):
                if i not in valid_index:
                    f.write(cell_list[i] + " " + mask_list[i] + '\n')

    # get the Dataset objects
    train_dataset = CellDataset(train_txt, data_reader=load_file)
    valid_dataset = CellDataset(valid_txt, data_reader=load_file)

    return train_dataset, valid_dataset


In [3]:
# validate
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import scipy
import torchvision.transforms as T


def CUDA(func):
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper


class Validator():

    def __init__(self, unet,
                 hyper_params,
                 use_cuda,
                 data_loader):
        self.unet = unet
        self.hyper_params = hyper_params
        self.use_cuda = use_cuda
        self.data_loader = data_loader
        pass

    def validate(self, SHOW_PIC=False, TTA=False):

        width_out = 628
        height_out = 628
        batch_size = self.hyper_params["batch_size"]
        use_cuda = self.use_cuda

        j_scores = []
        for i, data in enumerate(self.data_loader):

            """preprocess"""
            b_val_x, b_val_y = data

            """Test time augmentation"""
            # S b_val_x: [batch_size, width, height]
            if TTA:
                b_val_x_fh = torch.flip(b_val_x, dims=[1])
                b_val_x_fv = torch.flip(b_val_x, dims=[2])
                b_val_x_90 = torch.rot90(b_val_x, 1, dims=(1, 2))
                b_val_x_180 = torch.rot90(b_val_x, 2, dims=(1, 2))
                b_val_x_270 = torch.rot90(b_val_x, 3, dims=(1, 2))

                b_val_x_list = [
                    b_val_x,
                    b_val_x_fh,
                    b_val_x_fv,
                    b_val_x_90,
                    b_val_x_180,
                    b_val_x_270,
                ]
            else:
                b_val_x_list = [b_val_x]

            """get binary output"""
            # S b_val_x_list: [6 or 1, batch_size, width, height]

            b_y_list_cpu = []
            for b_x in b_val_x_list:

                # S b_x: [batch_size, width, height]
                if not isinstance(b_x, torch.Tensor):
                    b_x = T.ToTensor()(b_x)
                if (len(b_x.size()) == 3):
                    b_x = b_x.unsqueeze(1)
                elif (len(b_x.size() == 2)):
                    b_x = b_x.unsqueeze(0)
                    b_x = b_x.unsqueeze(1)

                # S b_x: [batch_size, 1, width, height]
                if use_cuda:
                    b_x = b_x.cuda()

                """get raw output"""
                b_predict_y = self.unet(b_x)

                """binarization"""
                # S b_predict_y: [batch_size, 2, width, height]
                b_predict_y_cpu = self.binarization(b_predict_y).detach().cpu()

                # S b_predict_y: [batch_size, width, height]
                b_y_list_cpu.append(b_predict_y_cpu)

            """Augmentation vote"""
            # S b_y_list_cpu: [6 or 1, batch_size, width, height]

            if TTA:
                # S b_y_list_cpu[n]: [batch_size, width, height]
                b_y_list_cpu[1] = torch.flip(b_y_list_cpu[1], dims=[1])
                b_y_list_cpu[2] = torch.flip(b_y_list_cpu[2], dims=[2])
                b_y_list_cpu[3] = torch.rot90(b_y_list_cpu[3], 3, dims=(1, 2))
                b_y_list_cpu[4] = torch.rot90(b_y_list_cpu[4], 2, dims=(1, 2))
                b_y_list_cpu[5] = torch.rot90(b_y_list_cpu[5], 1, dims=(1, 2))

                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (6, 6))

                """Open operation"""
                for n in range(6):
                    b_y_list_cpu[n] = b_y_list_cpu[n].numpy()
                    for j in range(batch_size):
                        b_y_list_cpu[n][j] = cv2.morphologyEx(
                            src=b_y_list_cpu[n][j],
                            op=cv2.MORPH_OPEN,
                            kernel=kernel)
                    b_y_list_cpu[n] = torch.tensor(b_y_list_cpu[n])
                    b_y_list_cpu[n] = b_y_list_cpu[n].unsqueeze(0)

                """Vote"""
                # S b_y_list_cpu[n]: [1, batch_size, width, height]
                b_predict_y = torch.cat(tuple(b_y_list_cpu), dim=0)
                b_predict_y = torch.mean(b_predict_y, dim=0)

                # S b_predict_y: [batch_size, width, height]
                b_predict_y[b_predict_y > 0.5] = 1
                b_predict_y[b_predict_y <= 0.5] = 0

            else:
                b_predict_y = b_y_list_cpu[0]

            """Instance Sparse"""
            # S b_predict_y: [batch_size, width, height]
            b_predict_y = self.instance_sparse(b_predict_y)

            """Calculate jaccard score"""            
            for j in range(b_val_y.shape[0]):
                j_score = self.calc_jaccard(
                    b_val_y[j], b_predict_y[j], use_cuda=self.use_cuda)
                j_scores.append(j_score)

                if SHOW_PIC and j_score < 0.7:
                    b_val_x = b_val_x.cpu().detach().numpy()
                    b_predict_y = b_predict_y.cpu().detach().numpy()
                    comment = ("pic_num: %d, j_score: %f\n" % (i, j_score))
                    self.show_pic(picA=b_val_x[j],
                                  picB=b_val_y[j],
                                  picC=b_predict_y[j],
                                  comment=comment)
                pass
            pass  # end Calculate jaccard score

        j_score = np.mean(j_scores)
        return j_score

    def binarization(self, batch_predict_y):
        # S b_predict_y: [batch_size, 2, width, height]

        # sqeeze
        batch_predict_y_1 = torch.softmax(batch_predict_y, dim=1)
        batch_predict_y_1 = batch_predict_y_1[:, 1, :, :]
        THRESHOLD_1 = 0.5
        batch_predict_y_1[batch_predict_y_1 > THRESHOLD_1] = 1
        batch_predict_y_1[batch_predict_y_1 <= THRESHOLD_1] = 0

        """
        # sqeeze
        batch_predict_y_raw = torch.tensor(batch_predict_y)
        batch_predict_y_2 = batch_predict_y[:, 1, :, :]
        # binarization
        THRESHOLD_2 = 0
        batch_predict_y_2[batch_predict_y_2 > THRESHOLD_2] = 1
        batch_predict_y_2[batch_predict_y_2 <= THRESHOLD_2] = 0
        """
        batch_predict_y = batch_predict_y_1
        # S b_predict_y: [batch_size, width, height]

        return batch_predict_y

    def instance_sparse(self, batch_predict_y, KERNEL_SIZE=(3, 3)):
        """Post process the result."""
        # shape: [batch_size, width, height]

        res = []
        for predict_y in batch_predict_y:

            predict_y = predict_y.numpy().astype(np.uint8) * 255
            predict_y_old = predict_y.copy()

            # sure background area
            kernel = cv2.getStructuringElement(cv2.MORPH_RECT, KERNEL_SIZE)
            sure_bg = cv2.dilate(predict_y, kernel, iterations=2)

            # Finding sure foreground area
            dist_transform = cv2.distanceTransform(predict_y, cv2.DIST_L2, 5)
            ret, sure_fg = cv2.threshold(
                dist_transform, 0.5*dist_transform.max(), 255, 0)

            # Finding unknown region
            sure_fg = np.uint8(sure_fg)
            unknown = cv2.subtract(sure_bg, sure_fg)

            # Marker labelling
            ret, markers = cv2.connectedComponents(sure_fg)
            # Change background to 1
            markers = markers + 1
            # Mark the region of unknown with zero
            markers[unknown == 255] = 0

            # watershed
            predict_y_color = cv2.cvtColor(predict_y, cv2.COLOR_GRAY2BGR)
            markers = cv2.watershed(predict_y_color, markers)
            markers[markers == -1] = 1
            markers = markers - 1

            predict_y = markers.astype(int)
            res.append(predict_y)

        res = torch.Tensor(res)
        return res

    def calc_jaccard(self, imgA, imgB, use_cuda=True):
        """Calculate the jaccard score"""
        """All this may occur in GPU."""
        if use_cuda:
            imgA = imgA.cuda()
            imgB = imgB.cuda()

        unqA = torch.unique(imgA)
        unqB = torch.unique(imgB)
        num_A = len(unqA)
        num_B = len(unqB)

        for i in range(num_A):
            imgA[imgA == unqA[i]] = i
        for i in range(num_B):
            imgB[imgB == unqB[i]] = i

        hit_matrix = np.zeros([num_A, num_B])

        if use_cuda:
            for i in range(1, num_A):
                A_chan = (imgA == i).cuda()
                for j in range(1, num_B):
                    B_chan = (imgB == j).cuda()
                    A_and_B = torch.mul(A_chan, B_chan)
                    B_chan[A_chan == 1] = 1
                    hit_matrix[i, j] = torch.sum(
                        A_and_B).float() / torch.sum(B_chan).float()
        else:
            for i in range(1, num_A):
                A_chan = (imgA == i)
                for j in range(1, num_B):
                    B_chan = (imgB == j)
                    A_and_B = torch.mul(A_chan, B_chan)
                    B_chan[A_chan == 1] = 1
                    hit_matrix[i, j] = torch.sum(
                        A_and_B).float() / torch.sum(B_chan).float()

        jaccard_list = []
        for j in range(1, num_A):
            jac_col = np.max(hit_matrix[j, :])
            if jac_col > 0.5:
                jaccard_list.append(jac_col)
            else:
                jaccard_list.append(0)

        j_score = np.sum(jaccard_list) / (num_A - 1)

        return j_score

    def show_pic(self, picA, picB, picC=None,
                 is_gray=(True, False, False), comment=""):
        plt.subplot(1, 3, 1)
        plt.title("x")
        if is_gray[0]:
            plt.imshow(picA, cmap='gray')
        else:
            plt.imshow(picA)

        plt.subplot(1, 3, 2)
        plt.title("GT")
        if is_gray[1]:
            plt.imshow(picB, cmap='gray')
        else:
            plt.imshow(picB)

        if picC is not None:
            plt.subplot(1, 3, 3)
            plt.title("Predict")
            if is_gray[2]:
                plt.imshow(picC, cmap='gray')
            else:
                plt.imshow(picC)

        if comment is not "":
            plt.text(0, 1, comment, fontsize=14)

        plt.show()


In [4]:
# train
import os
import gc

import cv2
import numpy as np
import scipy
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader


SHOW_NET = False


class Trainer():
    def __init__(self):
        pass

    def setup(self, valid_rate=0.1, use_cuda=True,
              cell_dir="", mask_dir="", module_save_dir="", tmp_dir="",
              criterion=None, optimizer=None, hyper_params=None,
              ):
        """setup the module"""
        self.train_dataset, self.valid_dataset = get_dataset(
            cell_dir, mask_dir, valid_rate, tmp_dir)

        self.hyper_params = hyper_params
        self.train_data_loader = DataLoader(
            dataset=self.train_dataset,
            num_workers=self.hyper_params["threads"],
            batch_size=self.hyper_params["batch_size"],
            shuffle=True
        )
        self.valid_data_loader = DataLoader(
            dataset=self.valid_dataset,
            num_workers=self.hyper_params["threads"],
            batch_size=self.hyper_params["batch_size"],
            shuffle=False
        )

        self.use_cuda = use_cuda
        self.unet = UNet(n_channels=1, n_classes=2,)
        if use_cuda:
            self.unet = self.unet.cuda()
        if SHOW_NET:
            from torchsummary import summary
            batch_size = self.hyper_params["batch_size"]
            summary(self.unet, (batch_size, 628, 628))

        self.criterion = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.SGD(
            self.unet.parameters(), lr=self.hyper_params["learning_rate"], momentum=0.99)
        self.module_save_dir = module_save_dir

        self.v = Validator(unet=self.unet,
                           hyper_params=hyper_params,
                           use_cuda=use_cuda,
                           data_loader=self.valid_data_loader)

    def train(self):
        """train the model"""
        epochs = self.hyper_params["epochs"]
        epoch_lapse = self.hyper_params["epoch_lapse"]
        batch_size = self.hyper_params["batch_size"]
        epoch_save = self.hyper_params["epoch_save"]
        width_out = 628
        height_out = 628

        for _ in range(epochs):
            total_loss = 0
            for i, data in enumerate(self.train_data_loader):

                batch_train_x, batch_train_y = data
                batch_train_y = batch_train_y.long()
                batch_train_y[batch_train_y > 0] = 1  # important!!!
                if (len(batch_train_x.size()) == 3):
                    batch_train_x = batch_train_x.unsqueeze(1)
                if (len(batch_train_y.size()) == 3):
                    batch_train_y = batch_train_y.unsqueeze(1)

                if self.use_cuda:
                    batch_train_x = batch_train_x.cuda()
                    batch_train_y = batch_train_y.cuda()

                batch_loss = self.train_step(
                    batch_train_x, batch_train_y,
                    optimizer=self.optimizer,
                    criterion=self.criterion,
                    unet=self.unet,
                    width_out=width_out,
                    height_out=height_out,
                    batch_size=batch_size)

                total_loss += batch_loss
            print("epoch", _)

            if (_+1) % epoch_lapse == 0:
                val_acc = self.v.validate()
                print("Total loss in epoch %d : %f and validation accuracy : %f" %
                      (_ + 1, total_loss, val_acc))

            if (_+1) % epoch_save == 0:
                self.save_module(name_else="epoch-" + str(_ + 1))
                print("MODULE SAVED.")
        gc.collect()
        pass

    def train_step(self, inputs, labels, optimizer,
                   criterion, unet, batch_size,
                   width_out, height_out):
        optimizer.zero_grad()
        outputs = unet(inputs)
        outputs = outputs.permute(0, 2, 3, 1)

        outputs = outputs.reshape(batch_size * width_out * height_out, 2)
        labels = labels.reshape(batch_size * width_out * height_out)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        return loss

    def save_module(self, name_else=""):
        import datetime
        module_save_dir = self.module_save_dir
        filename = 'unet-' + datetime.datetime.now().strftime('%Y%m%d%H%M%S') + \
            name_else + '.pth'
        torch.save(self.unet.state_dict(), module_save_dir + filename)
        pass


In [7]:
hyper_parameters = {
            "batch_size": 2,
            "learning_rate": 1e-4,
            "threads": 0,
            "epochs": 10000,
            "epoch_lapse": 200,
            "epoch_save": 400,
        }

cell_dir = "C:/me/dataset1/train/"
mask_dir = "C:/me/dataset1/train_GT/SEG"
module_save_dir = "C:/me/test/save/"
tmp_dir = "C:/me/test/_tmp/"

valid_rate = 0.1
use_cuda = True
trainer = Trainer()

In [8]:
trainer.setup(cell_dir=cell_dir,
              mask_dir=mask_dir,
              module_save_dir=module_save_dir,
              tmp_dir=tmp_dir,
              valid_rate=valid_rate,
              hyper_params=hyper_parameters,
              use_cuda=use_cuda)

RuntimeError: cuda runtime error (999) : unknown error at ..\aten\src\THC\THCGeneral.cpp:47

In [7]:
trainer.train()

epoch 0
epoch 1
epoch 2
epoch 3
epoch 4
epoch 5
epoch 6
epoch 7
epoch 8
epoch 9
epoch 10
epoch 11
epoch 12
epoch 13
epoch 14
epoch 15
epoch 16
epoch 17
epoch 18
epoch 19
epoch 20
epoch 21
epoch 22
epoch 23
epoch 24
epoch 25
epoch 26
epoch 27
epoch 28
epoch 29
epoch 30
epoch 31
epoch 32
epoch 33
epoch 34
epoch 35
epoch 36
epoch 37
epoch 38
epoch 39
epoch 40
epoch 41
epoch 42
epoch 43
epoch 44
epoch 45
epoch 46
epoch 47
epoch 48
epoch 49
epoch 50
epoch 51
epoch 52
epoch 53
epoch 54
epoch 55
epoch 56
epoch 57
epoch 58
epoch 59
epoch 60
epoch 61
epoch 62
epoch 63
epoch 64
epoch 65
epoch 66
epoch 67
epoch 68
epoch 69
epoch 70
epoch 71
epoch 72
epoch 73
epoch 74
epoch 75
epoch 76
epoch 77
epoch 78
epoch 79
epoch 80
epoch 81
epoch 82
epoch 83
epoch 84
epoch 85
epoch 86
epoch 87
epoch 88
epoch 89
epoch 90
epoch 91
epoch 92
epoch 93
epoch 94
epoch 95
epoch 96
epoch 97
epoch 98
epoch 99
epoch 100
epoch 101
epoch 102
epoch 103
epoch 104
epoch 105
epoch 106
epoch 107
epoch 108
epoch 109
epoch 110


RuntimeError: CUDA error: the launch timed out and was terminated

In [8]:
import datetime
trainer.save_module()

In [10]:
torch.cuda.is_available()

False