In [None]:
import cv2

def remove_text(img):
    '''
    Attempts to remove bright textual artifacts from X-ray images. For example, many images indicate the right side of
    the body with a white 'R'. Works only for very bright text.
    :param img: Numpy array of image
    :return: Array of image with (ideally) any characters removed and inpainted
    '''
    mask = cv2.threshold(img, 230, 255, cv2.THRESH_BINARY)[1][:, :, 0].astype(np.uint8)
    img = img.astype(np.uint8)
    result = cv2.inpaint(img, mask, 10, cv2.INPAINT_NS).astype(np.float32)
    return result

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib as mpl
import datetime
import io
import os
import numpy as np
from sklearn.metrics import confusion_matrix, roc_curve
from skimage.segmentation import mark_boundaries

def visualize_explanation(orig_img, explanation, img_filename, label, probs, class_names, label_to_see='top', dir_path=None):
    '''
    Visualize an explanation for the prediction of a single X-ray image.
    :param orig_img: Original X-Ray image
    :param explanation: ImageExplanation object
    :param img_filename: Filename of the image explained
    :param label: Ground truth class of the example
    :param probs: Prediction probabilities
    :param class_names: Ordered list of class names
    :param label_to_see: Label to visualize in explanation
    :param dir_path: Path to directory where to save the generated image
    :return: Path to saved image
    '''

    # Plot original image on the left
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(orig_img)

    # Plot the image and its explanation on the right
    if label_to_see == 'top':
        label_to_see = explanation.top_labels[0]
    explanation.image = orig_img
    temp, mask = explanation.get_image_and_mask(label_to_see, positive_only=False, num_features=10,
                                                hide_rest=False)
    ax[1].imshow(mark_boundaries(temp, mask))

    # Display some information about the example
    pred_class = np.argmax(probs)
    fig.text(0.02, 0.8, "Prediction probabilities: " + str(['{:.2f}'.format(probs[i]) for i in range(len(probs))]),
             fontsize=10)
    fig.text(0.02, 0.82, "Predicted Class: " + str(pred_class) + ' (' + class_names[pred_class] + ')', fontsize=10)
    if label is not None:
        fig.text(0.02, 0.84, "Ground Truth Class: " + str(label) + ' (' + class_names[label] + ')', fontsize=10)
    fig.suptitle("LIME Explanation for image " + img_filename, fontsize=13)
    fig.tight_layout()

    # Save the image
    filename = None
    if dir_path is not None:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        filename = dir_path + img_filename.split('/')[-1] + '_exp_' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + '.png'
        plt.savefig(filename)
    return filename

In [None]:
import pandas as pd
import yaml
import os
import dill
import cv2
import numpy as np
from tqdm import tqdm
from datetime import datetime
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from lime.wrappers.scikit_image import SegmentationAlgorithm
#from src.data.preprocess import remove_text
from src.visualization.visualize import visualize_explanation


def predict_instance(x, model):
    '''
    Runs model prediction on 1 or more input images.
    :param x: Image(s) to predict
    :param model: A Keras model
    :return: A numpy array comprising a list of class probabilities for each prediction
    '''
    y = model.predict(x)  # Run prediction on the perturbations
    if y.shape[1] == 1:
        probs = np.concatenate([1.0 - y, y], axis=1)  # Compute class probabilities from the output of the model
    else:
        probs = y
    return probs


def predict_and_explain(x, model, exp, num_features, num_samples):
    '''
    Use the model to predict a single example and apply LIME to generate an explanation.
    :param x: Preprocessed image to predict
    :param model: The trained neural network model
    :param exp: A LimeImageExplainer object
    :param num_features: # of features to use in explanation
    :param num_samples: # of times to perturb the example to be explained
    :return: The LIME explainer for the instance
    '''

    def predict(x):
        '''
        Helper function for LIME explainer. Runs model prediction on perturbations of the example.
        :param x: List of perturbed examples from an example
        :return: A numpy array constituting a list of class probabilities for each predicted perturbation
        '''
        probs = predict_instance(x, model)
        return probs

    # Algorithm for superpixel segmentation. Parameters set to limit size of superpixels and promote border smoothness
    segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=2.25, max_dist=50, ratio=0.1, sigma=0.15)

    # Generate explanation for the example
    explanation = exp.explain_instance(x, predict, num_features=num_features, num_samples=num_samples, segmentation_fn=segmentation_fn)
    probs = predict_instance(np.expand_dims(x, axis=0), model)
    return explanation, probs


