In [1]:
from __future__ import print_function
from __future__ import division

import copy
import glob
import os
from os import listdir,mkdir,rmdir
from os.path import join,isdir,isfile
from PIL import Image
import time

import cv2
import numpy as np
from scipy.ndimage.morphology import binary_dilation,binary_erosion
from skimage import exposure

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
import torchvision
from torchvision import datasets, models, transforms,utils
from torchvision.transforms import functional as func
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation.fcn import FCNHead

%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
plt.rcParams['figure.figsize'] = 16,9
plt.rcParams.update({'font.size': 22})

In [3]:
#SWITCHES

path_tr = "/home/darvin/Data/cellSegmentation/train_combine"
path_va = "/home/darvin/Data/cellSegmentation/test"
path_te = "/home/darvin/Data/cellSegmentation/test"

num_classes = 2

batch_size = 4

num_epochs = 100

feature_extract = False

In [4]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [5]:
# Initialize Model

#model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
#model.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
#model.classifier = DeepLabHead(2048, num_classes)
#set_parameter_requires_grad(model, feature_extract)

model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False)
model.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.classifier = DeepLabHead(2048, num_classes)
set_parameter_requires_grad(model, feature_extract)

#model = torchvision.models.segmentation.fcn_resnet50(pretrained=False)
#model.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
#model.classifier = FCNHead(2048, num_classes)
#set_parameter_requires_grad(model, feature_extract)

print(model)

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [6]:
# Augmentation Classes

class RandomRotate(object):
    def __call__(self, sample):
        X,Y = sample['X'],sample['Y']
        rotnum = np.random.choice(4)
        for ii in range(X.shape[0]):
            X[ii,:,:] = np.rot90(X[ii,:,:],k=rotnum,axes=(0,1))
        Y[:,:] = np.rot90(Y,k=rotnum,axes=(0,1))
        return {'X':X, 'Y':Y}
    
class RandomShift(object):
    def __init__(self, max_shift=32):
        self.max_shift = int(max_shift)
    def __call__(self, sample):
        X,Y = sample['X'],sample['Y']
        h,w = X.shape[1:]
        X_shift = np.zeros((X.shape[0], X.shape[1]+2*self.max_shift, X.shape[2]+2*self.max_shift))
        for ii in range(X.shape[0]):
            X_shift[ii,:,:] = np.pad(X[ii,:,:], self.max_shift, mode='constant')
        Y_shift = np.pad(Y, self.max_shift, mode='constant')
        top     = np.random.randint(0, 2*self.max_shift)
        left    = np.random.randint(0, 2*self.max_shift)
        X[:,:,:] = X_shift[:,top:(top+h), left:(left+w)]
        Y[:,:]   = Y_shift[top:(top+h), left:(left+w)]
        return {'X':X, 'Y':Y}

class RandomFlip(object):
    def __init__(self, flip_prob=0.5):
        self.flip_prob = flip_prob
    def __call__(self, sample):
        X,Y = sample['X'],sample['Y']
        if np.random.rand() > self.flip_prob:
            X[:,:,:] = X[:,:,::-1]
            Y[:,:]   = Y[:,::-1]
        return {'X':X, 'Y':Y}
    
class ToTensor(object):
    def __call__(self, sample):
        X,Y = sample['X'], sample['Y']
        return {'X': torch.from_numpy(X).float(), 'Y': torch.from_numpy(Y).long()}


In [None]:
class cellSegmentationDataset(Dataset):
    def __init__(self, path_imgs, path_segs, transform=None):
        self.path_imgs = path_imgs
        self.path_segs = path_segs
        self.transform = transform
        self.list_imgs = sorted(listdir(path_imgs))
    def __len__(self):
        return len(self.list_imgs)
    def __getitem__(self,idx):
        img = cv2.imread(join(self.path_imgs, self.list_imgs[idx]))[:,:,0]
        img = img.astype(np.float32) / 255
        #img = np.transpose(img, (2,0,1))
        img = img.reshape([1,img.shape[0],img.shape[1]])
        
        seg = cv2.imread(join(self.path_segs, self.list_imgs[idx]))
        seg = (seg[:,:,0] > 0) + 0
        
        sample = {'X': img, 'Y': seg}
        if self.transform:
            sample = self.transform(sample)
        
        return sample

