# Applying Sufficient Input Subsets to Trained Classifiers

**This notebook will apply SIS to understand which subsets of features from ATAC-seq, RNA-seq, and the bimodal input are most needed to maintain similar classification results. These subsets are desirable for the following reasons:**
- 1. For the bimodal input, they suggest which of the features from ATAC-seq and RNA-seq are most needed. Theoretically, these two modalities contain the same information and RNA-seq may be able to contain the same (and more) information than ATAC-seq. We expect the majority of the features in these subsets to come from RNA-seq features, precluding the need for multimodal data in classification. However, ATAC-seq may contain features that RNA-seq doesn't capture
- 2. These subsets may suggest marker genes for a given cell type, which would we be useful for the Human Cell Atlas and any other project aiming to identify and describe cell types

**References:**
- SIS paper: https://arxiv.org/pdf/1810.03805.pdf
- GITHUB with more details: https://github.com/b-carter/SufficientInputSubsets

***
### SIS applied to a CNN trained on MNIST: (modify this for our project)
- Found this here: https://github.com/google-research/google-research/blob/master/sufficient_input_subsets/tutorials/sis_mnist_tutorial.ipynb

In [1]:
# Install sufficient input subsets from Google Research '''

# uncomment the following lines to install sufficient_input_subsets
!git clone https://github.com/google-research/google-research.git
%cd google-research

fatal: destination path 'google-research' already exists and is not an empty directory.
/Users/tjamesso/Desktop/MIT Courses/6_874_DL/Project/6_874-Multimodal-DL/google-research


In [2]:
from __future__ import print_function

import tensorflow as tf
import tensorflow.keras as K
import matplotlib.pyplot as plt
import numpy as np
import os
import time

from sufficient_input_subsets import sis
from helpers import *

preprocess imported
module name : helpers module package: 


### SIS for our cell-state classifier

In [3]:
######### Include code below ##########

In [168]:
# Load training and test data from pkl files

# get file paths
root = os.path.split(os.getcwd())[0]
pkl_path = os.path.join(root, 'data', 'sci-CAR', 'pkl_files')
pkl_atac = os.path.join(pkl_path, 'atacRaw_upSampled.pkl')
pkl_rna = os.path.join(pkl_path, 'rnaRaw_upSampled.pkl')
pkl_bimodal = os.path.join(pkl_path, 'bimodalRaw_upSampled.pkl')
pkl_bi_low = os.path.join(pkl_path, 'bimodal_cellLoad_upSampled.pkl')

# read pickle files
atac_train, atac_test, atac_features = read_pickle(pkl_atac)
rna_train, rna_test, rna_features = read_pickle(pkl_rna)
bimodal_train, bimodal_test, bimodal_features = read_pickle(pkl_bimodal)
bi_low_train, bi_low_test, bi_low_features = read_pickle(pkl_bi_low)

print('pickle files read')

# convert tensors to numpy arrays
atac_train_np, atac_test_np = atac_train.numpy(), atac_test.numpy()
rna_train_np, rna_test_np = rna_train.numpy(), rna_test.numpy()
bimodal_train_np, bimodal_test_np = bimodal_train.numpy(), bimodal_test.numpy()
bi_low_train_np, bi_low_test_np = bi_low_train.numpy(), bi_low_test.numpy()

pickle files read


In [169]:
atac_train.shape, rna_train.shape, bimodal_train.shape, bi_low_train.shape

(TensorShape([2582, 52761]),
 TensorShape([2582, 1185]),
 TensorShape([2582, 53946]),
 TensorShape([2902, 20]))

In [170]:
# Get the masks (mean values of each training set)

# get the num of features in each matrix
n_atac_features = atac_train.shape[1]
n_rna_features = rna_train.shape[1]
n_bimodal_features = bimodal_train.shape[1]
n_bi_low_features = bi_low_train.shape[1]


# Following the SIS paper, we use the mean pixel from training data as a mask.
ATAC_MASK = np.full((n_atac_features, 1), np.mean(atac_train_np))
RNA_MASK = np.full((n_rna_features, 1), np.mean(rna_train_np))
BIMODAL_MASK = np.full((n_bimodal_features, 1), np.mean(bimodal_train_np))
BI_LOW_MASK = np.full((n_bi_low_features, 1), np.mean(bi_low_train_np))



In [171]:
print(n_atac_features)
print(n_rna_features)
print(n_bimodal_features)
print(n_bi_low_features)

52761
1185
53946
20


