In [5]:
import os, sys, pdb, shutil, random, math, cv2, datetime, cPickle
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from collections import OrderedDict, namedtuple
from PIL import Image

from dataloaders.Classification_Image import OCT_image_classification, OCT_classification_persample
from dataloaders.Segmentation_Image import OCT_image_segmentation, OCT_segmentation_persample, natural_keys
import dataloaders.Segmentation_test_transforms as test_tr
from networks.segmentation.deeplab_xception import DeepLabv3_plus_xception
from networks.segmentation.deeplab_resnet import DeepLabv3_plus_resnet

import dataloaders.Image_transforms as Image_tr
from networks.classification.ResNet_original import ResNet34_original, ResNet50_original, ResNet18_original
from networks.classification.DenseNet_original import DenseNet121_original

from dataloaders.Image_utils import decode_segmap, decode_segmap_sequence
from utils import aic_fundus_lesion_segmentation
from tqdm import tqdm

def save_pickle(obj, save_path): 
    parent_path = os.path.dirname(save_path)  # get parent path
    if not os.path.exists(parent_path):
        os.makedirs(parent_path)
    cPickle.dump(obj, open(save_path, "wb"), True)

def load_pickle(pickle_path):
    with open(pickle_path, 'rb') as fo:
        pickle_dict = cPickle.load(fo)
    return pickle_dict    
    
def load_model(network, backbone, checkpoint_path, device, gpus, os, n_classes = 4):
    model = globals()[network](nInputChannels=1, n_classes=n_classes, os=os, backbone=backbone, checkpoint=None, ignore_prefixs = [])
    print("Load %s"%(checkpoint_path))
    checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    if "state_dict" in checkpoint.keys():
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace("module.", "") # remove `module.`
        new_state_dict[name] = v
    
    model.load_state_dict(new_state_dict)
    if len(gpus) > 1:
        model = nn.DataParallel(model, gpus)
    model.to(device)
    return model

def get_labels(label_sample_path, num_samples=128):
    # sort the image dir in numerical ascend order
    label_images = []
    image_names = os.listdir(label_sample_path)
    image_names.sort(key=natural_keys)
    for i, image_name in enumerate(image_names):
        label_image = cv2.imread(os.path.join(label_sample_path, image_name))[:,:,0]
        label_images.append(label_image)
    return label_images

def aug_batch_inputs(batch_inputs, scale, is_flip):
    """augment a batch of image"""
    assert len(batch_inputs.shape) == 4
    batch_size, original_w, original_h = batch_inputs.shape[0], batch_inputs.shape[-1], batch_inputs.shape[-2]
    target_w, target_h = int(original_w*scale), int(original_h*scale)
    batch_images = np.squeeze(batch_inputs, 1) # remove the channel dimension
    auged_images = np.empty((batch_size, target_h, target_w))
    for image_idx in range(batch_size):
        image = batch_images[image_idx]
        image = cv2.resize(image, (target_w, target_h), interpolation = cv2.INTER_LINEAR)
        if is_flip:
            image = cv2.flip(image, 1)
        auged_images[image_idx] = image
    return auged_images
    
def aug_restore(batch_outputs, target_w, target_h, is_flip):
    """restore the augmented output to target width and height"""
    assert len(batch_outputs.shape) == 4
    batch_size, num_channel = batch_outputs.shape[0], batch_outputs.shape[1]
    restored_outputs = np.empty((batch_size, num_channel, target_h, target_w))
    for image_idx in range(batch_size):
        image = np.transpose(batch_outputs[image_idx], [1, 2, 0])
        image = cv2.resize(image, (target_w, target_h), interpolation = cv2.INTER_LINEAR)
        if is_flip:
            image = cv2.flip(image, 1)
        restored_outputs[image_idx] = np.transpose(image, [2, 0, 1])
    return restored_outputs

def seg_predict_ensemble(models, inputs, device, test_scales, is_flip):
    with torch.no_grad():
        original_w, original_h = inputs.size(-1), inputs.size(-2)
        outputs = []
        to_flip = [True, False] if is_flip else [False]
        inputs_arr = inputs.detach().cpu().numpy()
        for model in models:
            model.eval()
            for scale in test_scales:
                for flip in to_flip:
                    aug_images = aug_batch_inputs(inputs_arr, scale, flip)
                    aug_images = np.expand_dims(aug_images, 1)# add the channel dimension
                    aug_images = torch.from_numpy(aug_images).float().to(device)
                    aug_outputs = model(aug_images)
                    aug_outputs = F.softmax(aug_outputs, dim=1)
                    restored_outputs = aug_restore(aug_outputs.detach().cpu().numpy(), original_w, original_h, flip)
                    outputs.append(restored_outputs)
        ensemble = np.mean(np.stack(outputs), 0)
    return ensemble
        
