In [65]:
import argparse
import os
import time
import sys

import numpy as np
import torch
import pickle

from meta_ood.meta_ood_master.config import config_evaluation_setup
from meta_ood.meta_ood_master.src.imageaugmentations import Compose, Normalize, ToTensor
# from meta_ood.meta_ood_master.src.model_utils import inference
from scipy.stats import entropy
# from meta_ood.meta_ood_master.src.calc import calc_precision_recall, calc_sensitivity_specificity
from meta_ood.meta_ood_master.src.helper import concatenate_metrics
from meta_ood.meta_ood_master.meta_classification import meta_classification
from meta_ood.meta_ood_master.UNet import ResNetUNet, convrelu

In [66]:
class eval_pixels(object):
    """
    Evaluate in vs. out separability on pixel-level
    """

    def __init__(self, params, roots, dataset):
        self.params = params
        self.epoch = params.val_epoch
        self.alpha = params.pareto_alpha
        self.batch_size = params.batch_size
        self.roots = roots
        self.dataset = dataset
        self.save_dir_data = os.path.join(self.roots.io_root, "results/entropy_counts_per_pixel")
        self.save_dir_plot = os.path.join(self.roots.io_root, "plots")
        print("Save dir of plots: {}".format(self.save_dir_plot))
        if self.epoch == 0:
            self.pattern = "baseline"
            self.save_path_data = os.path.join(self.save_dir_data, "baseline.p")
        else:
            self.pattern = "epoch_" + str(self.epoch) + "_alpha_" + str(self.alpha)
            self.save_path_data = os.path.join(self.save_dir_data, self.pattern + ".p")

    def counts(self, loader, num_bins=100, save_path=None, rewrite=False):
        """
        Count the number in-distribution and out-distribution pixels
        and get the networks corresponding confidence scores
        :param loader: dataset loader for evaluation data
        :param num_bins: (int) number of bins for histogram construction
        :param save_path: (str) path where to save the counts data
        :param rewrite: (bool) whether to rewrite the data file if already exists
        """
        print("\nCounting in-distribution and out-distribution pixels")
        if save_path is None:
            save_path = self.save_path_data
        if not os.path.exists(save_path) or rewrite:
            save_dir = os.path.dirname(save_path)
            if not os.path.exists(save_dir):
                print("Create directory", save_dir)
                os.makedirs(save_dir)
            bins = np.linspace(start=0, stop=1, num=num_bins + 1)
            counts = {"in": np.zeros(num_bins, dtype="int64"), "out": np.zeros(num_bins, dtype="int64")}
            inf = inference(self.params, self.roots, loader, self.dataset.num_eval_classes)
            print(inf.model_name)
            for i in range(len(loader)):
                probs, gt_train, _, _ = inf.probs_gt_load(i)
                ent = entropy(probs, axis=0) / np.log(self.dataset.num_eval_classes)
                counts["in"] += np.histogram(ent[gt_train == self.dataset.train_id_in], bins=bins, density=False)[0]
                counts["out"] += np.histogram(ent[gt_train == self.dataset.train_id_out], bins=bins, density=False)[0]
                print("\rImages Processed: {}/{}".format(i + 1, len(loader)), end=' ')
                sys.stdout.flush()
            torch.cuda.empty_cache()
            pickle.dump(counts, open(save_path, "wb"))
        print("Counts data saved:", save_path)

    def oodd_metrics_pixel(self, datloader=None, load_path=None):
        """
        Calculate 3 OoD detection metrics, namely AUROC, FPR95, AUPRC
        :param datloader: dataset loader
        :param load_path: (str) path to counts data (run 'counts' first)
        :return: OoD detection metrics
        """
        if load_path is None:
            load_path = self.save_path_data
        if not os.path.exists(load_path):
            if datloader is None:
                print("Please, specify dataset loader")
                exit()
            self.counts(loader=datloader, save_path=load_path)
        print("Load Path: {}".format(load_path))
        data = pickle.load(open(load_path, "rb"))
        fpr, tpr, _, auroc = calc_sensitivity_specificity(data, balance=True)
        fpr95 = fpr[(np.abs(tpr - 0.95)).argmin()]
        _, _, _, auprc = calc_precision_recall(data)
        if self.epoch == 0:
            print("\nOoDD Metrics - Epoch %d - Baseline" % self.epoch)
        else:
            print("\nOoDD Metrics - Epoch %d - Lambda %.2f" % (self.epoch, self.alpha))
        print("AUROC:", auroc)
        print("FPR95:", fpr95)
        print("AUPRC:", auprc)
        return auroc, fpr95, auprc