In [172]:
root = os.path.split(os.getcwd())[0]

In [174]:
# Load the models

# get directories
atac_model_dir = os.path.join(root, 'models', 'Archive', 'best_atacRaw_upSampled_model')
rna_model_dir = os.path.join(root, 'models', 'Archive', 'best_rnaRaw_upSampled_model')
bimodal_model_dir = os.path.join(root, 'models', 'Archive', 'best_bimodal_upSampled_model')
bi_low_model_dir = os.path.join(root, 'models', 'Archive', 'best_bimodal_cellLoad_upSampled_model')


# load models
atac_model  = K.models.load_model(atac_model_dir)
rna_model  = K.models.load_model(rna_model_dir)
bimodal_model = K.models.load_model(bimodal_model_dir)
bi_low_model = K.models.load_model(bi_low_model_dir)

print('models loaded')

models loaded


In [203]:
# Set parameters
HOURS = 2 # choose from [0, 1, 2] which corresponds to [0hr, 1hr, 3hr]
THRESHOLD = 0.7  #@param {type:"slider", min:0, max:1, step:0.1}
MODEL = rna_model
TEST_SET = rna_test_np
MASK = RNA_MASK
FEATURE_VEC = rna_features
TITLE = 'RNA Raw'

In [204]:
THRESHOLD

0.7

In [205]:
# Helper function that selects the probability for a single class, from the
# softmax output.
def make_f_for_digit(digit, model):
    def f_digit(batch_of_inputs):
        return model.predict(
            batch_of_inputs,
            batch_size=min(784, len(batch_of_inputs)))[:, digit]
    return f_digit

# This function maps a list of images to a list of probabilities (probability of
# each image being a 4).
f_digit = make_f_for_digit(HOURS, MODEL)

In [206]:
# Helper function that filters input images to those the model predicts with
# high confidence (f(image) >= threshold).
def select_images_for_sis(inputs, f_digit, threshold):
    preds = f_digit(inputs)
    idxs = np.nonzero(preds >= threshold)[0]
    return inputs[idxs], preds[idxs]

# Filter test images that the model classifies as 4 with high confidence.
high_confidence_cells = select_images_for_sis(TEST_SET, f_digit,
                                                         THRESHOLD)

In [207]:
# take a look at the samples and predictions classified with high confidence

hc_samples, hc_preds = high_confidence_cells
print(f'Found {hc_preds.shape[0]} cells >= {THRESHOLD} confidence')

Found 103 cells >= 0.7 confidence


In [215]:
hc_preds

