# Applying Sufficient Input Subsets to Trained Classifiers

**This notebook will apply SIS to understand which subsets of features from ATAC-seq, RNA-seq, and the bimodal input are most needed to maintain similar classification results. These subsets are desirable for the following reasons:**
- 1. For the bimodal input, they suggest which of the features from ATAC-seq and RNA-seq are most needed. Theoretically, these two modalities contain the same information and RNA-seq may be able to contain the same (and more) information than ATAC-seq. We expect the majority of the features in these subsets to come from RNA-seq features, precluding the need for multimodal data in classification. However, ATAC-seq may contain features that RNA-seq doesn't capture
- 2. These subsets may suggest marker genes for a given cell type, which would we be useful for the Human Cell Atlas and any other project aiming to identify and describe cell types

**References:**
- SIS paper: https://arxiv.org/pdf/1810.03805.pdf
- GITHUB with more details: https://github.com/b-carter/SufficientInputSubsets

***
### SIS applied to a CNN trained on MNIST: (modify this for our project)
- Found this here: https://github.com/google-research/google-research/blob/master/sufficient_input_subsets/tutorials/sis_mnist_tutorial.ipynb

In [2]:
# Install sufficient input subsets from Google Research '''

# uncomment the following lines to install sufficient_input_subsets
!git clone https://github.com/google-research/google-research.git
%cd google-research

fatal: destination path 'google-research' already exists and is not an empty directory.
/Users/tjamesso/Desktop/MIT Courses/6_874_DL/Project/6_874-Multimodal-DL/google-research


In [3]:
from __future__ import print_function

import tensorflow as tf
import tensorflow.keras as K
import matplotlib.pyplot as plt
import numpy as np
import os
import time

from sufficient_input_subsets import sis
from helpers import *

preprocess imported
module name : helpers module package: 


### SIS for our cell-state classifier

In [4]:
######### Include code below ##########

In [7]:
# Load training and test data from pkl files

# get file paths
root = os.path.split(os.getcwd())[0]
pkl_path = os.path.join(root, 'data', 'sci-CAR', 'pkl_files')
pkl_atac = os.path.join(pkl_path, 'atac_pp_fs_.pkl')
pkl_rna = os.path.join(pkl_path, 'rna_pp_fs_.pkl')
pkl_bimodal = os.path.join(pkl_path, 'bimodal_pp_fs.pkl')
pkl_bi_low = os.path.join(pkl_path, 'bi_low_pp_fs.pkl')

# read pickle files
atac_train, atac_test = read_pickle(pkl_atac)
rna_train, rna_test = read_pickle(pkl_rna)
bimodal_train, bimodal_test = read_pickle(pkl_bimodal)
bi_low_train, bi_low_test = read_pickle(pkl_bi_low)

print('pickle files read')

# convert tensors to numpy arrays
atac_train_np, atac_test_np = atac_train.numpy(), atac_test.numpy()
rna_train_np, rna_test_np = rna_train.numpy(), rna_test.numpy()
bimodal_train_np, bimodal_test_np = bimodal_train.numpy(), bimodal_test.numpy()
bi_low_train_np, bi_low_test_np = bi_low_train.numpy(), bi_low_test.numpy()

pickle files read


In [8]:
atac_train.shape, rna_train.shape, bimodal_train.shape

(TensorShape([2376, 52761]),
 TensorShape([2376, 1185]),
 TensorShape([2376, 53946]))

In [9]:
# Get the masks (mean values of each training set)

# get the num of features in each matrix
n_atac_features = atac_train.shape[1]
n_rna_features = rna_train.shape[1]
n_bimodal_features = bimodal_train.shape[1]
n_bi_low_features = bi_low_train.shape[1]