In [67]:
import os
import sys

import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from PIL import Image
class inference(object):

    def __init__(self, params, roots, loader, num_classes=None, init_net=True):
        self.epoch = params.val_epoch
        self.alpha = params.pareto_alpha
        self.batch_size = params.batch_size
        self.model_name = roots.model_name
        print(self.model_name)
        self.batch = 0
        self.batch_max = int(len(loader) / self.batch_size) + (len(loader) % self.batch_size > 0)
        self.loader = loader
        self.batchloader = iter(DataLoader(loader, batch_size=self.batch_size, shuffle=False))
        self.probs_root = os.path.join(roots.io_root, "probs")

        if self.epoch == 0:
            pattern = "baseline"
            ckpt_path = roots.init_ckpt
            self.probs_load_dir = os.path.join(self.probs_root, pattern)
        else:
            pattern = "epoch_" + str(self.epoch) + "_alpha_" + str(self.alpha)
            basename = self.model_name + "_" + pattern + ".pth"
            self.probs_load_dir = os.path.join(self.probs_root, pattern)
            ckpt_path = os.path.join(roots.weights_dir, basename)
        if init_net and num_classes is not None:
            # self.net = load_network(self.model_name, num_classes, ckpt_path)
            
            UNet_trained = ResNetUNet(19)
            UNet_untrained = ResNetUNet(19)

            checkpoint_path_trained="/work/pi_noah_daniels_uri_edu/said_harb_uri_edu_data/io/ood_detection/meta_ood_UNetResNet/weights/UNetResNet_epoch_76_alpha_0.9.pth"
            checkpoint_path_untrained="/work/pi_noah_daniels_uri_edu/said_harb_uri_edu_data/io/ood_detection/meta_ood_UNetResNet/weights/UNetResNet.pth"

            state_dict_trained=torch.load(checkpoint_path_trained)['state_dict']
            UNet_trained.load_state_dict(state_dict_trained)

            state_dict_untrained=torch.load(checkpoint_path_untrained)
            UNet_untrained.load_state_dict(state_dict_untrained)
            self.net=UNet_trained
            
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.net=self.net.to(device)

    def probs_gt_load(self, i, load_dir=None):
        if load_dir is None:
            load_dir = self.probs_load_dir
            print("Load Directory is None!")
        try:
            print("We are in the try of probs_gt_load")
            filename = os.path.join(load_dir, "probs" + str(i) + ".hdf5")
            print(filename)
            f_probs = h5py.File(filename, "r")
            probs = np.asarray(f_probs['probabilities'])
            gt_train = np.asarray(f_probs['gt_train_ids'])
            gt_label = np.asarray(f_probs['gt_label_ids'])
            probs = np.squeeze(probs)
            gt_train = np.squeeze(gt_train)
            gt_label = np.squeeze(gt_label)
            im_path = f_probs['image_path'][0].decode("utf8")
        except OSError:
            print("No probs file for image %d, therefore run inference..." % i)
            probs, gt_train, gt_label, im_path = self.prob_gt_calc(i)
        return probs, gt_train, gt_label, im_path

    def probs_gt_save(self, i, save_dir=None):
        if save_dir is None:
            save_dir = self.probs_load_dir
        if not os.path.exists(save_dir):
            print("Create directory:", save_dir)
            os.makedirs(save_dir)
        probs, gt_train, gt_label, im_path = self.prob_gt_calc(i)
        file_name = os.path.join(save_dir, "probs" + str(i) + ".hdf5")
        f = h5py.File(file_name, "w")
        f.create_dataset("probabilities", data=probs)
        f.create_dataset("gt_train_ids", data=gt_train)
        f.create_dataset("gt_label_ids", data=gt_label)
        f.create_dataset("image_path", data=[im_path.encode('utf8')])
        print("file stored:", file_name)
        f.close()

    def probs_gt_load_batch(self):
        assert self.batch_size > 1, "Please use batch size > 1 or use function 'probs_gt_load()' instead, bye bye..."
        x, y, z, im_paths = next(self.batchloader)
        print(x.shape)
        probs = prediction(self.net, x)
        gt_train = y.numpy()
        gt_label = z.numpy()
        self.batch += 1
        print("\rBatch %d/%d processed" % (self.batch, self.batch_max))
        sys.stdout.flush()
        return probs, gt_train, gt_label, im_paths

    def prob_gt_calc(self, i):
        x, y = self.loader[i]
        print(x.shape)
        probs = np.squeeze(prediction(self.net, x.unsqueeze_(0)))
        gt_train = y.numpy()
        try:
            gt_label = np.array(Image.open(self.loader.annotations[i]).convert('L'))
            print("Try Successful")
        except AttributeError:
            gt_label = np.zeros(gt_train.shape)
        im_path = self.loader.images[i]
        return probs, gt_train, gt_label, im_path