def seg_main(models, dataloader, device, test_scales, is_flip):
    samples, sample_predictions = [], []
    for inputs in dataloader:
        output_ensemble = seg_predict_ensemble(models, inputs, device, test_scales, is_flip)
        predictions = np.argmax(output_ensemble, 1)
        samples.append(inputs.cpu().numpy())
        sample_predictions.append(predictions)
    return np.concatenate(samples, 0), np.concatenate(sample_predictions, 0)

def write_disk(target_root, sample_names, sample_predictions):
    for i in range(len(sample_names)):
        target_path = os.path.join(target_root, sample_names[i]+"_volumes.npy")
        np.save(target_path, sample_predictions[i].astype("uint8")) # save as uint8 to save space


ModuleNotFoundError: No module named 'cPickle'

In [6]:
def get_classification_models(device, gpus):
    def load_classification_models(model_config, num_classes, device, gpus):
        model = globals()[model_config.network](model_config.net_config, num_classes)
        print("Load %s"%(model_config.checkpoint))
        model.load_state_dict(torch.load(model_config.checkpoint))
        if len(gpus) > 1:
            model = nn.DataParallel(model, gpus)
        model.to(device)
        return model
    
    model_config = namedtuple("Model", ["network", "checkpoint", "net_config"])
    model_configs = [
                 model_config("ResNet50_original", "checkpoint/multiple_label/ResNet50_original/aug_multipleLabel/epoch102.pth", None),
                 model_config("DenseNet121_original", "checkpoint/multiple_label/DenseNet121_original/aug_multipleLabel/epoch56.pth", None),
                model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel3/epoch2.pth", None),
                model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel2/epoch1.pth", None),
                model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel/epoch0.pth", None),
                
                model_config("ResNet18_original", "checkpoint/multiple_label/ResNet18_original/aug_multipleLabel/epoch8.pth", None),
                model_config("ResNet18_original", "checkpoint/multiple_label/ResNet18_original/aug_multipleLabel2/epoch162.pth", None),
                model_config("ResNet18_original", "checkpoint/multiple_label/ResNet18_original/aug_multipleLabel3/epoch141.pth", None),
                
                model_config("ResNet50_original", "checkpoint/multiple_label/ResNet50_original/aug_multipleLabel2/epoch17.pth", None),
                model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel4/epoch29.pth", None),
                model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel5/epoch20.pth", None),
                ]
    return [load_classification_models(model_configObj, 3, device, gpus) for model_configObj in model_configs]

def classification_prediction(inputs, models, device):
    with torch.no_grad():
        inputs = inputs.float()
        inputs = inputs.to(device)
        softmaxs = []
        for model in models:
            model.eval()
            output = model(inputs)
            softmax = torch.sigmoid(output).detach().cpu().numpy()
            softmaxs.append(softmax)
        ensemble = np.mean(softmaxs, 0)
    return ensemble


In [7]:
def count_abnormal(image_arr, targets):
    counts = [np.count_nonzero(image_arr == target) for target in targets]
    return counts

def divide_blocks(image_predictions, target_pixel, num_threshold):
    """divide the prediction of images in an sample by counting the number of target pixel"""
    blocks, blocks_numpixel = [], []
    start_idx, end_idx = -1, -1
    for i, prediction in enumerate(image_predictions):
        num_target_pixel = np.count_nonzero(prediction == target_pixel)
        if num_target_pixel >= num_threshold:
            if start_idx == -1: # start a new block
                start_idx = i
                end_idx = -1
        else:
            if start_idx > -1: # the block is supposed to be ended
                end_idx = i # end of a block 
                if (end_idx - 1) == start_idx: # the block contains only one image
                    blocks.append([start_idx])
                else:
                    blocks.append([start_idx, end_idx - 1])
                start_idx = -1 # restart a new block
                end_idx = -1
        if i == (len(image_predictions) - 1): # only the last image is into consideration
            if start_idx > -1 and start_idx == i:
                blocks.append([i])
            elif start_idx > -1 and start_idx != i:
                blocks.append([start_idx, i])
    return blocks

