In [1]:
import sys
sys.path.append('..')
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sympy import simplify_logic
import time
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree import _tree, export_text
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split

from lens.utils.base import validate_network, set_seed, tree_to_formula
from lens.utils.relu_nn import get_reduced_model, prune_features
from lens.utils.psi_nn import prune_equal_fanin
from lens import logic

results_dir = 'results/celldiff'
if not os.path.isdir(results_dir):
    os.makedirs(results_dir)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
n_rep = 10
tot_epochs = 15001
prune_epochs = 5000
seed = 42

In [2]:
gene_expression_matrix = pd.read_csv('./data/celldiff/data_matrix.csv', index_col=0)
clustering_labels = pd.read_csv('./data/celldiff/cluster_labels.csv', index_col=0)
biomarkers = pd.read_csv('./data/celldiff/markers.csv', index_col=0)

task_id = 2
labels = (clustering_labels.values==task_id)

scaler = MinMaxScaler((0, 1))
scaler.fit(gene_expression_matrix.values)
data_normalized = scaler.transform(gene_expression_matrix.values)

x = torch.FloatTensor(data_normalized)
y = torch.FloatTensor(labels).to(torch.long).squeeze()
print(x.shape)
print(y.shape)

torch.Size([1422, 1845])
torch.Size([1422])


In [3]:
concepts = list(gene_expression_matrix.columns)
concepts[:10]

['alas2',
 'si:xx-by187g17.1',
 'si:ch211-5k11.6',
 'tspan36',
 'ba1l',
 'mmp9',
 'il21r.1',
 'plekhf1',
 'slc3a2a',
 'si:dkey-222f2.7']

In [4]:
n_splits = 10
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0)

In [5]:
def train_nn(x_train, y_train, need_pruning, seed, device, relu=False, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    layers = [
        torch.nn.Linear(x_train.size(1), 200),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(200, 80),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(80, 30),
        torch.nn.ReLU() if relu else torch.nn.LeakyReLU(),
        torch.nn.Linear(30, 2),
        torch.nn.Softmax(dim=1),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    loss_form = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train)
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.000001 * torch.norm(module.weight, 1)
                loss += 0.000001 * torch.norm(module.bias, 1)
                break

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning and epoch % 1000 == 1:
            prune_features(model, n_classes=1, device=device)
            need_pruning = True
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = torch.argmax(y_pred, dim=1)
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [13]:
def train_psi_nn(x_train, y_train, need_pruning, seed, device, verbose=False):
    set_seed(seed)
    x_train = x_train.to(device)
    y_train = y_train.to(device).to(torch.float)
    layers = [
        torch.nn.Linear(x_train.size(1), 10),
        torch.nn.Sigmoid(),
        torch.nn.Linear(10, 4),
        torch.nn.Sigmoid(),
        torch.nn.Linear(4, 1),
        torch.nn.Sigmoid(),
    ]
    model = torch.nn.Sequential(*layers).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_form = torch.nn.BCELoss()
    model.train()
    for epoch in range(tot_epochs):
        # forward pass
        optimizer.zero_grad()
        y_pred = model(x_train).squeeze()
        # Compute Loss
        loss = loss_form(y_pred, y_train)

        for module in model.children():
            if isinstance(module, torch.nn.Linear):
                loss += 0.000001 * torch.norm(module.weight, 1)

        # backward pass
        loss.backward()
        optimizer.step()

        if epoch > prune_epochs and need_pruning:
            model = prune_equal_fanin(model, 2, validate=True, device=device)
            need_pruning = False
            
        # compute accuracy
        if epoch % 500 == 0 and verbose:
            y_pred_d = y_pred > 0.5
            accuracy = y_pred_d.eq(y_train).sum().item() / y_train.size(0)
            print(f'\t Epoch {epoch}: train accuracy: {accuracy:.4f}')
    
    return model

In [7]:
def c_to_y(method, need_pruning, relu, verbose=False):
    methods = []
    splits = []
    explanations = []
    explanations_inv = []
    model_accuracies = []
    explanation_accuracies = []
    explanation_accuracies_inv = []
    elapsed_times = []
    elapsed_times_inv = []
    for split, (trainval_index, test_index) in enumerate(skf.split(x.cpu().detach().numpy(), y.cpu().detach().numpy())):
        print(f'Split [{split+1}/{n_splits}]')
        x_trainval, x_test = torch.FloatTensor(x[trainval_index]), torch.FloatTensor(x[test_index])
        y_trainval, y_test = torch.LongTensor(y[trainval_index]), torch.LongTensor(y[test_index])
        x_train, x_val, y_train, y_val = train_test_split(x_trainval, y_trainval, test_size=0.3, random_state=42)
    
        explanation, explanation_inv = '', ''
        explanation_accuracy, explanation_accuracy_inv = 0, 0
        
        if method == 'tree':
            classifier = DecisionTreeClassifier(random_state=seed)
            classifier.fit(x_trainval.cpu().detach().numpy(), y_trainval.cpu().detach().numpy())
            y_preds = classifier.predict(x_test.cpu().detach().numpy())
            model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds)

            target_class = 1
            start = time.time()
            explanation = tree_to_formula(classifier, concepts, target_class)
            elapsed_time = time.time() - start
            explanation_accuracy = model_accuracy

            target_class_inv = 0
            start = time.time()
            explanation_inv = tree_to_formula(classifier, concepts, target_class_inv)
            elapsed_time_inv = time.time() - start
            explanation_accuracy_inv = model_accuracy
        
        else:
            if method == 'psi':
                # positive class
                target_class = 1
                model = train_psi_nn(x_trainval, y_trainval.eq(target_class), need_pruning, seed, device, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy()
                model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds>0.5)
                
            else:
                model = train_nn(x_trainval, y_trainval, need_pruning, seed, device, relu, verbose)
                y_preds = model(x_test.to(device)).cpu().detach().numpy()
                model_accuracy = accuracy_score(y_test.cpu().detach().numpy(), y_preds.argmax(axis=1))

            # positive class
            target_class = 1
            start = time.time()
            if method == 'psi':
                global_explanation = logic.generate_fol_explanations(model, device)[0]
            else:
                global_explanation, _, _ = logic.relu_nn.combine_local_explanations(model, 
                                                                                   x_val.to(device), 
                                                                                   y_val.to(device), 
                                                                                   topk_explanations=1,
                                                                                   target_class=target_class,
                                                                                   method=method, device=device)
            elapsed_time = time.time() - start
            
            if global_explanation:
                explanation_accuracy, _ = logic.base.test_explanation(global_explanation, target_class, x_test, y_test)
                explanation = logic.base.replace_names(global_explanation, concepts)