def prediction(net, image):
    image = image.cuda()
    with torch.no_grad():
        out = net(image)
    if isinstance(out, tuple):
        out = out[0]
    out = out.data.cpu()
    out = F.softmax(out, 1)
    return out.numpy()

In [68]:
from meta_ood.meta_ood_master.src.helper import counts_array_to_data_list
from sklearn.metrics import roc_curve, precision_recall_curve, average_precision_score, auc
def calc_precision_recall(data, balance=False):
    if balance:
        x1 = counts_array_to_data_list(np.array(data["in"]), 1e+5)
        x2 = counts_array_to_data_list(np.array(data["out"]), 1e+5)
    else:
        ratio_in = np.sum(data["in"]) / (np.sum(data["in"]) + np.sum(data["out"]))
        ratio_out = 1 - ratio_in
        x1 = counts_array_to_data_list(np.array(data["in"]), 1e+7 * ratio_in)
        x2 = counts_array_to_data_list(np.array(data["out"]), 1e+7 * ratio_out)
    probas_pred1 = np.array(x1) / 100
    probas_pred2 = np.array(x2) / 100
    y_true = np.concatenate((np.zeros(len(probas_pred1)), np.ones(len(probas_pred2))))
    y_scores = np.concatenate((probas_pred1, probas_pred2))
    return precision_recall_curve(y_true, y_scores) + (average_precision_score(y_true, y_scores), )

In [69]:
def calc_sensitivity_specificity(data, balance=False):
    if balance:
        x1 = counts_array_to_data_list(np.array(data["in"]), max_size=1e+5)
        x2 = counts_array_to_data_list(np.array(data["out"]), max_size=1e+5)
    else:
        x1 = counts_array_to_data_list(np.array(data["in"]))
        x2 = counts_array_to_data_list(np.array(data["out"]))
    probas_pred1 = np.array(x1) / 100
    probas_pred2 = np.array(x2) / 100
    y_true = np.concatenate((np.zeros(len(probas_pred1)), np.ones(len(probas_pred2)))).astype("uint8")
    y_scores = np.concatenate((probas_pred1, probas_pred2))
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    return fpr, tpr, thresholds, auc(fpr, tpr)

In [70]:
args={"TRAINSET":"Cityscapes+COCO",
      "VALSET":"Fishyscapes",
      "MODEL": "UNetResNet",# UNetResNet
      "val_epoch": 76,
      "pareto_alpha": 0.9,
      "pixel_eval":True,
      "segment_eval":False}

config = config_evaluation_setup(args)
if not args["pixel_eval"] and not args["segment_eval"]:
    args["pixel_eval"] = args["segment_eval"] = True

transform = Compose([ToTensor(), Normalize(config.dataset.mean, config.dataset.std)])
datloader = config.dataset(root=config.roots.eval_dataset_root, transform=transform)

start = time.time()

"""Perform evaluation"""
print("\nEVALUATE MODEL: ", config.roots.model_name)
if args["pixel_eval"]:
    print("\nPIXEL-LEVEL EVALUATION")
    eval_pixels(config.params, config.roots, config.dataset).oodd_metrics_pixel(datloader=datloader)

end = time.time()
hours, rem = divmod(end - start, 3600)
minutes, seconds = divmod(rem, 60)
print("\nFINISHED {:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds))



EVALUATE MODEL:  UNetResNet

PIXEL-LEVEL EVALUATION
Save dir of plots: /work/pi_noah_daniels_uri_edu/said_harb_uri_edu_data/io/ood_detection/meta_ood_UNetResNet/fs_eval/plots
Load Path: /work/pi_noah_daniels_uri_edu/said_harb_uri_edu_data/io/ood_detection/meta_ood_UNetResNet/fs_eval/results/entropy_counts_per_pixel/epoch_76_alpha_0.9.p

OoDD Metrics - Epoch 76 - Lambda 0.90
AUROC: 0.8773581436648372
FPR95: 0.3853149511765647
AUPRC: 0.1906332271956456

FINISHED 00:00:04.79
