In [None]:
import os, pdb, cv2, random, traceback, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from pprint import pprint, pformat
from tqdm import tqdm
from collections import OrderedDict, namedtuple
from sklearn.metrics import  roc_curve, auc

from dataloaders.Classification_Image import OCT_image_classification, OCT_classification_persample, natural_keys
from dataloaders.Image_transforms import Resize, Split_h, Normalize_divide, To_CHW    
from sklearn import metrics

from networks.classification.ResNet import ResNet34, ResNet50, ResNet101
from networks.classification.ResNet_original import ResNet34_original, ResNet50_original, ResNet18_original
from networks.classification.DenseNet_original import DenseNet121_original, DenseNet169_original, DenseNet201_original

from utils import aic_fundus_lesion_classification

def save_pickle(obj, save_path): 
    parent_path = os.path.dirname(save_path)  # get parent path
    if not os.path.exists(parent_path):
        os.makedirs(parent_path)
    cPickle.dump(obj, open(save_path, "wb"), True)

def load_pickle(pickle_path):
    with open(pickle_path, 'rb') as fo:
        pickle_dict = cPickle.load(fo)
    return pickle_dict

def load_model(model_config, num_classes, device, gpus):
    model = globals()[model_config.network](model_config.net_config, num_classes)
    print("Load %s"%(model_config.checkpoint))
    model.load_state_dict(torch.load(model_config.checkpoint))
    if len(gpus) > 1:
        model = nn.DataParallel(model, gpus)
    model.to(device)
    return model

def get_labels(label_root, label_dict = OrderedDict([(255, 0), (191, 1), (128, 2)]), num_samples=128):
    total_labels = []
    for label_sample_name in tqdm(sorted(os.listdir(label_root))):
        sample_labels = np.zeros((num_samples, 3))
        # sort the image dir in numerical ascend order
        image_names = os.listdir(os.path.join(label_root, label_sample_name))
        image_names.sort(key=natural_keys)
        for i, image_name in enumerate(image_names):
            label_image = cv2.imread(os.path.join(label_root, label_sample_name, image_name))[:,:,0]
            for target_pixel in label_dict:
                if target_pixel in label_image:
                    sample_labels[i, label_dict[target_pixel]] = 1
        total_labels.append(sample_labels)
    return total_labels

def predict_ensemble(models, inputs, device, last_activation="sigmoid"):
    """predict single input from the data loader via multiple models"""
    with torch.no_grad():
        inputs = inputs.float()
        inputs = inputs.to(device)
        softmaxs = []
        for model in models:
            model.eval()
            output = model(inputs)
            if last_activation == "sigmoid":
                softmax = torch.sigmoid(output).detach().cpu().numpy()
            elif last_activation == "softmax":
                softmax = F.softmax(output, dim=1).detach().cpu().numpy()
            else:
                raise("Unknown activation function in last layer: {}".format(last_activation))
            softmaxs.append(softmax)
        ensemble = np.mean(softmaxs, 0)
    return ensemble

def main(main_models,
         batch_size,
         dataloader,
         device,
         num_classes = 2):
    predictions = np.zeros((128, 3))
    for batch_idx, inputs in enumerate(dataloader):
        softmaxs = predict_ensemble(main_models, inputs, device, "sigmoid")
        for image_idx in range(inputs.size(0)):
            predictions[batch_idx*batch_size + image_idx, :] = softmaxs[image_idx, :]
    return predictions

def write_disk(target_root, sample_names, sample_predictions):
    for i in range(len(sample_names)):
        target_path = os.path.join(target_root, sample_names[i]+"_detections.npy")
        np.save(target_path, sample_predictions[i])

In [None]:
class Config(object):
    def __init__(self):
        self.batch_size = 240
        self.num_classes = 3
        
        self.target_h = 224
        self.target_w = 224
        
        self.gpus = "0,1,2,3"
        self.num_workers = 6

config = Config()        

In [None]:
gpus = map(int, config.gpus.split(","))
device = torch.device("cuda:{}".format(gpus[0]))

model_config = namedtuple("Model", ["network", "checkpoint", "net_config"])