#             # negative class
#             target_class_inv = 0
#             if method == 'psi':
#                 model = train_psi_nn(x_trainval, y_trainval.eq(target_class_inv), need_pruning, seed, device, verbose)
            
#             start = time.time()
#             if method == 'psi':
#                 global_explanation_inv = logic.generate_fol_explanations(model, device)[0]
#             else:
#                 global_explanation_inv, _, _ = logic.relu_nn.combine_local_explanations(model, 
#                                                                                        x_val.to(device), 
#                                                                                        y_val.to(device), 
#                                                                                        topk_explanations=2,
#                                                                                        target_class=target_class_inv,
#                                                                                        method=method, device=device)
#             elapsed_time_inv = time.time() - start
#             if global_explanation_inv:
#                 explanation_accuracy_inv, _ = logic.base.test_explanation(global_explanation_inv, 
#                                                                           target_class_inv, x_test, y_test)
#                 explanation_inv = logic.base.replace_names(global_explanation_inv, concepts)
        
        if verbose:
            print(f'\t Model\'s accuracy: {model_accuracy:.4f}')
            print(f'\t Class {target_class} - Global explanation: "{explanation}" - Accuracy: {explanation_accuracy:.4f}')
            print(f'\t Elapsed time {elapsed_time}')
#             print(f'\t Class {target_class_inv} - Global explanation: "{explanation_inv}" - Accuracy: {explanation_accuracy_inv:.4f}')
#             print(f'\t Elapsed time {elapsed_time_inv}')

        methods.append(method)
        splits.append(split)
        explanations.append(explanation)
#         explanations_inv.append(explanation_inv)
        model_accuracies.append(model_accuracy)
        explanation_accuracies.append(explanation_accuracy)
#         explanation_accuracies_inv.append(explanation_accuracy_inv)
        elapsed_times.append(elapsed_time)
#         elapsed_times_inv.append(elapsed_time_inv)
    
    results = pd.DataFrame({
        'method': methods,
        'split': splits,
        'explanation': explanations,
#         'explanation_inv': explanations_inv,
        'model_accuracy': model_accuracies,
        'explanation_accuracy': explanation_accuracies,
#         'explanation_accuracy_inv': explanation_accuracies_inv,
        'elapsed_time': elapsed_times,
#         'elapsed_time_inv': elapsed_times_inv,
    })
    results.to_csv(os.path.join(results_dir, f'results_{method}.csv'))
    
    return results

# General pruning

