# This code implements a LIMEADE update using LIME and performs a simulated experiment using COCO segments.

In [None]:
from multiprocessing import Pool, get_context, Process, set_start_method
from skimage.segmentation import felzenszwalb, slic, quickshift
from nltk.classify.scikitlearn import SklearnClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics.pairwise import cosine_distances
from sklearn.linear_model import LogisticRegression
from scipy.ndimage.measurements import find_objects
from skimage.segmentation import mark_boundaries
from sklearn.utils import check_random_state
import matplotlib.patches as patches
from skimage.color import gray2rgb
from sklearn.svm import LinearSVC, SVC
from collections import ChainMap
import matplotlib.pyplot as plt
from matplotlib import cm
from random import choice
from PIL import Image 
import numpy as np
import inspect
import sklearn
import shutil
import random
import torch
import types
import glob
import json
import math
import json
import time
from tqdm import tqdm
import sys
import os
import re

# import pycocotools COCO API
from pycocotools.coco import COCO

# import some pytorch dependencies
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import torch

# handle pickle versioning issues
try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle
    
# import img2vec for LIME pseudo-examples (need to install on device)
from img2vec_pytorch import Img2Vec

# ignore all warnings
import warnings
warnings.filterwarnings('ignore')

# set random seed for reproducibility during development
random.seed(28)
np.random.seed(28)

###################### set some hyperparameters ########################

# define whether we are using ResNet 18 or ResNet 50:
RESNET_50_BOOL = True

# set +/- size for train and test sets
TRAIN_SIZE = 1
UPDATE_SIZE = 1 

# set # of nearest neighbors
N_NEAREST_NEIGHBORS = [1, 5, 10, 25, 50, 100]

# sets weights of nearest neighbors (collectively)
NEIGHBORS_WEIGHTS = [0.25, 0.5, 1, 2, 4]

# sets number of samples for LIME to use
N_LIME_SAMPLES = 500

# sets desired class
category_names = ["bed", "carrot", "cow", "donut", "fire hydrant", "fork", "frisbee", "giraffe", "horse", "knife", "motorcycle", "potted plant", "scissors", "sink", "skateboard", "snowboard", "suitcase", "surfboard", "toothbrush", "baseball glove"]

# sets number of runs to do
N_RUNS = 100
######################################################################

In [None]:
# load pre-processed IDs
if RESNET_50_BOOL:
    embeddings_paths = sorted(glob.glob("../COCO/train2017_embeddings/*_full_embedding_50.npy"))
    preprocessed_ids = []
    for i in range(0, len(embeddings_paths)):
        img_id = int(embeddings_paths[i].split('/')[-1].split('_full_embedding_50.npy')[0])
        preprocessed_ids.append(img_id)
else:
    embeddings_paths = sorted(glob.glob("../COCO/train2017_embeddings/*_full_embedding_18.npy"))
    preprocessed_ids = []
    for i in range(0, len(embeddings_paths)):
        img_id = int(embeddings_paths[i].split('/')[-1].split('_full_embedding_18.npy')[0])
        preprocessed_ids.append(img_id)
        
        
### Now, we load in all embeddings (this step takes a while and only needs to be run once): ####
# load pre-processed IDs
if RESNET_50_BOOL:
    all_embeddings_paths = sorted(glob.glob("../COCO/train2017_embeddings/*50.npy"))
else:
    all_embeddings_paths = sorted(glob.glob("../COCO/train2017_embeddings/*18.npy"))
        
 
# Handles whether to preprocess the embeddings and paths (default to not, as this is time-consuming; only re-run
# if new images added to unlabeled pool)
REDOWNLOAD = True

preprocessed_embeddings_paths = [] 
preprocessed_embeddings = []

if REDOWNLOAD: 

    for i in tqdm(range(0, len(preprocessed_ids))):
        index = preprocessed_ids[i]
        for j in range(0, len(all_embeddings_paths)):
            if "../COCO/train2017_embeddings/" + '{:>012d}'.format(index) in all_embeddings_paths[j]:
                preprocessed_embeddings_paths.append(all_embeddings_paths[j])
                # NOTE: CANNOT BREAK HERE!!!
                
    f = open("preprocessed_embeddings_paths.pkl", "wb")
    pickle.dump(preprocessed_embeddings_paths, f)
    
    for i in tqdm(range(0, len(preprocessed_embeddings_paths))):
        preprocessed_embeddings.append(np.load(preprocessed_embeddings_paths[i]))
    preprocessed_embeddings = np.array(preprocessed_embeddings)
    np.save("preprocessed_embeddings.npy", preprocessed_embeddings)

    
# saves having to re-run
else:
    preprocessed_embeddings_paths = pickle.load(open("preprocessed_embeddings_paths.pkl", "rb"))
    preprocessed_embeddings = np.load("preprocessed_embeddings.npy")
    
print("DONE LOADING EMBEDDINGS")
################################################################################################

### Next, we handle COCO category filtering & the function for mapping COCO segments to superpixels: ###

dataDir='..'
dataType='train2017'
annFile='../COCO/annotations/instances_' + dataType + '.json'

# initialize COCO api for instance annotations
coco=COCO(annFile)

# display COCO categories and supercategories
cats = coco.loadCats(coco.getCatIds())