def count_pixels(blocks, image_predictions, target_pixel):
    """count number of target pixel in blocks"""
    blocks_numpixel = []
    for block in blocks:
        assert len(block) in [1, 2]
        if len(block) == 1:
            num_pixels = np.count_nonzero(image_predictions[block[0]] == target_pixel)
            blocks_numpixel.append(num_pixels)
        else:
            start_idx, end_idx = block[0], block[1]
            num_pixels = np.sum([np.count_nonzero(image_predictions[i] == target_pixel) for i in range(start_idx, end_idx + 1)])
            blocks_numpixel.append(num_pixels)
    return blocks_numpixel

def remove_block_connection(classification_index, classification_dataloader, classification_models, device,
                     image_predictions,
                     blocks, block_numpixels,
                     class_type, target_class_type = 0,
                     abnormal_threshold = 0.5, passed_percent = 0.5):
    for block_idx, block in enumerate(blocks):
        images_passed = []
        if len(block) == 1:
            start_idx, end_idx = block[0], block[0]
        else:
            start_idx, end_idx = block[0], block[1]
            
        for inputs_idx, inputs in enumerate(classification_dataloader):
            if inputs_idx >= start_idx and inputs_idx <= end_idx:
                classification_softmax = classification_prediction(inputs, classification_models, device)
                if classification_softmax[0, classification_index] >= abnormal_threshold:
                    images_passed.append(1)
                else:
                    images_passed.append(0)
        
        if float(np.sum(images_passed)) / len(images_passed) < passed_percent: # replace the class type
            if block_idx == np.argmax(block_numpixels):
                print("Segmentation class %d, Total %d blocks, %d-th block has %d valid images, %d abnormal pixels (largest block), ignore !!!"%(class_type, len(blocks),block_idx, (end_idx - start_idx + 1), block_numpixels[block_idx]))
#                 print("Segmentation class %d, Total %d blocks, Remove %d-th block has %d valid images, %d abnormal pixels (largest block) !!!"%(class_type, len(blocks),block_idx, (end_idx - start_idx + 1), block_numpixels[block_idx]))
            else:
                print("Segmentation class %d, Total %d blocks, Remove %d-th block has %d valid images, %d abnormal pixels"%(class_type, len(blocks), block_idx, (end_idx - start_idx + 1), block_numpixels[block_idx]))
                for i in range(start_idx, end_idx + 1):
                    np.place(image_predictions[i], image_predictions[i]==class_type, target_class_type)
    return image_predictions    

def replace_pixel(sample_predictions, source_pixel, target_pixel):
    """replace source pixel with target pixel in sample prediction """
    replaced_predictions = []
    for sample_prediction in sample_predictions:
        np.place(sample_prediction, sample_prediction==source_pixel, target_pixel)
        replaced_predictions.append(sample_prediction)
    return replaced_predictions
    

In [8]:
def vanilla_seg(root_path, seg_models, seg_tr, device, batch_size, num_workers, test_scales, is_flip):
    sample_images, sample_predictions, sample_names = [], [], []
    for sample_name in tqdm(sorted(os.listdir(root_path))):
        sample_names.append(sample_name.replace(".img", ""))
        sample_path = os.path.join(root_path, sample_name)
        seg_dataset = OCT_segmentation_persample(sample_path, transform = seg_tr)
        seg_datasetloader = torch.utils.data.DataLoader(seg_dataset, batch_size = batch_size, shuffle=False, num_workers=num_workers)
        images, predictions = seg_main(seg_models, seg_datasetloader, device, test_scales, is_flip)
        sample_images.append(images)
        sample_predictions.append(predictions)
    return sample_images, sample_predictions, sample_names

def seg_connection(root_path, seg_models, classification_models, seg_tr, classification_tr, 
                   device, seg_batch_size, num_workers, class_included, class_thresholds, abnormal_threshold, passed_percent,
                   test_scales, is_flip):
    sample_images, sample_predictions, sample_names = [], [], []
    for sample_idx, sample_name in enumerate(sorted(os.listdir(root_path))):
        time_start = datetime.datetime.now()
        sample_names.append(sample_name.replace(".img", ""))
        sample_path = os.path.join(root_path, sample_name)
        seg_dataset = OCT_segmentation_persample(sample_path, transform = seg_tr)
        seg_datasetloader = torch.utils.data.DataLoader(seg_dataset, batch_size = seg_batch_size, shuffle=False, num_workers=num_workers)
        images, predictions = seg_main(seg_models, seg_datasetloader, device, test_scales, is_flip)
        
        classification_dataset = OCT_classification_persample(sample_path, classification_tr)
        classification_loader = torch.utils.data.DataLoader(classification_dataset, batch_size = 1,
                                             shuffle=False, num_workers=config.num_workers)
        
        for class_type in class_included:
            classification_index = class_type - 1
            blocks = divide_blocks(predictions, class_type, class_thresholds[class_type])
            block_numpixels = count_pixels(blocks, predictions, class_type)
            predictions = remove_block_connection(classification_index, classification_loader, classification_models, device,
                         predictions,
                         blocks, block_numpixels,
                         class_type, target_class_type = 0,
                         abnormal_threshold=abnormal_threshold, passed_percent=passed_percent)
        time_end = datetime.datetime.now()
        sample_images.append(images)
        sample_predictions.append(predictions)
        print("{}-th sample processed, cost {} seconds----------------------------".format(sample_idx, (time_end-time_start).seconds))
    return sample_images, sample_predictions, sample_names