In [8]:
method = 'pruning'
need_pruning = True
relu = False
results_pruning = c_to_y(method, need_pruning, relu)
results_pruning

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]


Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,elapsed_time
0,pruning,0,dpysl5b & ap5z1 & pcxa & nr2c2 & ~si:ch211-103...,0.93007,85.314685,1.163631
1,pruning,1,FO704779.1 & pcxa & ~CD37 & ~si:ch211-103n10.5...,0.909091,83.916084,1.183507
2,pruning,2,dpysl5b & pcxa & nr2c2 & ~si:ch211-103n10.5 & ...,0.950704,89.43662,0.559986
3,pruning,3,si:dkey-183p4.10 & slc43a2b & si:ch211-132g1.1...,0.943662,83.098592,3.528456
4,pruning,4,dpysl5b & ap5z1 & pcxa & nr2c2 & ~CD37 & ~mafb...,0.971831,85.211268,1.18482
5,pruning,5,gyg1b & dpysl5b & yrk & asap2b & riok3 & hsd3b...,0.922535,83.098592,14.036478
6,pruning,6,dpysl5b & ap5z1 & ppip5k1b & pcxa & ~CD37 & ~s...,0.93662,83.802817,3.152489
7,pruning,7,anxa3b & Unnamed: 247 & si:ch211-132g1.1 & zgc...,0.950704,83.802817,15.447285
8,pruning,8,si:dkeyp-118b1.2 & si:ch211-132g1.1 & plekha6 ...,0.93662,83.802817,9.054447
9,pruning,9,anxa3b & FYB & si:ch211-132g1.1 & zgc:154055 &...,0.929577,83.802817,15.162649


# Weights

In [9]:
method = 'weights'
need_pruning = False
relu = True
results_weights = c_to_y(method, need_pruning, relu)
results_weights

Split [1/10]
Split [2/10]
Split [3/10]
Split [4/10]
Split [5/10]
Split [6/10]
Split [7/10]
Split [8/10]
Split [9/10]
Split [10/10]


Unnamed: 0,method,split,explanation,model_accuracy,explanation_accuracy,elapsed_time
0,weights,0,ap5z1 & pcxa & nr2c2 & ~si:ch211-137j23.8 & ~s...,0.944056,85.314685,1.326543
1,weights,1,pcxa & nr2c2 & ~si:ch211-103n10.5 & ~im:715144...,0.944056,90.20979,0.605859
2,weights,2,pcxa & nr2c2 & ~si:ch211-103n10.5 & ~tmem206 &...,0.971831,88.732394,0.620259
3,weights,3,si:ch211-132g1.1 & pcxa & nr2c2 & ~CD37 & ~si:...,0.971831,90.140845,0.794775
4,weights,4,pcxa & nr2c2 & ~si:ch211-103n10.5 & ~im:715144...,0.985915,89.43662,0.740607
5,weights,5,si:ch211-132g1.1 & ppip5k1b & pcxa & nr2c2 & ~...,0.93662,84.507042,1.093919
6,weights,6,pcxa & ~si:ch211-103n10.5 & ~im:7151449 & ~smi...,0.978873,92.957746,0.712162
7,weights,7,pcxa & nr2c2 & ~si:ch211-103n10.5 & ~tmem206 &...,0.964789,87.323944,1.033122
8,weights,8,si:ch211-132g1.1 & pcxa & nr2c2 & ~si:ch211-10...,0.943662,89.43662,0.812649
9,weights,9,pcxa & nr2c2 & ~si:ch211-103n10.5 & ~im:715144...,0.957746,85.211268,0.544006


# Psi network

In [1]:
method = 'psi'
need_pruning = True
relu = False
results_psi = c_to_y(method, need_pruning, relu, verbose=False)
results_psi

NameError: name 'c_to_y' is not defined

# Decision tree

In [None]:
method = 'tree'
need_pruning = False
relu = False
results_tree = c_to_y(method, need_pruning, relu)
results_tree

# Summary

In [None]:
cols = ['model_accuracy', 'explanation_accuracy', 'explanation_accuracy_inv', 'elapsed_time', 'elapsed_time_inv']
mean_cols = [f'{c}_mean' for c in cols]
sem_cols = [f'{c}_sem' for c in cols]

# pruning
df_mean = results_pruning[cols].mean()
df_sem = results_pruning[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_pruning = pd.concat([df_mean, df_sem])
summary_pruning.name = 'pruning'

# # lime
# df_mean = results_lime[cols].mean()
# df_sem = results_lime[cols].sem()
# df_mean.columns = mean_cols
# df_sem.columns = sem_cols
# summary_lime = pd.concat([df_mean, df_sem])
# summary_lime.name = 'lime'

# weights
df_mean = results_weights[cols].mean()
df_sem = results_weights[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_weights = pd.concat([df_mean, df_sem])
summary_weights.name = 'weights'

# psi
df_mean = results_psi[cols].mean()
df_sem = results_psi[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_psi = pd.concat([df_mean, df_sem])
summary_psi.name = 'psi'

# tree
df_mean = results_tree[cols].mean()
df_sem = results_tree[cols].sem()
df_mean.columns = mean_cols
df_sem.columns = sem_cols
summary_tree = pd.concat([df_mean, df_sem])
summary_tree.name = 'tree'

summary = pd.concat([summary_pruning, 
#                      summary_lime, 
                     summary_weights, 
                     summary_psi, 
                     summary_tree], axis=1).T
summary.columns = mean_cols + sem_cols
summary

In [None]:
summary.to_csv(os.path.join(results_dir, 'summary.csv'))