In [None]:
import os, pdb, cv2, random, traceback, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from pprint import pprint, pformat
from tqdm import tqdm
from collections import OrderedDict, namedtuple
from sklearn.metrics import  roc_curve, auc

from dataloaders.Classification_Image import OCT_image_classification, OCT_classification_persample, natural_keys
from dataloaders.Image_transforms import Resize, Split_h, Normalize_divide, To_CHW    
from sklearn import metrics

from networks.classification.ResNet import ResNet34, ResNet50, ResNet101
from networks.classification.ResNet_original import ResNet34_original, ResNet50_original
from networks.classification.DenseNet_original import DenseNet121_original

from utils import aic_fundus_lesion_classification

def save_pickle(obj, save_path): 
    parent_path = os.path.dirname(save_path)  # get parent path
    if not os.path.exists(parent_path):
        os.makedirs(parent_path)
    cPickle.dump(obj, open(save_path, "wb"), True)

def load_pickle(pickle_path):
    with open(pickle_path, 'rb') as fo:
        pickle_dict = cPickle.load(fo)
    return pickle_dict

def load_model(model_config, num_classes, device, gpus):
    model = globals()[model_config.network](model_config.net_config, num_classes)
    print("Load %s"%(model_config.checkpoint))
    model.load_state_dict(torch.load(model_config.checkpoint))
    if len(gpus) > 1:
        model = nn.DataParallel(model, gpus)
    model.to(device)
    return model

def get_labels(label_root, label_dict = OrderedDict([(255, 0), (191, 1), (128, 2)]), num_samples=128):
    total_labels = []
    for label_sample_name in tqdm(sorted(os.listdir(label_root))):
        sample_labels = np.zeros((num_samples, 3))
        # sort the image dir in numerical ascend order
        image_names = os.listdir(os.path.join(label_root, label_sample_name))
        image_names.sort(key=natural_keys)
        for i, image_name in enumerate(image_names):
            label_image = cv2.imread(os.path.join(label_root, label_sample_name, image_name))[:,:,0]
            for target_pixel in label_dict:
                if target_pixel in label_image:
                    sample_labels[i, label_dict[target_pixel]] = 1
        total_labels.append(sample_labels)
    return total_labels

def predict_ensemble(models, inputs, device):
    """predict single input from the data loader via multiple models"""
    with torch.no_grad():
        inputs = inputs.float()
        inputs = inputs.to(device)
        softmaxs = []
        for model in models:
            model.eval()
            output = model(inputs)
            softmax = F.softmax(output, dim=1).detach().cpu().numpy()
            softmaxs.append(softmax)
        ensemble = np.mean(softmaxs, 0)
    return ensemble

def main(normal_edemas_models, edema_PEDs_models, edema_SRFs_models,
         batch_size,
         dataloader,
         device,
         num_classes = 2):
    predictions = np.zeros((128, 3))
    for batch_idx, inputs in enumerate(dataloader):
        edema_softmax = predict_ensemble(normal_edemas_models, inputs, device)
        PED_softmax = predict_ensemble(edema_PEDs_models, inputs, device)
        SRF_softmax = predict_ensemble(edema_SRFs_models, inputs, device)
        for image_idx in range(inputs.size(0)):
            predictions[batch_idx*batch_size + image_idx, 0] = edema_softmax[image_idx, 1]
            predictions[batch_idx*batch_size + image_idx, 1] = SRF_softmax[image_idx, 1]
            predictions[batch_idx*batch_size + image_idx, 2] = PED_softmax[image_idx, 1]
    return predictions

def write_disk(target_root, sample_names, sample_predictions):
    for i in range(len(sample_names)):
        target_path = os.path.join(target_root, sample_names[i]+"_detections.npy")
        np.save(target_path, sample_predictions[i])

In [None]:
class Config(object):
    def __init__(self):
        self.batch_size = 60
        self.num_classes = 2
        self.num_split = 2
        
        self.target_h = 224
        self.target_w = 224
        
        self.gpus = "0, 3"
        self.num_workers = 4

config = Config()        

In [None]:
gpus = map(int, config.gpus.split(","))
device = torch.device("cuda:{}".format(gpus[0]))
resnet_config = {"ensemble_idx": 7, "feature_size": 7}

model_config = namedtuple("Model", ["network", "checkpoint", "net_config"])

# train
# normal/edemas
# model_config("ResNet34_original", "checkpoint/normal_edema/ResNet34_original/aug/epoch23.pth", None), AUC 0.9907
# model_config("ResNet50_original", "checkpoint/normal_edema/ResNet50_original/aug/epoch89.pth", None), epoch 89 AUC 0.9864, epoch 70 AUC 0.9861
# model_config("DenseNet121_original", "checkpoint/normal_edema/DenseNet121_original/aug/epoch41.pth", None), epoch 41 AUC 0.9869, epoch 56 0.9848
# model_config("DenseNet121_original", "checkpoint/normal_edema/DenseNet121_original/aug_20180924_150746/epoch88.pth", None) AUC 0.9831

# used, AUC 0.9912 
# model_config("ResNet34_original", "checkpoint/normal_edema/ResNet34_original/aug/epoch23.pth", None),
# model_config("ResNet50_original", "checkpoint/normal_edema/ResNet50_original/aug/epoch89.pth", None),
# model_config("DenseNet121_original", "checkpoint/normal_edema/DenseNet121_original/aug/epoch41.pth", None)