# Following the SIS paper, we use the mean pixel from training data as a mask.
ATAC_MASK = np.full((n_atac_features, 1), np.mean(atac_train_np))
RNA_MASK = np.full((n_rna_features, 1), np.mean(rna_train_np))
BIMODAL_MASK = np.full((n_bimodal_features, 1), np.mean(bimodal_train_np))
BI_LOW_MASK = np.full((n_bi_low_features, 1), np.mean(bi_low_train_np))



In [10]:
root = os.path.split(os.getcwd())[0]

In [11]:
# Load the models

# get directories
atac_model_dir = os.path.join(root, 'models', 'best_atac_model_pp_fs')
rna_model_dir = os.path.join(root, 'models', 'best_rna_model_pp_fs')
bimodal_model_dir = os.path.join(root, 'models', 'best_bimodal_model_pp_fs')
bi_low_model_dir = os.path.join(root, 'models', 'best_bi_low_model_pp_fs')


# load models
atac_model  = K.models.load_model(atac_model_dir)
rna_model  = K.models.load_model(rna_model_dir)
bimodal_model = K.models.load_model(bimodal_model_dir)
bi_low_model = K.models.load_model(bi_low_model_dir)

print('models loaded')

models loaded


In [12]:
# Set parameters
HOURS = 0 # choose from [0, 1, 2] which corresponds to [0hr, 1hr, 3hr]
THRESHOLD = 0.7  #@param {type:"slider", min:0, max:1, step:0.1}

In [13]:
# Helper function that selects the probability for a single class, from the
# softmax output.
def make_f_for_digit(digit, model):
    def f_digit(batch_of_inputs):
        return model.predict(
            batch_of_inputs,
            batch_size=min(784, len(batch_of_inputs)))[:, digit]
    return f_digit

# This function maps a list of images to a list of probabilities (probability of
# each image being a 4).
f_digit = make_f_for_digit(HOURS, rna_model)

In [14]:
# Helper function that filters input images to those the model predicts with
# high confidence (f(image) >= threshold).
def select_images_for_sis(inputs, f_digit, threshold):
    preds = f_digit(inputs)
    idxs = np.nonzero(preds >= threshold)[0]
    return inputs[idxs], preds[idxs]

# Filter test images that the model classifies as 4 with high confidence.
high_confidence_cells = select_images_for_sis(rna_test_np, f_digit,
                                                         THRESHOLD)

In [15]:
# take a look at the samples and predictions classified with high confidence

hc_samples, hc_preds = high_confidence_cells
print(f'Found {hc_preds.shape[0]} cells >= {THRESHOLD} confidence')

Found 57 cells >= 0.7 confidence


In [16]:
# Randomly select some of these digits to run SIS.
cells_to_run_sis = hc_samples[
    np.random.choice(hc_samples.shape[0],
                     size=5,
                     replace=False)]

In [17]:
RNA_MASK.shape, rna_test_np.shape, rna_train_np.shape
RNA_MASK

array([[0.46075094],
       [0.46075094],
       [0.46075094],
       ...,
       [0.46075094],
       [0.46075094],
       [0.46075094]], dtype=float32)

In [18]:
# Run SIS on each digit and visualize the resulting SIS-collections.
begin = time.time()
collections = []
for initial_cell in tqdm(cells_to_run_sis):
    begin_loop = time.time()
    collection = sis.sis_collection(f_digit, THRESHOLD, initial_cell,
                                    np.squeeze(RNA_MASK))
    collections.append(collection)
    end_loop = time.time()
    print(f'Loop completed in {end_loop - begin_loop}')
end = time.time()
#     plot_sis_collection(initial_image, collection, FULLY_MASKED_IMAGE)
print(f'Cell complete in {end-begin} seconds')

  0%|          | 0/5 [00:00<?, ?it/s]



 80%|████████  | 4/5 [05:10<01:22, 82.16s/it]

Loop completed in 64.65955710411072


100%|██████████| 5/5 [05:44<00:00, 68.93s/it]

Loop completed in 33.886642932891846
Cell complete in 344.6346788406372 seconds