array([0.9976012 , 0.96586925, 0.9993926 , 0.99842477, 0.9998272 ,
       0.99827373, 0.99877816, 0.9998398 , 0.9948368 , 0.9978096 ,
       0.8405217 , 0.9996517 , 0.7881396 , 0.9998242 , 0.99765766,
       0.99863845, 0.9931543 , 0.9950682 , 0.99349517, 0.99988425,
       0.99986684, 0.99932146, 0.9993299 , 0.99999833, 0.9967277 ,
       0.7025534 , 0.99702793, 0.9999447 , 0.9981192 , 0.9988502 ,
       0.9913591 , 0.99999726, 0.9998596 , 0.999694  , 0.9991084 ,
       0.91017216, 0.9685043 , 0.99781585, 0.9981871 , 0.98401225,
       0.9999182 , 0.9995876 , 0.81144524, 0.7733769 , 0.9862895 ,
       0.99928856, 0.9996793 , 0.9989491 , 0.99851197, 0.9957625 ,
       0.8860866 , 0.78436047, 0.9995573 , 0.9988908 , 0.9996855 ,
       0.99999297, 0.9997377 , 0.9974618 , 0.99632794, 0.9996006 ,
       0.9995221 , 0.859427  , 0.9908602 , 0.99588984, 0.9998354 ,
       0.9993368 , 0.8837964 , 0.999998  , 0.9961836 , 0.95274043,
       0.9996762 , 0.9966247 , 0.999126  , 0.99093306, 0.99985

In [208]:
# # Randomly select some of these digits to run SIS.
# cells_to_run_sis = hc_samples[
#     np.random.choice(hc_samples.shape[0],
#                      size=5,
#                      replace=False)]

In [209]:
# OR run all cells through SIS
cells_to_run_sis = tf.convert_to_tensor(high_confidence_cells[0])

In [210]:
# Run SIS on each digit and visualize the resulting SIS-collections.
begin = time.time()
collections = []
for initial_cell in cells_to_run_sis:
    begin_loop = time.time()
    collection = sis.sis_collection(f_digit, THRESHOLD, initial_cell,
                                    np.squeeze(MASK))
    collections.append(collection)
    end_loop = time.time()
    print(f'Loop completed in {end_loop - begin_loop}')
end = time.time()
#     plot_sis_collection(initial_image, collection, FULLY_MASKED_IMAGE)
print(f'Cell complete in {end-begin} seconds')

Loop completed in 25.36620593070984
Loop completed in 23.25635313987732
Loop completed in 23.580237865447998
Loop completed in 26.5559139251709
Loop completed in 24.454350233078003
Loop completed in 23.697070121765137
Loop completed in 23.95415687561035
Loop completed in 26.453312158584595
Loop completed in 22.88639998435974
Loop completed in 22.92171311378479
Loop completed in 22.813381910324097
Loop completed in 22.54314374923706
Loop completed in 24.342291831970215
Loop completed in 22.770493745803833
Loop completed in 22.46404504776001
Loop completed in 22.618826866149902
Loop completed in 24.66104292869568
Loop completed in 26.004273176193237
Loop completed in 24.972266912460327
Loop completed in 22.396268129348755
Loop completed in 23.614142179489136
Loop completed in 22.15380597114563
Loop completed in 23.419946670532227
Loop completed in 23.61780285835266
Loop completed in 24.359530925750732
Loop completed in 23.04483199119568
Loop completed in 23.00208830833435
Loop completed 

### SIS Output Analysis
Each SIS Output comes as an interable of length 4 with the following entries:
0. sis: Sufficient input subset, ordered by most important features
1. The order of features removed by back selection -- the most important features are at the end
2. The resulting probabilities if you remove up to that feature
3. Mask (probably not important for this analysis)

In [211]:
# Analyze collections
def get_features_from_ixs(ix, feature_vec):
    return feature_vec[ix].to_numpy()


# Get all subsets into a nice dataframe
def get_features_from_collection(collection, feature_vec):
    features_mat = []
    for ix, cell in enumerate(collection):
        for disjoint_subset in cell:
            sis_ixs = disjoint_subset[0]
            features = get_features_from_ixs(sis_ixs, feature_vec)
            features_mat.append((ix, features))
    return features_mat


# get a list of most important features
important_features = get_features_from_collection(collections, FEATURE_VEC)

def count_genes(important_features):
    gene_dict = {}
    for ix, features in important_features:
        for f in features:
            if f[0] not in gene_dict:
                gene_dict[f[0]] = 1
            else:
                gene_dict[f[0]] += 1
    # sort gene dict
    
    return gene_dict

gene_dict = count_genes(important_features)

classes = ['0hr', '1hr', '3hr']

def print_features(important_features, class_, dataset):
    print('_'*100)
    print(f'Dataset: {dataset}')
    print(f'Num cells above threshold : {len(cells_to_run_sis)}')
    print(f'Important features to classify as {class_}')
    print(count_genes(important_features))
    print('_'*100)
    for row in important_features:
        print(f'Cell {row[0]} Disjoint SIS: {[g[0] for g in row[1]]}')
        

# get an ordered list of the importance of each feature, ordered by fraction of cells the feature appears in
def order_features(gene_dict, num_confident_cells):
    feature_list = []
    for feature, count in gene_dict.items():
        feature_list.append((feature, count/num_confident_cells))
    # sort by number of appearances
    sorted_features = sorted(feature_list, key=lambda x: x[1], reverse=True)
    return sorted_features
#     # cast as DataFrame
#     return pd.DataFrame()

# print(f'{line}\n' for line in order_features(gene_dict, len(cells_to_run_sis)))
features = order_features(gene_dict, len(cells_to_run_sis))
print(features)

# get a print out of the important features
print_features(important_features, class_=classes[HOURS], dataset=TITLE)
        




[]
____________________________________________________________________________________________________
Dataset: RNA Raw
Num cells above threshold : 103
Important features to classify as 3hr
{}
____________________________________________________________________________________________________
Cell 0 Disjoint SIS: []
Cell 1 Disjoint SIS: []
Cell 2 Disjoint SIS: []
Cell 3 Disjoint SIS: []
Cell 4 Disjoint SIS: []
Cell 5 Disjoint SIS: []
Cell 6 Disjoint SIS: []
Cell 7 Disjoint SIS: []
Cell 8 Disjoint SIS: []
Cell 9 Disjoint SIS: []
Cell 10 Disjoint SIS: []
Cell 11 Disjoint SIS: []
Cell 12 Disjoint SIS: []
Cell 13 Disjoint SIS: []
Cell 14 Disjoint SIS: []
Cell 15 Disjoint SIS: []
Cell 16 Disjoint SIS: []
Cell 17 Disjoint SIS: []
Cell 18 Disjoint SIS: []
Cell 19 Disjoint SIS: []
Cell 20 Disjoint SIS: []
Cell 21 Disjoint SIS: []
Cell 22 Disjoint SIS: []
Cell 23 Disjoint SIS: []
Cell 24 Disjoint SIS: []
Cell 25 Disjoint SIS: []
Cell 26 Disjoint SIS: []
Cell 27 Disjoint SIS: []
Cell 28 Disjoin

In [212]:
data = pd.DataFrame(features, columns=['Feature','% occurrence'])


In [213]:
data

Unnamed: 0,Feature,% occurrence


In [214]:
write_zipped_pickle(data, filename=os.path.join(root, 'results', 'SIS', f'{TITLE}_class_{HOURS}'))

In [155]:
# return sorted list of gene counts
# data = count_genes(important_features), important_features
# write_zipped_pickle(data, filename=os.path.join(root, 'results', 'initial_model_pp_fs_scai', f'ATAC_SVD_{classes[HOURS]}_SIS_thresh={THRESHOLD}.pkl'))

***
## Reference code from Google Research -- running SIS on MNIST

In [None]:
#@title Select a digit and threshold {run: "auto"}

DIGIT = 4  #@param ['0', 1', '2', '3', '4', '5', '6', '7', '8', '9'] {type:"raw"}

THRESHOLD = 0.7  #@param {type:"slider", min:0, max:1, step:0.1}

# Following the SIS paper, we use the mean pixel from training images as a mask.
FULLY_MASKED_IMAGE = np.full((28, 28, 1), np.mean(x_train))

# Helper function that selects the probability for a single class, from the
# softmax output.
def make_f_for_digit(digit, model):
    def f_digit(batch_of_inputs):
        return model.predict(
            batch_of_inputs,
            batch_size=min(784, len(batch_of_inputs)))[:, digit]
    return f_digit

# This function maps a list of images to a list of probabilities (probability of
# each image being a 4).
f_digit = make_f_for_digit(DIGIT, model)

# Helper function that filters input images to those the model predicts with
# high confidence (f(image) >= threshold).
def select_images_for_sis(inputs, f_digit, threshold):
    preds = f_digit(inputs)
    idxs = np.nonzero(preds >= threshold)[0]
    return inputs[idxs]

# Filter test images that the model classifies as 4 with high confidence.
high_confidence_images_for_digit = select_images_for_sis(x_test, f_digit,
                                                         THRESHOLD)

# Randomly select some of these digits to run SIS.
digits_to_run_sis = high_confidence_images_for_digit[
    np.random.choice(high_confidence_images_for_digit.shape[0],
                     size=5,
                     replace=False)]

# Helpers for plotting an MNIST digit and its corresponding SIS-collection.
def plot_mnist_digit(ax, image):
    ax.imshow(image[:, :, 0], cmap=plt.get_cmap('gray'))
    ax.axis('off')

def plot_sis_collection(initial_image, collection, fully_masked_image):
    # Grid contains initial image, an empty cell (for spacing), and collection.
    width = len(collection) + 2
    plt.figure(figsize=(width, 1))
    gs = plt.GridSpec(1, width, wspace=0.1)

    # Plot initial image.
    ax = plt.subplot(gs[0])
    plot_mnist_digit(ax, initial_image)

    # Plot each SIS.
    for i, sis_result in enumerate(collection):
        ax = plt.subplot(gs[i+2])
        masked_image = sis.produce_masked_inputs(
            initial_image, fully_masked_image, [sis_result.mask])[0]
        plot_mnist_digit(ax, masked_image)

    plt.show()

print('Running SIS on {} examples of digit {}. '
      'This might take a couple minutes.'.format(len(digits_to_run_sis), DIGIT))

# Run SIS on each digit and visualize the resulting SIS-collections.
for initial_image in digits_to_run_sis:
    collection = sis.sis_collection(f_digit, THRESHOLD, initial_image,
                                    FULLY_MASKED_IMAGE)
    plot_sis_collection(initial_image, collection, FULLY_MASKED_IMAGE)