In [9]:
class Config(object):
    def __init__(self):
        
        self.label_dict = OrderedDict([(0, 0), (255, 1), (191, 2), (128, 3)])
        self.gpus = "0, 1, 2, 3"
        self.os = 16
        self.n_classes = 4
        self.batch_size = 20
        
        self.num_workers = 6

config = Config()

IndentationError: unexpected indent (<ipython-input-9-f86d3b1bdcc2>, line 8)

In [None]:
gpus = map(int, config.gpus.split(","))
device = torch.device("cuda:{}".format(gpus[0]))

# ResNet101, checkpoint/segmentation/DeepLabv3_plus_resnet/aug_512_1024/epoch13.pth
# ResNet50, checkpoint/segmentation/DeepLabv3_plus_resnet/aug_classweight_1_1_1.5_10/epoch34.pth
# ResNet50, checkpoint/segmentation/DeepLabv3_plus_resnet/aug_dice_scale_0.75_1.5_weight_1_1.5_1.5_6/epoch31.pth

seg_models = [load_model("DeepLabv3_plus_resnet", "ResNet101", "checkpoint/segmentation/DeepLabv3_plus_resnet/aug_ResNet101_cross_entropy_avg1_scale_0.5_2.0_weight_1_1.5_1.5_10/epoch12.pth", 
               device, gpus, 16, config.n_classes)]

seg_tr = transforms.Compose([
        test_tr.Normalize_divide(255.0),
        test_tr.ToTensor()])

classification_models = get_classification_models(device, gpus)
classification_tr = transforms.Compose([
                             Image_tr.Resize((224, 224)),
                             Image_tr.Normalize_divide(255.0)])

root_path = "./data/Edema_testset/original_images"

sample_images, sample_predictions, sample_names = vanilla_seg(root_path, seg_models, seg_tr, device, 
                                                              batch_size=config.batch_size, num_workers=config.num_workers,
                                                              test_scales = [0.5, 0.75, 1.0, 1.25, 1.50, 1.75], is_flip = True)

# sample_images, sample_predictions, sample_names = seg_connection(root_path, seg_models, classification_models, 
#                                                                  seg_tr, classification_tr, 
#                                                                  device, config.batch_size, config.num_workers,
#                                                                  class_included = [1, 2],
#                                                                  class_thresholds = {1: 500, 2: 10, 3:5},
#                                                                  abnormal_threshold=0.75, passed_percent=0.5,
#                                                                  test_scales = [0.5, 0.75, 1.0, 1.25, 1.50, 1.75], is_flip = True)
    

In [None]:
save_pickle(sample_names, "./predictions/pickles/20181007/sample_names_epoch12.pkl")
save_pickle(sample_predictions, "./predictions/pickles/20181007/segmentation_sample_predictions_epoch12.pkl")

In [None]:
sample_names = load_pickle("./predictions/pickles/20181007/sample_names_epoch12.pkl")
sample_predictions = load_pickle("./predictions/pickles/20181007/segmentation_sample_predictions_epoch12.pkl")

In [None]:
sample_predictions = replace_pixel(sample_predictions, 2, 0)
sample_predictions = replace_pixel(sample_predictions, 3, 0)

In [None]:
write_disk("./predictions/segmentation/test/xxy/20181007_1/", sample_names, sample_predictions)

In [None]:
vallabel_root = "./data/Edema_validationset/label_images"
sample_labels = [get_labels(os.path.join(vallabel_root, sample_name)) for sample_name in tqdm(sorted(os.listdir(vallabel_root)))]
sample_dices = []
for sample_idx in range(len(sample_labels)):
    label = np.array(sample_labels[sample_idx])
    for target_pixel in config.label_dict:
        np.place(label, label==target_pixel, config.label_dict[target_pixel])
    
    prediction = np.array(sample_predictions[sample_idx])
    sample_dice = aic_fundus_lesion_segmentation(label, prediction)
    sample_dices.append(sample_dice)
    print(sample_idx, sample_dice)