In [None]:
data_tr = cellSegmentationDataset(join(path_tr, 'img'), join(path_tr,'seg'),
                                  transform=transforms.Compose([RandomRotate(),RandomFlip(),RandomShift(),ToTensor()]))
data_va = cellSegmentationDataset(join(path_va, 'img'), join(path_va,'seg'),
                                  transform=transforms.Compose([ToTensor()]))
data_te = cellSegmentationDataset(join(path_te, 'img'), join(path_te,'seg'),
                                  transform=transforms.Compose([ToTensor()]))

loader_tr = DataLoader(data_tr, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
loader_va = DataLoader(data_va, batch_size=batch_size, shuffle=False, num_workers=4)
loader_te = DataLoader(data_te, batch_size=batch_size, shuffle=False, num_workers=4)

dataloaders_tr = {'train': loader_tr, 'val': loader_va}

In [None]:
# Send the model to GPU
model = model.cuda()

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# Observe that all parameters are being optimized
optimizer = optim.Adam(params_to_update, lr=0.001, weight_decay=0.001)
#optimizer = optim.SGD(params_to_update, lr=0.0001, momentum=0.9, weight_decay=0.001)

In [None]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    since = time.time()

    tr_iou_history = []
    tr_loss_history = []
    val_iou_history = []
    val_loss_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_iou = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_iou = 0.0
            counter = 0

            # Iterate over data.
            optimizer.zero_grad()
            for sample_batch in dataloaders[phase]:
                X = sample_batch['X'].cuda()
                Y = sample_batch['Y'].cuda()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    output = model(X)
                    loss = criterion(output['out'],Y)
                    
                    Y_copy = Y.detach().cpu().numpy()
                    weight = np.ones_like(Y_copy)
                    for ii in range(Y.size(0)):
                        Y_slice = Y_copy[ii,:,:]
                        boundary = (binary_dilation(Y_slice > 0) + 0) - (binary_erosion(Y_slice > 0) + 0)
                        weight[ii,:,:] += boundary * 9
                    loss = torch.mean(loss * torch.from_numpy(weight).float().cuda())
                    
                    if phase == 'train':
                        loss.backward()
                        counter += 1
                        if counter%4 == 0:
                            optimizer.step()
                            optimizer.zero_grad()
                
                running_loss += loss.item() * X.size(0)
                pred = torch.argmax(output['out'],dim=1)
                running_iou += 2 * torch.sum(pred.float() * Y.float()) * X.size(0) / (torch.sum(pred.float()) + torch.sum(Y.float()))
            if counter%4 != 0 and phase == 'train':
                optimizer.step()
                optimizer.zero_grad()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_iou = running_iou / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} IOU: {:.4f}'.format(phase, epoch_loss, epoch_iou))

            # deep copy the model
            if phase == 'val':
                if epoch_iou > best_iou:
                    best_iou = epoch_iou
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(model.state_dict(), '/home/darvin/Models/cellSegmentation.pth')
                val_loss_history.append(epoch_loss)
                val_iou_history.append(epoch_iou)
            else:
                tr_iou_history.append(epoch_iou)
                tr_loss_history.append(epoch_loss)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val IOU: {:4f}'.format(best_iou))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, [val_iou_history,val_loss_history,tr_iou_history,tr_loss_history]

In [None]:
# Setup the loss fxn
#criterion = nn.CrossEntropyLoss(torch.Tensor([1,3]).cuda())
criterion = nn.CrossEntropyLoss(reduction='none')

# Train and evaluate
model, history = train_model(model, dataloaders_tr, criterion, optimizer, num_epochs=num_epochs)

In [None]:
best_state_dict = torch.load('/home/darvin/Models/cellSegmentation.pth')
model.load_state_dict(best_state_dict)

In [None]:
def apply_model(model, path_img):
    model = model.cuda().eval()
    img = cv2.imread(path_img)[:,:,0]
    img = img.astype(np.float32) / 255
    seg = F.softmax(model(torch.from_numpy(img.reshape([1,img.shape[0],img.shape[1]])).float().unsqueeze(0).cuda())['out'],dim=1)
    seg = seg[0,:,:,:].detach().cpu().numpy()
    return img,seg[1,:,:]

def calculate_ious(seg,gt,ts):
    ious = np.zeros_like(ts)
    for ii,t in enumerate(ts):
        intersection = np.sum((seg > t) * (gt > 0))
        union = np.sum(((seg > t) + (gt > 0)) > 0)
        ious[ii] = float(intersection) / float(union) + float(512**2 - union) / float(512**2 - intersection)
    return ious / 2