### SIS Output Analysis
Each SIS Output comes as an interable of length 4 with the following entries:
0. sis: Sufficient input subset, ordered by most important features
1. The order of features removed by back selection -- the most important features are at the end
2. The resulting probabilities if you remove up to that feature
3. Mask (probably not important for this analysis)

[SISResult(sis=array([[522],
        [ 94],
        [ 62],
        [455],
        [734],
        [ 30]]), ordering_over_entire_backselect=array([[ 861],
        [ 896],
        [1104],
        ...,
        [  62],
        [  94],
        [ 522]]), values_over_entire_backselect=array([0.97537369, 0.98489261, 0.99063373, ..., 0.13124144, 0.05699748,
        0.02130504]), mask=array([False, False, False, ..., False, False, False]))]

In [70]:
# Analyze collections

# Get all subsets into a nice dataframe


***
## Reference code from Google Research -- running SIS on MNIST

In [None]:
#@title Select a digit and threshold {run: "auto"}

DIGIT = 4  #@param ['0', 1', '2', '3', '4', '5', '6', '7', '8', '9'] {type:"raw"}

THRESHOLD = 0.7  #@param {type:"slider", min:0, max:1, step:0.1}

# Following the SIS paper, we use the mean pixel from training images as a mask.
FULLY_MASKED_IMAGE = np.full((28, 28, 1), np.mean(x_train))

# Helper function that selects the probability for a single class, from the
# softmax output.
def make_f_for_digit(digit, model):
    def f_digit(batch_of_inputs):
        return model.predict(
            batch_of_inputs,
            batch_size=min(784, len(batch_of_inputs)))[:, digit]
    return f_digit

# This function maps a list of images to a list of probabilities (probability of
# each image being a 4).
f_digit = make_f_for_digit(DIGIT, model)

# Helper function that filters input images to those the model predicts with
# high confidence (f(image) >= threshold).
def select_images_for_sis(inputs, f_digit, threshold):
    preds = f_digit(inputs)
    idxs = np.nonzero(preds >= threshold)[0]
    return inputs[idxs]

# Filter test images that the model classifies as 4 with high confidence.
high_confidence_images_for_digit = select_images_for_sis(x_test, f_digit,
                                                         THRESHOLD)

# Randomly select some of these digits to run SIS.
digits_to_run_sis = high_confidence_images_for_digit[
    np.random.choice(high_confidence_images_for_digit.shape[0],
                     size=5,
                     replace=False)]

# Helpers for plotting an MNIST digit and its corresponding SIS-collection.
def plot_mnist_digit(ax, image):
    ax.imshow(image[:, :, 0], cmap=plt.get_cmap('gray'))
    ax.axis('off')

def plot_sis_collection(initial_image, collection, fully_masked_image):
    # Grid contains initial image, an empty cell (for spacing), and collection.
    width = len(collection) + 2
    plt.figure(figsize=(width, 1))
    gs = plt.GridSpec(1, width, wspace=0.1)

    # Plot initial image.
    ax = plt.subplot(gs[0])
    plot_mnist_digit(ax, initial_image)

    # Plot each SIS.
    for i, sis_result in enumerate(collection):
        ax = plt.subplot(gs[i+2])
        masked_image = sis.produce_masked_inputs(
            initial_image, fully_masked_image, [sis_result.mask])[0]
        plot_mnist_digit(ax, masked_image)

    plt.show()

print('Running SIS on {} examples of digit {}. '
      'This might take a couple minutes.'.format(len(digits_to_run_sis), DIGIT))

# Run SIS on each digit and visualize the resulting SIS-collections.
for initial_image in digits_to_run_sis:
    collection = sis.sis_collection(f_digit, THRESHOLD, initial_image,
                                    FULLY_MASKED_IMAGE)
    plot_sis_collection(initial_image, collection, FULLY_MASKED_IMAGE)