In [None]:
valid_dices = [[], [], [], []]
for sample_idx in range(len(sample_dices)):
    for data_type in range(4):
        dice_score = sample_dices[sample_idx][data_type]
        if not math.isnan(dice_score):
            valid_dices[data_type].append(dice_score)
print("mean class dices: %s"%([np.mean(dices) for dices in valid_dices]))

NameError: name 'sample_labels' is not defined

In [None]:
sample_idx = 13
for image_idx in range(128):
    image = np.squeeze(sample_images[sample_idx][image_idx])
    label = sample_labels[sample_idx][image_idx]
    prediction = sample_predictions[sample_idx][image_idx]
    plt.figure()
    plt.subplot(131)
    plt.imshow(image, cmap = "gray")
    plt.subplot(132)
    plt.imshow(label, cmap = "gray")
    plt.subplot(133)
    plt.imshow(decode_segmap(prediction), cmap = "gray")
    plt.show()
    print("%d, label   abnormal pixels: %s"%(image_idx, count_abnormal(label, [255, 191, 128])))
    print("%d, predict abnormal pixels: %s"%(image_idx, count_abnormal(prediction, [1, 2, 3])))

In [None]:
class_thresholds = {1: 500, 2: 10, 3:5}
class_type = 1

rea_blocks = divide_blocks(sample_predictions[sample_idx], class_type, class_thresholds[class_type])
block_numpixels = count_pixels(rea_blocks, sample_predictions[sample_idx], class_type)
print(rea_blocks)
print(block_numpixels)

In [None]:
# original dice scores per sample, DeepLab V3+, aug_512_1024, epoch 13
# 0, [0.9953913572955032, 0.8340105323339809, 0.897754417707465, nan]
# 1, [0.9711029532694698, 0.5825124021131763, nan, nan]
# 2, [0.984575074181537, 0.87036935702214, nan, nan]
# 3, [0.9804807003555078, 0.7994834911793456, 0.49694016868922786, nan]
# 4, [0.978341909341848, 0.835843416784337, 0.5977946941352797, nan]
# 5, [0.9954415541089505, 0.8556120417099563, nan, nan]
# 6, [0.9978480137764166, 0.6011353655545492, nan, 0.7816266971196548]
# 7, [0.9810973092135972, 0.7503458588967595, 0.5657404002731523, nan]
# 8, [0.9738636436197384, 0.6333658331661316, nan, nan]
# 9, [0.9960554992443774, 0.8708277615796239, 0.9339293126129993, nan]
# 10, [0.9968853684332056, 0.9228649151460038, 0.9291939554231939, 0.06471816283924843]
# 11, [0.9970407295810517, 0.918798821803022, 0.9196883861597178, 0.6642642642642642]
# 12, [0.9979058149555472, 0.9435037670132433, 0.9506723939470507, 0.6248153618906942]
# 13, [0.9996719836525478, 0.6989183047530227, nan, 0.7698178237321517]
# 14, [0.9994861342220099, 0.6812009412102815, nan, 0.7936418359668924]
# mean class dices: [0.9896792030167539, 0.7865861873510382, 0.7864642161185108, 0.6164806909688175]

In [None]:
# {1: 800, 2: 10, 3:5}, abnormal threshold: 0.5, passed_percent: 0.5
# [0.9899642521913241, 0.7941926848197653, 0.7969739480984086, 0.616742385984297]

# {1: 800, 2: 10, 3:5}, abnormal threshold: 0.6, passed_percent: 0.5
# [0.9901855736327988, 0.7961173806069044, 0.7969738600277836, 0.6045752257517375]

# {1: 800, 2: 10, 3:5}, abnormal threshold: 0.4, passed_percent: 0.5
# [0.9899414191899925, 0.7934372731537295, 0.7969738600277836, 0.616742385984297]

# {1: 800, 2: 10, 3:5}, abnormal threshold: 0.4, passed_percent: 0.4
# [0.9899414191899925, 0.7934372731537295, 0.7969738600277836, 0.616742385984297]

# {1: 800, 2: 10, 3:5}, abnormal threshold: 0.7, passed_percent: 0.5, classes: [1, 2]
# [0.9902553894011289, 0.7967099289736387, 0.7969738600277836, 0.6164806909688175]

# {1: 800, 2: 10, 3:5}, abnormal threshold: 0.75, passed_percent: 0.5, classes: [1, 2]
# [0.9904760316560257, 0.7977752882140663, 0.7969738600277836, 0.6164806909688175]

In [4]:
round(0.9902149958506881, 6)

0.990215