# edema_SRFs, AUC 0.9913
# model_config("ResNet34_original", "checkpoint/edema_SRF/ResNet34_original/aug_oversample_includeNormal/epoch19.pth", None)
# model_config("ResNet50_original", "checkpoint/edema_SRF/ResNet50_original/aug_includeNormal/epoch7.pth", None)

# PED
# model_config("ResNet50_original", "checkpoint/edema_PED/ResNet50_original/aug_oversample_includeNormal/epoch15.pth", None), AUC 0.9913
# model_config("ResNet50_original", "checkpoint/edema_PED/ResNet50_original/aug_oversample_includeNormal_20180924_150856/epoch65.pth", None), AUC 0.9924

# model_config("DenseNet121_original", "checkpoint/edema_PED/DenseNet121_original/aug_oversample_includeNormal/epoch13.pth", None), epoch 13 AUC 0.9924, epoch 16 AUC 0.9924
# model_config("DenseNet121_original", "checkpoint/edema_PED/DenseNet121_original/aug_oversample_includeNormal_20180924_150921/epoch67.pth", None),  AUC 0.9915

# desired, checkpoint/normal_edema/ResNet50_original/aug/epoch70.pth
normal_edemas = [
                model_config("ResNet34_original", "checkpoint/normal_edema/ResNet34_original/aug/epoch23.pth", None)
#                  model_config("ResNet50_original", "checkpoint/normal_edema/ResNet50_original/aug/epoch70.pth", None),
#                  model_config("DenseNet121_original", "checkpoint/normal_edema/DenseNet121_original/aug/epoch41.pth", None)
                ]

edema_SRFs = [
    model_config("ResNet50_original", "checkpoint/edema_SRF/ResNet50_original/aug_includeNormal/epoch7.pth", None)
#              model_config("ResNet34_original", "checkpoint/edema_SRF/ResNet34_original/aug_oversample_includeNormal/epoch19.pth", None),
#              model_config("ResNet50_original", "checkpoint/edema_SRF/ResNet50_original/aug_includeNormal/epoch7.pth", None)
            ]

edema_PEDs = [
    model_config("ResNet50_original", "checkpoint/edema_PED/ResNet50_original/aug_oversample_includeNormal/epoch15.pth", None),
#                 model_config("ResNet50_original", "checkpoint/edema_PED/ResNet50_original/aug_oversample_includeNormal_20180924_150856/epoch54.pth", None),
#                 model_config("DenseNet121_original", "checkpoint/edema_PED/DenseNet121_original/aug_oversample_includeNormal/epoch13.pth", None),
#                 model_config("DenseNet121_original", "checkpoint/edema_PED/DenseNet121_original/aug_oversample_includeNormal_20180924_150921/epoch67.pth", None)
              ]

normal_edemas_models = [load_model(normal_edema, config.num_classes, device, gpus) for normal_edema in normal_edemas]
edema_SRFs_models = [load_model(edema_SRF, config.num_classes, device, gpus) for edema_SRF in edema_SRFs]
edema_PEDs_models = [load_model(edema_PED, config.num_classes, device, gpus) for edema_PED in edema_PEDs]

In [None]:
root_path = "./data/Edema_validationset/original_images"
sample_predictions, sample_names = [], []
for sample_name in tqdm(sorted(os.listdir(root_path))):
    sample_names.append(sample_name.replace(".img", ""))
    sample_path = os.path.join(root_path, sample_name)
    
    dataset = OCT_classification_persample(sample_path, 
                         transform = transforms.Compose([
                             Resize((config.target_h, config.target_w)),
                             Normalize_divide(255.0)
                         ]))

    dataset_loader = torch.utils.data.DataLoader(dataset, batch_size = config.batch_size,
                                             shuffle=False, num_workers=config.num_workers)
    predictions = main(normal_edemas_models, edema_PEDs_models, edema_SRFs_models, 
                       config.batch_size,
                       dataset_loader,
                       device,
                       num_classes = config.num_classes)
    sample_predictions.append(predictions)

In [None]:
# for sample_prediction in sample_predictions:
#     for i in range(128):
#         if sample_prediction[i][0] < 0.5 and sample_prediction[i][2] > 0.5:
#             print(sample_prediction[i])


In [None]:
# write_disk("./predictions/classification/test/lisijia/20180925_3", sample_names, sample_predictions)

In [None]:
vallabel_root = "./data/Edema_validationset/label_images"
valsample_labels = get_labels(vallabel_root)
valsample_aucs = [aic_fundus_lesion_classification(valsample_labels[i], sample_predictions[i]) for i in range(len(sample_predictions))]
valid_aucs = []
for sample_auc in valsample_aucs:
    for auc_value in sample_auc:
        if math.isnan(auc_value): continue
        valid_aucs.append(auc_value)
pprint(valsample_aucs)
print("mean auc: %.4f"%(np.mean(valid_aucs)))

In [None]:
valid_aucs = [[], [], []]
for sample_auc in valsample_aucs:
    for i, auc_value in enumerate(sample_auc):
        if not math.isnan(auc_value):
            valid_aucs[i].append(auc_value)
for auc_values in valid_aucs:
    print(np.mean(auc_values))

In [None]:
np.mean(valid_aucs[0])