# Applying Sufficient Input Subsets to Trained Classifiers

**This notebook will apply SIS to understand which subsets of features from ATAC-seq, RNA-seq, and the bimodal input are most needed to maintain similar classification results. These subsets are desirable for the following reasons:**
- 1. For the bimodal input, they suggest which of the features from ATAC-seq and RNA-seq are most needed. Theoretically, these two modalities contain the same information and RNA-seq may be able to contain the same (and more) information than ATAC-seq. We expect the majority of the features in these subsets to come from RNA-seq features, precluding the need for multimodal data in classification. However, ATAC-seq may contain features that RNA-seq doesn't capture
- 2. These subsets may suggest marker genes for a given cell type, which would we be useful for the Human Cell Atlas and any other project aiming to identify and describe cell types

**References:**
- SIS paper: https://arxiv.org/pdf/1810.03805.pdf
- GITHUB with more details: https://github.com/b-carter/SufficientInputSubsets

***
### SIS applied to a CNN trained on MNIST: (modify this for our project)
- Found this here: https://github.com/google-research/google-research/blob/master/sufficient_input_subsets/tutorials/sis_mnist_tutorial.ipynb

In [None]:
#@title Select a digit and threshold {run: "auto"}

DIGIT = 4  #@param ['0', 1', '2', '3', '4', '5', '6', '7', '8', '9'] {type:"raw"}

THRESHOLD = 0.7  #@param {type:"slider", min:0, max:1, step:0.1}

# Following the SIS paper, we use the mean pixel from training images as a mask.
FULLY_MASKED_IMAGE = np.full((28, 28, 1), np.mean(x_train))

# Helper function that selects the probability for a single class, from the
# softmax output.
def make_f_for_digit(digit, model):
    def f_digit(batch_of_inputs):
        return model.predict(
            batch_of_inputs,
            batch_size=min(784, len(batch_of_inputs)))[:, digit]
    return f_digit

# This function maps a list of images to a list of probabilities (probability of
# each image being a 4).
f_digit = make_f_for_digit(DIGIT, model)

# Helper function that filters input images to those the model predicts with
# high confidence (f(image) >= threshold).
def select_images_for_sis(inputs, f_digit, threshold):
    preds = f_digit(inputs)
    idxs = np.nonzero(preds >= threshold)[0]
    return inputs[idxs]

# Filter test images that the model classifies as 4 with high confidence.
high_confidence_images_for_digit = select_images_for_sis(x_test, f_digit,
                                                         THRESHOLD)

# Randomly select some of these digits to run SIS.
digits_to_run_sis = high_confidence_images_for_digit[
    np.random.choice(high_confidence_images_for_digit.shape[0],
                     size=5,
                     replace=False)]

# Helpers for plotting an MNIST digit and its corresponding SIS-collection.
def plot_mnist_digit(ax, image):
    ax.imshow(image[:, :, 0], cmap=plt.get_cmap('gray'))
    ax.axis('off')

def plot_sis_collection(initial_image, collection, fully_masked_image):
    # Grid contains initial image, an empty cell (for spacing), and collection.
    width = len(collection) + 2
    plt.figure(figsize=(width, 1))
    gs = plt.GridSpec(1, width, wspace=0.1)

    # Plot initial image.
    ax = plt.subplot(gs[0])
    plot_mnist_digit(ax, initial_image)

    # Plot each SIS.
    for i, sis_result in enumerate(collection):
        ax = plt.subplot(gs[i+2])
        masked_image = sis.produce_masked_inputs(
            initial_image, fully_masked_image, [sis_result.mask])[0]
        plot_mnist_digit(ax, masked_image)

    plt.show()

print('Running SIS on {} examples of digit {}. '
      'This might take a couple minutes.'.format(len(digits_to_run_sis), DIGIT))

# Run SIS on each digit and visualize the resulting SIS-collections.
for initial_image in digits_to_run_sis:
    collection = sis.sis_collection(f_digit, THRESHOLD, initial_image,
                                    FULLY_MASKED_IMAGE)
    plot_sis_collection(initial_image, collection, FULLY_MASKED_IMAGE)

### SIS for our cell-state classifier

In [3]:
######### Include code below ##########