def predict_and_explain_set(raw_img_dir=None, preds_dir=None, save_results=True, give_explanations=True):
    '''
    Preprocess a raw dataset. Then get model predictions and corresponding explanations.
    :param raw_img_dir: Directory in which to look for raw images
    :param preds_dir: Path at which to save results of this prediction
    :param save_results: Flag specifying whether to save the prediction results to disk
    :param give_explanations: Flag specifying whether to provide LIME explanations with predictions spreadsheet
    :return: Dataframe of prediction results, optionally including explanations.
    '''

    # Load project config data
    cfg = yaml.full_load(open(os.getcwd() + "/config.yml", 'r'))
    cur_date = datetime.now().strftime('%Y%m%d-%H%M%S')

    # Restore the model, LIME explainer, and model class indices from their respective serializations
    model = load_model(cfg['PATHS']['MODEL_TO_LOAD'], compile=False)
    explainer = dill.load(open(cfg['PATHS']['LIME_EXPLAINER'], 'rb'))
    class_indices = dill.load(open(cfg['PATHS']['OUTPUT_CLASS_INDICES'], 'rb'))

    # Load LIME and prediction constants from config
    NUM_SAMPLES = cfg['LIME']['NUM_SAMPLES']
    NUM_FEATURES = cfg['LIME']['NUM_FEATURES']
    CLASS_NAMES = cfg['DATA']['CLASSES']

    # Define column names of the DataFrame representing the prediction results
    col_names = ['Image Filename', 'Predicted Class']
    for c in cfg['DATA']['CLASSES']:
        col_names.append('p(' + c + ')')

    # Add columns for client explanation
    if give_explanations:
        col_names.append('Explanation Filename')

    # Set raw image directory based on project config, if not specified
    if raw_img_dir is None:
        raw_img_dir = cfg['PATHS']['BATCH_PRED_IMGS']

    # If no path is specified, create new directory for predictions
    if preds_dir is None:
        preds_dir = cfg['PATHS']['BATCH_PREDS'] + '/' + cur_date + '/'
        if save_results and not os.path.exists(cfg['PATHS']['BATCH_PREDS'] + '/' + cur_date):
            os.mkdir(preds_dir)

    # Create DataFrame for raw image file names
    raw_img_df = pd.DataFrame({'filename': os.listdir(raw_img_dir)})
    raw_img_df = raw_img_df[raw_img_df['filename'].str.contains('jpg|png|jpeg', na=False)]   # Enforce image files

    # Create generator for the image files
    img_gen = ImageDataGenerator(preprocessing_function=remove_text, samplewise_std_normalization=True,
                                 samplewise_center=True)
    img_iter = img_gen.flow_from_dataframe(dataframe=raw_img_df, directory=raw_img_dir, x_col="filename",
                                           target_size=cfg['DATA']['IMG_DIM'], batch_size=1, class_mode=None,
                                           shuffle=False)

    # Predict (and optionally explain) all images in the specified directory
    rows = []
    print('Predicting and explaining examples.')

    for filename in raw_img_df['filename'].tolist():

        # Get preprocessed image and make a prediction.
        try:
            x = img_iter.next()
        except StopIteration:
            break
        y = np.squeeze(predict_instance(x, model))

        # Rearrange prediction probability vector to reflect original ordering of classes in project config
        p = [y[CLASS_NAMES.index(c)] for c in class_indices]
        predicted_class = CLASS_NAMES[np.argmax(p)]
        row = [filename, predicted_class]
        row.extend(list(p))

        # Explain this prediction
        if give_explanations:
            explanation, _ = predict_and_explain(np.squeeze(x, axis=0), model, explainer, NUM_FEATURES, NUM_SAMPLES)
            if cfg['LIME']['COVID_ONLY'] == True:
                label_to_see = class_indices['COVID-19']
            else:
                label_to_see = 'top'

            # Load and resize the corresponding original image (no preprocessing)
            orig_img = cv2.imread(raw_img_dir + filename)
            orig_img = cv2.resize(orig_img, tuple(cfg['DATA']['IMG_DIM']), interpolation=cv2.INTER_NEAREST)

            # Generate visual for explanation
            exp_filename = visualize_explanation(orig_img, explanation, filename, None, p, CLASS_NAMES,
                                                 label_to_see=label_to_see, dir_path=preds_dir)
            row.append(exp_filename.split('/')[-1])
        rows.append(row)

    # Convert results to a Pandas DataFrame and save
    results_df = pd.DataFrame(rows, columns=col_names)
    if save_results:
        results_path = preds_dir + 'predictions.csv'
        results_df.to_csv(results_path, columns=col_names, index_label=False, index=False)
    return results_df


In [None]:
results = predict_and_explain_set(preds_dir=None, save_results=True, give_explanations=True)