# function for retrieving relevant superpixels for an image, given the image ID and category ID
def retrieve_relevant_superpixels(imgId, catId):
    
    # first, load relevant annotations
    annIds = coco.getAnnIds(catIds=[catId], imgIds=[imgId])
    anns = coco.loadAnns(annIds)
    
    # we load the superpixels for image
    segments_filename = "../COCO/train2017_segments/" + '{:>012d}'.format(imgId) + ".npy"

    # load superpixels
    segments = np.load(segments_filename)
    
    # load image
    img = coco.loadImgs([imgId])[0]
        
    # create image mask for segments
    mask = np.zeros((img['height'],img['width']))
        
    # add segments
    for ann in anns:
        
        # convert segments to pixel mask
        mask = np.maximum(coco.annToMask(ann), mask)

    # convert to correct dimensions (224 x 224) for comparing against segments
    transformed_mask = np.array(pill_transf(Image.fromarray(np.uint8(mask)*255)))

    # next, we find the ids that overlap with segments
    relevant_superpixels = np.unique(segments[transformed_mask != 0])

    return(relevant_superpixels)

###########################################################################################

def process_category(name):
        
    # get all images containing given category
    category = name
    catIds = coco.getCatIds(catNms=[category])
    imgIds = coco.getImgIds(catIds=catIds)
    
    # filter to find ones in preprocessed
    filtered_img_ids = []
    for i in range(0, len(imgIds)):
        if imgIds[i] in preprocessed_ids:
            filtered_img_ids.append(imgIds[i])
    
    # shuffle filtered IDs to get a new classifier each time
    np.random.shuffle(filtered_img_ids)
    
    # sets positive train ids to those in the selected class
    positive_train_ids = filtered_img_ids[:TRAIN_SIZE]
    print(positive_train_ids)
    print(filtered_img_ids[:5])
    
    # sets positive val ids to those in the selected class
    positive_update_ids = filtered_img_ids[TRAIN_SIZE:TRAIN_SIZE+UPDATE_SIZE]
    
    # we now randomly assign negative train ids
    negative_train_ids = []
    
    # need to copy or else append to positive_train_ids as well
    already_used = positive_train_ids.copy()
    
    for i in range(0, TRAIN_SIZE):
        random_negative = choice(preprocessed_ids)
        while (random_negative in already_used) or (random_negative in filtered_img_ids):
            random_negative = choice(preprocessed_ids)
        negative_train_ids.append(random_negative)
        already_used.append(random_negative)
        
    negative_update_ids = []
    
    for i in range(0, UPDATE_SIZE):
        random_negative = choice(preprocessed_ids)
        while (random_negative in already_used) or (random_negative in filtered_img_ids):
            random_negative = choice(preprocessed_ids)
        negative_update_ids.append(random_negative)
        already_used.append(random_negative)
    
    
    print("TRAIN IDS: ")
    print(positive_train_ids)
    print(negative_train_ids)
    print(positive_update_ids)
    print(negative_update_ids)
    
                
    used_ids = []
    used_ids += positive_train_ids.copy()
    used_ids += positive_update_ids.copy()
    used_ids += negative_train_ids.copy()
    used_ids += negative_update_ids.copy()
        
    unlabeled_pool = [x for x in preprocessed_ids if x not in used_ids]
    
    return [positive_train_ids, negative_train_ids, positive_update_ids, negative_update_ids, unlabeled_pool, filtered_img_ids]

####################################################################################

In [None]:
# modified LIME code:

import copy
from functools import partial

import numpy as np
import sklearn
from sklearn.utils import check_random_state
from skimage.color import gray2rgb
from tqdm.auto import tqdm


from lime import lime_base
from lime.wrappers.scikit_image import SegmentationAlgorithm