model_configs = [
    model_config("ResNet50_original", "checkpoint/multiple_label/ResNet50_original/aug_multipleLabel/epoch102.pth", None),
    model_config("DenseNet121_original", "checkpoint/multiple_label/DenseNet121_original/aug_multipleLabel/epoch56.pth", None),
    model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel3/epoch2.pth", None),
    model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel2/epoch1.pth", None),
    model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel/epoch0.pth", None),

    model_config("ResNet18_original", "checkpoint/multiple_label/ResNet18_original/aug_multipleLabel/epoch8.pth", None),
    model_config("ResNet18_original", "checkpoint/multiple_label/ResNet18_original/aug_multipleLabel2/epoch162.pth", None),
    model_config("ResNet18_original", "checkpoint/multiple_label/ResNet18_original/aug_multipleLabel3/epoch141.pth", None),

    model_config("ResNet50_original", "checkpoint/multiple_label/ResNet50_original/aug_multipleLabel2/epoch17.pth", None),
    model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel4/epoch29.pth", None),
    model_config("ResNet34_original", "checkpoint/multiple_label/ResNet34_original/aug_multipleLabel5/epoch20.pth", None),

    model_config("DenseNet121_original", "checkpoint/multiple_label/DenseNet121_original/aug_multipleLabel1/epoch23.pth", None), #0.9863
    model_config("DenseNet121_original", "checkpoint/multiple_label/DenseNet121_original/aug_multipleLabel2/epoch8.pth", None), #0.9831
    model_config("DenseNet121_original", "checkpoint/multiple_label/DenseNet121_original/aug_multipleLabel3/epoch72.pth", None), #0.9811
    model_config("DenseNet121_original", "checkpoint/multiple_label/DenseNet121_original/aug_multipleLabel4/epoch17.pth", None), #0.9891
    model_config("DenseNet121_original", "checkpoint/multiple_label/DenseNet121_original/aug_multipleLabel5/epoch112.pth", None), #0.9859
    model_config("DenseNet121_original", "checkpoint/multiple_label/DenseNet121_original/aug_multipleLabel6/epoch11.pth", None), #0.9814
    
    model_config("DenseNet169_original", "checkpoint/multiple_label/DenseNet169_original/aug_multipleLabel/epoch138.pth", None), #0.9846
    model_config("DenseNet169_original", "checkpoint/multiple_label/DenseNet169_original/aug_multipleLabel1/epoch132.pth", None), #0.9869
    model_config("DenseNet169_original", "checkpoint/multiple_label/DenseNet169_original/aug_multipleLabel2/epoch42.pth", None), #0.9863
    model_config("DenseNet169_original", "checkpoint/multiple_label/DenseNet169_original/aug_multipleLabel3/epoch42.pth", None), #0.9792
    model_config("DenseNet169_original", "checkpoint/multiple_label/DenseNet169_original/aug_multipleLabel4/epoch151.pth", None), #0.9847
    model_config("DenseNet169_original", "checkpoint/multiple_label/DenseNet169_original/aug_multipleLabel5/epoch164.pth", None), #0.9813
    
    model_config("DenseNet201_original", "checkpoint/multiple_label/DenseNet201_original/aug_multipleLabel/epoch11.pth", None), 
    model_config("DenseNet201_original", "checkpoint/multiple_label/DenseNet201_original/aug_multipleLabel2/epoch21.pth", None), 
    
    model_config("DenseNet201_original", "checkpoint/multiple_label/DenseNet201_original/aug_multipleLabel4/epoch78.pth", None), 
]



main_models = [load_model(model_configObj, config.num_classes, device, gpus) for model_configObj in model_configs]

In [None]:
root_path = "./data/Edema_testset/original_images"
sample_predictions, sample_names = [], []
for sample_name in tqdm(sorted(os.listdir(root_path))):
    sample_names.append(sample_name.replace(".img", ""))
    sample_path = os.path.join(root_path, sample_name)
    
    dataset = OCT_classification_persample(sample_path, 
                         transform = transforms.Compose([
                             Resize((config.target_h, config.target_w)),
                             Normalize_divide(255.0)
                         ]))
    
    dataset_loader = torch.utils.data.DataLoader(dataset, batch_size = config.batch_size,
                                             shuffle=False, num_workers=config.num_workers)
    predictions = main(main_models, 
                       config.batch_size,
                       dataset_loader,
                       device,
                       num_classes = config.num_classes)
    sample_predictions.append(predictions)

In [None]:
write_disk("./predictions/classification/20181009", sample_names, sample_predictions)