def save_visualization(model,path_img,path_gt,path_save):
    model = model.cuda().eval()
    img = cv2.imread(path_img).astype(np.float32) / 255
    
    vis = np.zeros_like(img)
    gt = (cv2.imread(path_gt)[:,:,0] > 0) + 0
    
    img = img[:,:,0]
    seg = F.softmax(model(torch.from_numpy(img.reshape([1,img.shape[0],img.shape[1]])).float().unsqueeze(0).cuda())['out'],dim=1)
    seg = seg[0,:,:,:].detach().cpu().numpy()
    
    intersection = np.sum((seg[1,:,:] > 0.5) * (gt > 0))
    union = np.sum(((seg[1,:,:] > 0.5) + (gt > 0)) > 0)
    iou = float(intersection) / float(union)
    
    for ii in range(3):
        vis[:,:,ii] += img * 0.8
    vis[:,:,1] += gt * 0.5
    vis[:,:,0] += ((seg[1,:,:] > 0.5) + 0) * 0.5
    vis[:,:,2] += ((seg[1,:,:] > 0.5) + 0) * 0.5
    vis = np.clip(vis,0,1)
    vis = cv2.putText((255*vis).astype(np.uint8),'IoU: ' + str(iou),org=(10,30), fontFace = cv2.FONT_HERSHEY_SIMPLEX ,fontScale=1,color=(0,255,255), thickness=2)
    cv2.imwrite(path_save,vis)
    
def normalize_img(img):
    img -= img.min()
    img /= (img.max() + 1e-6)
    return img

def preprocess_img(path_img):
    img = np.array(Image.open(path_img)).astype(np.float32)
    img = normalize_img(img)
    img = exposure.equalize_adapthist(img, clip_limit=0.03)
    img = cv2.resize(img,(512,512))
    return img

def do_full_inference(model,path_img,path_save=None):
    model = model.cuda().eval()
    
    img = preprocess_img(path_img)
    if len(img.shape) == 3:
        img = img[:,:,0]
    
    seg = F.softmax(model(torch.from_numpy(img.reshape([1,img.shape[0],img.shape[1]])).float().unsqueeze(0).cuda())['out'],dim=1)
    seg = seg[0,:,:,:].detach().cpu().numpy()
    
    if path_save:
        cv2.imwrite(path_save,(255*seg[1,:,:]).astype(np.float32))
        
    return img,seg[1,:,:]

In [None]:
ind = 30

path_testing = '/home/darvin/Data/cellSegmentation/test/'
list_imgs = sorted(listdir(join(path_testing,'img')))
path_img = join(path_testing,'img',list_imgs[ind])
path_gt = join(path_testing,'seg',list_imgs[ind])
img,seg = apply_model(model,path_img)
gt = cv2.imread(path_gt)[:,:,0]

fig,ax = plt.subplots(1,3)
ax[0].imshow(img,cmap='bone')
ax[1].imshow(seg>0.5,cmap='bone')
ax[2].imshow(gt,cmap='bone')

intersection = np.sum((seg > 0.5) * (gt > 0))
union = np.sum(((seg > 0.5) + (gt > 0)) > 0)
iou1 = float(intersection) / float(union)
intersection = np.sum((seg < 0.5) * (gt == 0))
union = np.sum(((seg < 0.5) + (gt == 0)) > 0)
iou2 = float(intersection) / float(union)
ax[1].set_title((iou1+iou2)/2)

In [None]:
fig,ax = plt.subplots(1,2)

ax[0].imshow((seg > 0.5) * (gt > 0))
ax[1].imshow(((seg > 0.5) + (gt > 0)) > 0)

In [None]:
ts = np.arange(0.01,1.0,0.01)
ious = np.zeros_like(ts)
path_va = '/home/darvin/Data/cellSegmentation/test/'
list_imgs = sorted(listdir(join(path_va,'img')))
for name_img in list_imgs:
    path_img = join(path_va,'img',name_img)
    path_gt = join(path_va,'seg',name_img)
    img,seg = apply_model(model,path_img)
    gt = cv2.imread(path_gt)[:,:,0]
    ious += calculate_ious(seg,gt,ts) / len(list_imgs)

In [None]:
plt.plot(ts,ious)

In [None]:
ts[np.argmax(ious)]