class ModifiedImageExplanation(object):
    def __init__(self, image, segments):
        """Init function.
        Args:
            image: 3d numpy array
            segments: 2d numpy array, with the output from skimage.segmentation
        """
        self.image = image
        self.segments = segments
        self.intercept = {}
        self.local_exp = {}
        self.local_pred = {}
        self.score = {}

    def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
                           num_features=5, min_weight=0.):
        """Init function.
        Args:
            label: label to explain
            positive_only: if True, only take superpixels that positively contribute to
                the prediction of the label.
            negative_only: if True, only take superpixels that negatively contribute to
                the prediction of the label. If false, and so is positive_only, then both
                negativey and positively contributions will be taken.
                Both can't be True at the same time
            hide_rest: if True, make the non-explanation part of the return
                image gray
            num_features: number of superpixels to include in explanation
            min_weight: minimum weight of the superpixels to include in explanation
        Returns:
            (image, mask), where image is a 3d numpy array and mask is a 2d
            numpy array that can be used with
            skimage.segmentation.mark_boundaries
        """
        if label not in self.local_exp:
            raise KeyError('Label not in explanation')
        if positive_only & negative_only:
            raise ValueError("Positive_only and negative_only cannot be true at the same time.")
        segments = self.segments
        image = self.image
        exp = self.local_exp[label]
        mask = np.zeros(segments.shape, segments.dtype)
        if hide_rest:
            temp = np.zeros(self.image.shape)
        else:
            temp = self.image.copy()
        if positive_only:
            fs = [x[0] for x in exp
                  if x[1] > 0 and x[1] > min_weight][:num_features]
        if negative_only:
            fs = [x[0] for x in exp
                  if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
        if positive_only or negative_only:
            for f in fs:
                temp[segments == f] = image[segments == f].copy()
                mask[segments == f] = 1
            return temp, mask
        else:
            for f, w in exp[:num_features]:
                if np.abs(w) < min_weight:
                    continue
                c = 0 if w < 0 else 1
                mask[segments == f] = -1 if w < 0 else 1
                temp[segments == f] = image[segments == f].copy()
                temp[segments == f, c] = np.max(image)
            return temp, mask

class ModifiedLimeImageExplainer(object):
    """Explains predictions on Image (i.e. matrix) data.
    For numerical features, perturb them by sampling from a Normal(0,1) and
    doing the inverse operation of mean-centering and scaling, according to the
    means and stds in the training data. For categorical features, perturb by
    sampling according to the training distribution, and making a binary
    feature that is 1 when the value is the same as the instance being
    explained."""

    def __init__(self, kernel_width=.25, kernel=None, verbose=False,
                 feature_selection='auto', random_state=None):
        """Init function.
        Args:
            kernel_width: kernel width for the exponential kernel.
            If None, defaults to sqrt(number of columns) * 0.75.
            kernel: similarity kernel that takes euclidean distances and kernel
                width as input and outputs weights in (0,1). If None, defaults to
                an exponential kernel.
            verbose: if true, print local prediction values from linear model
            feature_selection: feature selection method. can be
                'forward_selection', 'lasso_path', 'none' or 'auto'.
                See function 'explain_instance_with_data' in lime_base.py for
                details on what each of the options does.
            random_state: an integer or numpy.RandomState that will be used to
                generate random numbers. If None, the random state will be
                initialized using the internal numpy seed.
        """
        kernel_width = float(kernel_width)

        if kernel is None:
            def kernel(d, kernel_width):
                return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))

        kernel_fn = partial(kernel, kernel_width=kernel_width)

        self.random_state = check_random_state(random_state)
        self.feature_selection = feature_selection
        self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state)

    def explain_instance(self, image, embedding, classifier_fn, labels=(1,),
                         hide_color=None,
                         top_labels=5, num_features=100000, num_samples=1000,
                         batch_size=4,
                         segmentation_fn=None,
                         distance_metric='cosine',
                         model_regressor=None,
                         random_seed=None,
                         progress_bar=True):
        """Generates explanations for a prediction.
        First, we generate neighborhood data by randomly perturbing features
        from the instance (see __data_inverse). We then learn locally weighted
        linear models on this neighborhood data to explain each of the classes
        in an interpretable way (see lime_base.py).
        Args:
            image: 3 dimension RGB image. If this is only two dimensional,
                we will assume it's a grayscale image and call gray2rgb.
            classifier_fn: classifier prediction probability function, which
                takes a numpy array and outputs prediction probabilities.  For
                ScikitClassifiers , this is classifier.predict_proba.
            labels: iterable with labels to be explained.
            hide_color: TODO
            top_labels: if not None, ignore labels and produce explanations for
                the K labels with highest prediction probabilities, where K is
                this parameter.
            num_features: maximum number of features present in explanation
            num_samples: size of the neighborhood to learn the linear model
            batch_size: TODO
            distance_metric: the distance metric to use for weights.
            model_regressor: sklearn regressor to use in explanation. Defaults
            to Ridge regression in LimeBase. Must have model_regressor.coef_
            and 'sample_weight' as a parameter to model_regressor.fit()
            segmentation_fn: SegmentationAlgorithm, wrapped skimage
            segmentation function
            random_seed: integer used as random seed for the segmentation
                algorithm. If None, a random integer, between 0 and 1000,
                will be generated using the internal random number generator.
            progress_bar: if True, show tqdm progress bar.
        Returns:
            An ImageExplanation object (see lime_image.py) with the corresponding
            explanations.
        """
        if len(image.shape) == 2:
            image = gray2rgb(image)
        if random_seed is None:
            random_seed = self.random_state.randint(0, high=1000)

        if segmentation_fn is None:
            segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
                                                    max_dist=200, ratio=0.2,
                                                    random_seed=random_seed)
        try:
            segments = segmentation_fn(image)
        except ValueError as e:
            raise e

        fudged_image = image.copy()
        if hide_color is None:
            for x in np.unique(segments):
                fudged_image[segments == x] = (
                    np.mean(image[segments == x][:, 0]),
                    np.mean(image[segments == x][:, 1]),
                    np.mean(image[segments == x][:, 2]))
        else:
            fudged_image[:] = hide_color

        top = labels

        data, labels = self.data_labels(image, fudged_image, segments,
                                        classifier_fn, num_samples,
                                        batch_size=batch_size,
                                        progress_bar=progress_bar)

        distances = sklearn.metrics.pairwise_distances(
            data,
            data[0].reshape(1, -1),
            metric=distance_metric
        ).ravel()

        ret_exp = ModifiedImageExplanation(image, segments)
        if top_labels:
            top = np.argsort(labels[0])[-top_labels:]
            ret_exp.top_labels = list(top)
            ret_exp.top_labels.reverse()
        for label in top:
            (ret_exp.intercept[label],
             ret_exp.local_exp[label],
             ret_exp.score[label],
             ret_exp.local_pred[label]) = self.base.explain_instance_with_data(
                data, labels, distances, label, num_features,
                model_regressor=model_regressor,
                feature_selection=self.feature_selection)
        return ret_exp

    def data_labels(self,
                    image,
                    fudged_image,
                    segments,
                    classifier_fn,
                    num_samples,
                    batch_size=10, # altered for this version
                    progress_bar=True):
        """Generates images and predictions in the neighborhood of this image.
        Args:
            image: 3d numpy array, the image
            fudged_image: 3d numpy array, image to replace original image when
                superpixel is turned off
            segments: segmentation of the image
            classifier_fn: function that takes a list of images and returns a
                matrix of prediction probabilities
            num_samples: size of the neighborhood to learn the linear model
            batch_size: classifier_fn will be called on batches of this size.
            progress_bar: if True, show tqdm progress bar.
        Returns:
            A tuple (data, labels), where:
                data: dense num_samples * num_superpixels
                labels: prediction probabilities matrix
        """
        n_features = np.unique(segments).shape[0]
        data = self.random_state.randint(0, 2, num_samples * n_features)\
            .reshape((num_samples, n_features))
        labels = []
        data[0, :] = 1
        
        ######################## HACK HERE FOR EMBEDDINGS ####################
        imgs = []
        pil_imgs = []
        embeddings = []
        
        # load in img2vec
        # we choose resnet embeddings
        img2vec_resnet_50 = Img2Vec(cuda=True, model='resnet-50') 
        img2vec_resnet_18 = Img2Vec(cuda=True, model='resnet-18') 
        
        rows = tqdm(data) if progress_bar else data
        for row in rows:
            temp = copy.deepcopy(image)
            zeros = np.where(row == 0)[0]
            mask = np.zeros(segments.shape).astype(bool)
            for z in zeros:
                mask[segments == z] = True
            temp[mask] = fudged_image[mask]
            imgs.append(temp)
                                
            for img in imgs:
                pil_img = Image.fromarray(np.uint8(img))
                pil_imgs.append(pil_img)

            # generate embedding using img2vec
            if RESNET_50_BOOL:
                batch_embeddings_resnet = img2vec_resnet_50.get_vec(pil_imgs, tensor=False)
            else:
                batch_embeddings_resnet = img2vec_resnet_18.get_vec(pil_imgs, tensor=False)
            
            # update the embeddings array
            if len(embeddings) == 0:
                embeddings = batch_embeddings_resnet
            else:
                embeddings = np.concatenate((embeddings, batch_embeddings_resnet), axis=0)

            if len(embeddings) == batch_size:
                preds = classifier_fn(np.array(embeddings))
                labels.extend(preds)
                imgs = []
                pil_imgs = [] # need to reset the PIL versions of images and embeddings here, too
                embeddings = [] 
                
        if len(embeddings) > 0:
            preds = classifier_fn(np.array(embeddings))
            labels.extend(preds)
            
        ###################################################################

        return data, np.array(labels)

In [None]:
for category_name in category_names:
    
    print("STARTING CATEGORY: " + category_name)
    
    for run_index in range(0, N_RUNS):

        print("STARTING RUN " + str(run_index) + "...")

        ########### set a partition here ############
        partition = process_category(category_name)

        positive_train_ids = partition[0]
        negative_train_ids = partition[1]
        positive_update_ids = partition[2]
        negative_update_ids = partition[3]
        unlabeled_pool = partition[4] # stores indices of preprocessed but unlabeled images to draw from during update
        filtered_img_ids = partition[5] # need to remember all of the positive examples, too, to determine whether a prediction is correct

        intersections = np.intersect1d(filtered_img_ids, unlabeled_pool)

        ############################################

        # Next, we load in the relevant embeddings based on the unlabeled pool: #
        preprocessed_ids = np.array(preprocessed_ids, dtype=int)

        relevant_embeddings = []
        relevant_embeddings_paths = []

        for i in range(0, len(preprocessed_embeddings_paths)):
            path = preprocessed_embeddings_paths[i]
            path_index = int(path[29:41])

            if path_index in positive_train_ids:
                continue
            if path_index in negative_train_ids:
                continue
            if path_index in positive_update_ids:
                continue
            if path_index in negative_update_ids:
                continue
                
            relevant_embeddings.append(preprocessed_embeddings[i])
            relevant_embeddings_paths.append(preprocessed_embeddings_paths[i]) 

        relevant_embeddings = np.array(relevant_embeddings)

        ############################################################################


        ### Next, we train a linear model on the embeddings using the labels above and evaluate performance:
        positive_train_embeddings = []
        for i in range(0, len(positive_train_ids)):
            if RESNET_50_BOOL:
                filename = "../COCO/train2017_embeddings/" + '{:>012d}'.format(positive_train_ids[i]) + "_full_embedding_50.npy"
            else:
                filename = "../COCO/train2017_embeddings/" + '{:>012d}'.format(positive_train_ids[i]) + "_full_embedding_18.npy"
            embedding = np.load(filename)
            positive_train_embeddings.append(embedding)
        positive_train_embeddings = np.array(positive_train_embeddings)

        negative_train_embeddings = []
        for i in range(0, len(negative_train_ids)):
            if RESNET_50_BOOL:
                filename = "../COCO/train2017_embeddings/" + '{:>012d}'.format(negative_train_ids[i]) + "_full_embedding_50.npy"
            else:
                filename = "../COCO/train2017_embeddings/" + '{:>012d}'.format(negative_train_ids[i]) + "_full_embedding_18.npy"
            embedding = np.load(filename)
            negative_train_embeddings.append(embedding)
        negative_train_embeddings = np.array(negative_train_embeddings)

        # sets training data
        train_X = np.concatenate((positive_train_embeddings, negative_train_embeddings), axis=0)
        train_y = np.concatenate((np.ones(len(positive_train_embeddings)), np.zeros(len(negative_train_embeddings))))
        train_sample_weight = np.ones(len(positive_train_embeddings) + len(negative_train_embeddings))

        # sets val data from precomputed partitions
        val_X = np.load("splits/" + category_name + "_val_X.npy")
        val_y = np.load("splits/" + category_name + "_val_y.npy")
        
        # sets test data from precomputed partitions
        test_X = np.load("splits/" + category_name + "_test_X.npy")
        test_y = np.load("splits/" + category_name + "_test_y.npy")
        
        # instantiate linear model and train
        clf = LogisticRegression(random_state=0)
        clf.fit(train_X, train_y)

        # evaluate linear model's performance
        train_score = clf.score(train_X, train_y)
        original_val_score = clf.score(val_X, val_y)
        original_test_score = clf.score(test_X, test_y)
        
        print("SCORES: ")
        print(train_score, original_val_score, original_test_score)

        ############################################################################

        # Next, we randomly draw a NEGATIVE from the validation data to inspect and understand why the prediction has been made. 
        def get_image(index):
            filename = "../COCO/train2017/" + '{:>012d}'.format(index) + ".jpg"
            with open(os.path.abspath(filename), 'rb') as f:
                with Image.open(f) as img:
                    return img.convert('RGB') 

        def get_embedding(index):
            if RESNET_50_BOOL:
                filename = "../COCO/train2017_embeddings/" + '{:>012d}'.format(index) + "_full_embedding_50.npy"
            else:
                filename = "../COCO/train2017_embeddings/" + '{:>012d}'.format(index) + "_full_embedding_18.npy"
            embedding = np.load(filename)
            return embedding

        # first, let's randomly grab a negative image from VAL
        negative_instance_index = random.choice(negative_update_ids)

        negative_img = get_image(negative_instance_index)
        negative_instance_embedding = get_embedding(negative_instance_index)

        prediction = clf.predict(np.array([negative_instance_embedding]))

        ############################################################################

        ### Next, we define the image transform needed and then run the explain instance functionality from the modified LIME code:
        # need this transform defined in https://github.com/marcotcr/lime/blob/master/doc/notebooks/Tutorial%20-%20images%20-%20Pytorch.ipynb
        def get_pil_transform(): 
            transf = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.CenterCrop(224)
            ])    

            return transf 

        pill_transf = get_pil_transform()

        from lime import lime_image

        explainer = ModifiedLimeImageExplainer()
        explanation = explainer.explain_instance(np.array(pill_transf(negative_img)),
                                                 negative_instance_embedding, # pass in embedding here
                                                 clf.predict_proba, # classification function
                                                 top_labels=5, 
                                                 hide_color=0, 
                                                 batch_size=1,
                                                 num_samples=N_LIME_SAMPLES) # number of images that will be sent to classification function

        ############################################################################


        ### Next, we show the explanation and retrieve the corresponding crop from the preprocessed data:
        # retrieve explanation
        temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=1, hide_rest=False)

        # we now find the segment mask for the explanation
        exp_segment_rows, exp_segment_cols = np.where(mask != 0)

        # now, we load the mask from the corresponding image to map back to the crop
        segments_filename = "../COCO/train2017_segments/" + '{:>012d}'.format(negative_instance_index) + ".npy"
        segments = np.load(segments_filename)

        # for each pixel in the explanation mask, we retrieve the ID in the mask and check to see what it corresponds to
        segment_ids = []
        for i in range(0, len(exp_segment_rows)):
            segment_ids.append(segments[exp_segment_rows[i], exp_segment_cols[i]])

        # we now check to make sure that this is all one segment
        try:
            assert len(set(segment_ids)) == 1
        except:
            print("Explanation weirdness - multiple superpixels reported with top explanation ID - maybe strange tie?")

        # if so, we record the segment ID
        negative_explanation_segment_id = segment_ids[0]

        # we now retrieve the cropped image
        crop_filename = "../COCO/train2017_crops/" + '{:>012d}'.format(negative_instance_index) + "_" + str(negative_explanation_segment_id) + ".jpg"

        # save most-weighted superpixel in LIME explanation as superpixel to act on
        negative_action_segment_ids = [negative_explanation_segment_id]

        # thumbs-down this
        negative_feedback = 0
        ####################################################################################################################################

        ### Next, we show the superpixel acted on, along with the annotation, and set the embedding:

        mask = np.zeros(segments.shape)
        for negative_action_segment_id in negative_action_segment_ids:
            mask += (segments == negative_action_segment_id)

        # now, we generate bounding box around these segments
        boxes = find_objects(mask.astype(int))
        try:
            box = boxes[0]
        except:
            continue
        w1 = box[1].start
        w2 = box[1].stop
        h1 = box[0].start
        h2 = box[0].stop
        pil_im = Image.fromarray(np.uint8(np.array(pill_transf(get_image(negative_instance_index))))).convert('RGB')
        cropped = pil_im.crop((w1, h1, w2, h2))


        # we now feed it to the embedding model
        # load in img2vec
        # generate appropriate ResNet embedding
        if RESNET_50_BOOL:
            img2vec_resnet_50 = Img2Vec(cuda=True, model='resnet-50') 
            crop_embedding = img2vec_resnet_50.get_vec(cropped, tensor=False)
        else:
            img2vec_resnet_18 = Img2Vec(cuda=True, model='resnet-18') 
            crop_embedding = img2vec_resnet_18.get_vec(cropped, tensor=False)

        negative_superpixel_embedding = crop_embedding

        ####################################################################################################################################

        ### Next, we retrieve the most similar crops across all images:

        max_neighbors = np.max(np.array(N_NEAREST_NEIGHBORS))

        # next, we compute the nearest neighbors using the embeddings 
        cosine_vectors = relevant_embeddings - crop_embedding
        cosine_distances = np.sum(np.abs(cosine_vectors)**2,axis=-1) #**(1./2)
        negative_nearest_indices = np.argsort(cosine_distances)[:np.max(np.array(N_NEAREST_NEIGHBORS))]  
        
        negative_instance_cosine_vectors = relevant_embeddings - negative_instance_embedding
        negative_instance_cosine_distances = np.sum(np.abs(negative_instance_cosine_vectors)**2,axis=-1)
        negative_instance_nearest_indices = np.argsort(negative_instance_cosine_distances)[:np.max(np.array(N_NEAREST_NEIGHBORS))*1000]
        negative_instance_nearest_embeddings = []

        for i in range(0, len(negative_instance_nearest_indices)):
            # we want the full images here!
            if "_full" in relevant_embeddings_paths[negative_instance_nearest_indices[i]]:
                negative_instance_nearest_embeddings.append(np.load(relevant_embeddings_paths[negative_instance_nearest_indices[i]]))
                coco_id = int(relevant_embeddings_paths[negative_instance_nearest_indices[i]].split('_')[1].split("embeddings/")[1])
                if len(negative_instance_nearest_embeddings) == np.max(np.array(N_NEAREST_NEIGHBORS)):
                    break
        negative_instance_nearest_embeddings = np.array(negative_instance_nearest_embeddings)

        nn_full_image_embeddings = []
        nn_superpixel_embeddings = []

        # we now show the nearest crops and grab the relevant full image embeddings
        for i in range(0, max_neighbors):

            index = negative_nearest_indices[i] # if scikit-learn, use [0][i] here
            if RESNET_50_BOOL:
                nn_crop_filename = relevant_embeddings_paths[index].replace("_embedding_50.npy", ".jpg").replace("_embeddings", "_crops")
                chunks = relevant_embeddings_paths[index].split("_")
                nn_superpixel_embedding = np.load(relevant_embeddings_paths[index])
                nn_full_image_embedding = np.load(chunks[0] + "_" + chunks[1] + "_full_embedding_50.npy")
                nn_full_image_filename = (chunks[0] + "_" + chunks[1] + "_full.jpg").replace("_embeddings", "_crops")

                re.sub(r'^[0-9]{12}$', 'full', relevant_embeddings_paths[index])
            else:
                nn_crop_filename = relevant_embeddings_paths[index].replace("_embedding_18.npy", ".jpg").replace("_embeddings", "_crops")    
                chunks = relevant_embeddings_paths[index].split("_")
                nn_superpixel_embedding = np.load(relevant_embeddings_paths[index])
                nn_full_image_embedding = np.load(chunks[0] + "_" + chunks[1] + "_full_embedding_18.npy")  
                nn_full_image_filename = (chunks[0] + "_" + chunks[1] + "_full.jpg").replace("_embeddings", "_crops")

            # add to list of nearest neighbor full image embeddings
            nn_full_image_embeddings.append(nn_full_image_embedding)
            # add superpixel embeddings to list as well
            nn_superpixel_embeddings.append(nn_superpixel_embedding)

        negative_nn_embeddings_for_update = []

        # compute the corresponding full image embeddings for the update
        for i in range(0, len(N_NEAREST_NEIGHBORS)):
            negative_nn_embeddings_for_update.append(np.array(nn_full_image_embeddings[:N_NEAREST_NEIGHBORS[i]]))

        ####################################################################################################################################

        # Next, we randomly draw a POSITIVE from the validation data to inspect and understand why the prediction has been made. 

        ### We visualize the randomly-selected image and report the predicted class below:

        def get_embedding(index):
            if RESNET_50_BOOL:
                filename = "../COCO/train2017_embeddings/" + '{:>012d}'.format(index) + "_full_embedding_50.npy"
            else:
                filename = "../COCO/train2017_embeddings/" + '{:>012d}'.format(index) + "_full_embedding_18.npy"
            embedding = np.load(filename)
            return embedding

        # first, let's randomly grab a positive image from VAL
        positive_instance_index = random.choice(positive_update_ids)

        positive_img = get_image(positive_instance_index)
        positive_instance_embedding = get_embedding(positive_instance_index)

        prediction = clf.predict(np.array([positive_instance_embedding]))

        ####################################################################################################################################

        ### Next, we run LIME:

        # need this transform defined in https://github.com/marcotcr/lime/blob/master/doc/notebooks/Tutorial%20-%20images%20-%20Pytorch.ipynb
        def get_pil_transform(): 
            transf = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.CenterCrop(224)
            ])    

            return transf 

        pill_transf = get_pil_transform()

        from lime import lime_image

        explainer = ModifiedLimeImageExplainer()
        explanation = explainer.explain_instance(np.array(pill_transf(positive_img)),
                                                 positive_instance_embedding, # pass in embedding here
                                                 clf.predict_proba, # classification function
                                                 top_labels=5, 
                                                 hide_color=0, 
                                                 batch_size=1,
                                                 num_samples=N_LIME_SAMPLES) # number of images that will be sent to classification function


        ####################################################################################################################################

        ### Next, we show the explanation and retrieve the corresponding crop from the preprocessed data:

        # get LIME explanation
        temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=1, hide_rest=False)
        img_boundry2 = mark_boundaries(temp/255.0, mask)

        # we now find the segment mask for the explanation
        exp_segment_rows, exp_segment_cols = np.where(mask != 0)

        # now, we load the mask from the corresponding image to map back to the crop
        segments_filename = "../COCO/train2017_segments/" + '{:>012d}'.format(positive_instance_index) + ".npy"
        segments = np.load(segments_filename)

        # for each pixel in the explanation mask, we retrieve the ID in the mask and check to see what it corresponds to
        segment_ids = []
        for i in range(0, len(exp_segment_rows)):
            segment_ids.append(segments[exp_segment_rows[i], exp_segment_cols[i]])

        # we now check to make sure that this is all one segment
        try:
            assert len(set(segment_ids)) == 1
        except:
            print("Explanation weirdness - multiple superpixels reported with top explanation ID - maybe strange tie?")


        # if so, we record the segment ID
        positive_explanation_segment_id = segment_ids[0]

        # we now retrieve the cropped image
        crop_filename = "../COCO/train2017_crops/" + '{:>012d}'.format(positive_instance_index) + "_" + str(positive_explanation_segment_id) + ".jpg"

        catId = coco.getCatIds(catNms=[category_name])[0]

        # now, we set all superpixels containing that "object" as the superpixels to act on
        positive_action_segment_ids = retrieve_relevant_superpixels(positive_instance_index, catId)

        # set positive feedback as 1
        positive_feedback = 1

        ####################################################################################################################################

        # Next, we show the superpixel acted on, along with the annotation, and set the embedding:

        mask = np.zeros(segments.shape)
        for positive_action_segment_id in positive_action_segment_ids:
            mask += (segments == positive_action_segment_id)

        # now, we generate bounding box around these segments
        boxes = find_objects(mask.astype(int))
        try:
            box = boxes[0]
        except:
            continue
        w1 = box[1].start
        w2 = box[1].stop
        h1 = box[0].start
        h2 = box[0].stop
        pil_im = Image.fromarray(np.uint8(np.array(pill_transf(get_image(positive_instance_index))))).convert('RGB')
        cropped = pil_im.crop((w1, h1, w2, h2))

        # we now feed it to the embedding model
        # load in img2vec
        # generate appropriate ResNet embedding
        if RESNET_50_BOOL:
            img2vec_resnet_50 = Img2Vec(cuda=True, model='resnet-50') 
            crop_embedding = img2vec_resnet_50.get_vec(cropped, tensor=False)
        else:
            img2vec_resnet_18 = Img2Vec(cuda=True, model='resnet-18') 
            crop_embedding = img2vec_resnet_18.get_vec(cropped, tensor=False)

        positive_superpixel_embedding = crop_embedding

        ####################################################################################################################################


        ### Next, we retrieve the most similar crops across all images:

        # next, we compute the nearest neighbors using the embeddings 
        cosine_vectors = relevant_embeddings - crop_embedding
        cosine_distances = np.sum(np.abs(cosine_vectors)**2,axis=-1) #**(1./2)
        positive_nearest_indices = np.argsort(cosine_distances)[:np.max(np.array(N_NEAREST_NEIGHBORS))]
        
        positive_instance_cosine_vectors = relevant_embeddings - positive_instance_embedding
        positive_instance_cosine_distances = np.sum(np.abs(positive_instance_cosine_vectors)**2,axis=-1)
        positive_instance_nearest_indices = np.argsort(positive_instance_cosine_distances)[:np.max(np.array(N_NEAREST_NEIGHBORS))*1000]
        positive_instance_nearest_embeddings = []

        for i in range(0, len(positive_instance_nearest_indices)):
            # we want the full images here!
            if "_full" in relevant_embeddings_paths[positive_instance_nearest_indices[i]]:
                positive_instance_nearest_embeddings.append(np.load(relevant_embeddings_paths[positive_instance_nearest_indices[i]]))
                coco_id = int(relevant_embeddings_paths[positive_instance_nearest_indices[i]].split('_')[1].split("embeddings/")[1])
                if len(positive_instance_nearest_embeddings) == np.max(np.array(N_NEAREST_NEIGHBORS)):
                    break
        positive_instance_nearest_embeddings = np.array(positive_instance_nearest_embeddings)

        # print("Done identifying nearest neighbors....")

        nn_full_image_embeddings = []
        nn_superpixel_embeddings = []

        # we now show the nearest crops and grab the relevant full image embeddings
        for i in range(0, max_neighbors):

            index = positive_nearest_indices[i] # if scikit-learn, use [0][i] here
            if RESNET_50_BOOL:
                nn_crop_filename = relevant_embeddings_paths[index].replace("_embedding_50.npy", ".jpg").replace("_embeddings", "_crops")
                nn_superpixel_embedding = np.load(relevant_embeddings_paths[index])
                chunks = relevant_embeddings_paths[index].split("_")
                nn_full_image_embedding = np.load(chunks[0] + "_" + chunks[1] + "_full_embedding_50.npy")
                nn_full_image_filename = (chunks[0] + "_" + chunks[1] + "_full.jpg").replace("_embeddings", "_crops")

                re.sub(r'^[0-9]{12}$', 'full', relevant_embeddings_paths[index])
            else:
                nn_crop_filename = relevant_embeddings_paths[index].replace("_embedding_18.npy", ".jpg").replace("_embeddings", "_crops")    
                nn_superpixel_embedding = np.load(relevant_embeddings_paths[index])
                chunks = relevant_embeddings_paths[index].split("_")
                nn_full_image_embedding = np.load(chunks[0] + "_" + chunks[1] + "_full_embedding_18.npy")  
                nn_full_image_filename = (chunks[0] + "_" + chunks[1] + "_full.jpg").replace("_embeddings", "_crops")

            # add to list of nearest neighbor full image embeddings
            nn_full_image_embeddings.append(nn_full_image_embedding)
            # add superpixel embeddings as well
            nn_superpixel_embeddings.append(nn_superpixel_embedding)

        positive_nn_embeddings_for_update = []
        # compute the corresponding full image embeddings for the update
        for i in range(0, len(N_NEAREST_NEIGHBORS)):
            positive_nn_embeddings_for_update.append(np.array(nn_full_image_embeddings[:N_NEAREST_NEIGHBORS[i]]))
            
        ####################################################################################################################################

        # Now, we update the model and re-train based on the user feedback:

        # get gold label for instance
        negative_label = 0
        if negative_instance_index in filtered_img_ids:
            negative_label = 1
        positive_label = 0
        if positive_instance_index in filtered_img_ids:
            positive_label = 1

        negative_img = get_image(negative_instance_index)
        negative_instance_embedding = get_embedding(negative_instance_index)

        # now, we make sure the label and feedback are correct based on FP / FN designation
        assert negative_label == 0
        assert positive_label == 1
        assert negative_feedback == 0
        assert positive_feedback == 1
        
        new_val_scores_neighbors = []
        new_test_scores_neighbors = []

        # sets the number of nearest neighbors to consider
        for i in range(0, len(N_NEAREST_NEIGHBORS)):
            # set weight of neighbors relative to other examples in sample
            for NEIGHBORS_WEIGHT in NEIGHBORS_WEIGHTS:

                ### NOW, WE HANDLE NEAREST NEIGHBORS ###

                # instantiate linear model and train
                clf_additional_labels = np.zeros(N_NEAREST_NEIGHBORS[i]*2)
                clf_additional_labels[:N_NEAREST_NEIGHBORS[i]] += negative_feedback
                clf_additional_labels[N_NEAREST_NEIGHBORS[i]:] += positive_feedback

                clf = LogisticRegression(random_state=0)
                clf.fit(np.concatenate((train_X, np.concatenate((negative_nn_embeddings_for_update[i], positive_nn_embeddings_for_update[i]), axis=0)), axis=0), 
                        np.concatenate((train_y, clf_additional_labels), axis=0),
                        np.concatenate((train_sample_weight, np.zeros(N_NEAREST_NEIGHBORS[i]*2) + NEIGHBORS_WEIGHT/float(N_NEAREST_NEIGHBORS[i])), axis=0)
                        )
                # evaluate linear model's performance
                new_val_score = clf.score(val_X, val_y)
                new_val_scores_neighbors.append(new_val_score)
                new_test_score = clf.score(test_X, test_y)
                new_test_scores_neighbors.append(new_test_score)


        ####################################################################################################################################

        # Next, we evaluate the change in performance with adding the example to the training data for comparison:

        # instantiate linear model and train
        clf = LogisticRegression(random_state=0)

        clf.fit(np.concatenate((train_X, np.array([get_embedding(negative_instance_index), get_embedding(positive_instance_index)])), axis=0), 
                np.concatenate((train_y, np.array([negative_label, positive_label]))),
                np.concatenate((train_sample_weight, np.array([1., 1.])))
                )

        # evaluate linear model's performance
        new_val_score_added_instance = clf.score(val_X, val_y)
        new_test_score_added_instance = clf.score(test_X, test_y)

        ####################################################################################################################################

        # Lastly, we save the experiment output:

        # lastly, let us write the info to an out file

        metadata = {}
        metadata["RESNET_50_BOOL"] = RESNET_50_BOOL # ResNet-50 or ResNet-18?
        metadata["TRAIN_SIZE"] = TRAIN_SIZE # size of training data
        metadata["N_NEAREST_NEIGHBORS"] = N_NEAREST_NEIGHBORS # number of nearest neighbors to include
        metadata["NEIGHBORS_WEIGHTS"] = NEIGHBORS_WEIGHTS # weights of neighbors for updates (of various strengths)
        metadata["category_name"] = category_name # category name (i.e., "person")
        metadata["positive_train_ids"] = positive_train_ids # IDs of + examples in training data
        metadata["negative_train_ids"] = negative_train_ids # IDs of - examples in training data
        metadata["negative_instance_index"] = negative_instance_index # ID of FP instance being explained (a misclasification from val data)
        metadata["positive_instance_index"] = positive_instance_index # ID of FN instance being explained (a misclasification from val data)
        metadata["negative_explanation_segment_id"] = negative_explanation_segment_id # ID of the superpixel in FP explanation
        metadata["positive_explanation_segment_id"] = positive_explanation_segment_id # ID of the superpixel in FN explanation
        metadata["negative_action_segment_id"] = negative_action_segment_ids # IDs of the FP superpixel acted on
        metadata["positive_action_segment_id"] = positive_action_segment_ids # IDs of the FN superpixel acted on
        metadata["negative_nearest_indices"] = list(negative_nearest_indices) # IDs of nearest neighbors for FP (indexed into relevant_embeddings_paths) # slice [0] here if using scikit-learn nearest neighbors
        metadata["positive_nearest_indices"] = list(positive_nearest_indices) # IDs of nearest neighbors for FN (indexed into relevant_embeddings_paths) # slice [0] here if using scikit-learn nearest neighbors
        
        metadata["original_val_score"] = original_val_score # original performance of classifier on test set
        metadata["new_val_score_added_instance"] = new_val_score_added_instance # new performance of classifier on test set after update with corrected example appended to training data
        metadata["new_val_scores_neighbors"] = new_val_scores_neighbors # test scores for nearest neighbors update
        
        metadata["original_test_score"] = original_test_score # original performance of classifier on test set
        metadata["new_test_score_added_instance"] = new_test_score_added_instance # new performance of classifier on test set after update with corrected example appended to training data
        metadata["new_test_scores_neighbors"] = new_test_scores_neighbors # test scores for nearest neighbors update

        # https://stackoverflow.com/questions/11942364/typeerror-integer-is-not-json-serializable-when-serializing-json-in-python
        # need to handle numpy.int64 conversion to make dict JSON serializable
        def convert(o):
            if isinstance(o, np.int64): return int(o)  
            if isinstance(o, np.float32): return float(o)  

        # write to outfile with timestamp (bijection, so actual experiment time can be recovered easily)
        timestamp = round(time.time() * 1000)
        print("Saving to: " + str(timestamp) + ".json...")
        if not os.path.isdir('results/' + str(category_name)):
            os.mkdir('results/' + str(category_name))
        with open('results/' + str(category_name) + '/' + str(timestamp) + '.json', 'w') as f:
            json.dump(metadata, f, default=convert)

        ####################################################################################################################################