In [None]:
path_va = '/home/darvin/Data/cellSegmentation/test/'
path_save = join(path_va, 'vis')
if not isdir(path_save):
    mkdir(path_save)
list_imgs = sorted(listdir(join(path_va,'img')))
for name_img in list_imgs:
    path_img = join(path_va,'img',name_img)
    path_gt = join(path_va,'seg',name_img)
    save_visualization(model,path_img,path_gt,join(path_save,name_img))

In [None]:
def prepare_array(path_img,path_seg):
    img = np.array(Image.open(path_img)).astype(np.float32)
    seg = np.array(Image.open(path_seg)).astype(np.float32)
    
    img = normalize_img(img)
    #seg = (seg > 0)
    img = exposure.equalize_adapthist(img, clip_limit=0.03)
    
    img = cv2.resize(img,(512,512))
    seg = cv2.resize(seg,(512,512))
    seg = (seg > 0)
    
    img = (255*img).astype(np.uint8)
    seg = (255*seg).astype(np.uint8)
    
    return img,seg

def read_img(path_img):
    img = np.array(Image.open(path_img)).astype(np.float32)
    img = normalize_img(img)
    if len(img.shape) == 3:
        img = img[:,:,0]
    img = cv2.resize(img,(512,512))
    return img

def calculate_iou(seg,gt):
    intersection = np.sum((seg> 0.5) * (gt > 0))
    union = np.sum(((seg > 0.5) + (gt > 0)) > 0)
    iou = float(intersection) / float(union)
    return iou

In [None]:
path_imgs = "/home/darvin/Data/cellSegmentation/rawDataFromHumanExperts"
path_save = "/home/darvin/Data/cellSegmentation/humanVisualization"
list_imgs = sorted(listdir(path_imgs))

dict_hum2mac = {}

for name_img in list_imgs:
    parts = name_img.split('_')
    if parts[0] != 'DIC':
        continue
    name_root = '_'+'_'.join(parts[1:])
    path_img = join(path_imgs, 'DIC' + name_root)
    path_seg = join(path_imgs, 'Nul' + name_root)
    for name_human in ['James', 'LiAng', 'Sheng', 'Yitong']:
        path_hum = join(path_imgs, name_human + name_root)
        if isfile(path_hum):
            break
    if name_human not in dict_hum2mac:
        dict_hum2mac[name_human] = ([],[])
    img = read_img(path_img)
    gt  = read_img(path_seg)
    hum = read_img(path_hum)
    _,pred = do_full_inference(model,path_img,path_save=None)
    
    path_save_img = join(path_save, name_root[1:])
    
    vis = np.zeros((512,1024,3))
    for ii in range(3):
        for jj in range(2):
            vis[:,(jj*512):(jj*512+512),ii] += img * 0.8
            if ii == 1:
                vis[:,(jj*512):(jj*512+512),ii] += gt * 0.5
    vis[:,512:1024,0] += hum * 0.5
    vis[:,:512,2] += ((pred > 0.5) + 0) * 0.5
    vis = np.clip(vis,0,1)
    
    iou_machine = calculate_iou(pred,gt)
    iou_human = calculate_iou(hum,gt)
    dict_hum2mac[name_human][0].append(iou_human)
    dict_hum2mac[name_human][1].append(iou_machine)
    
    
    vis = (255*vis).astype(np.uint8)
    vis = cv2.putText(vis,'IoU: ' + str(iou_machine),org=(10,30), fontFace = cv2.FONT_HERSHEY_SIMPLEX ,fontScale=1,color=(0,255,255), thickness=2)
    vis = cv2.putText(vis,'IoU: ' + str(iou_human),org=(522,30), fontFace = cv2.FONT_HERSHEY_SIMPLEX ,fontScale=1,color=(255,255,0), thickness=2)
    vis = cv2.putText(vis, name_human,org=(522,80), fontFace = cv2.FONT_HERSHEY_SIMPLEX ,fontScale=1,color=(255,255,0), thickness=2)
    cv2.imwrite(path_save_img,vis)

In [None]:
dict_hum2col = {'Yitong':'r', 'Sheng':'g', 'LiAng':'b', 'James':'k'}

for name_human in dict_hum2col:
    col = dict_hum2col[name_human]
    x,y = dict_hum2mac[name_human]
    plt.scatter(x,y,c=col)
plt.plot([0,1],[0,1])