In [1]:
import torch
import pickle
import matplotlib.pyplot as plt
import numpy as np 
import statsmodels.api as sm
import networkx as nx
import seaborn as sns
import sys

from spikeometric.models import BernoulliGLM
from spikeometric.datasets import NormalGenerator, ConnectivityDataset
from spikeometric.stimulus import RegularStimulus

from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj, to_networkx, from_networkx

sys.path.append('..')

from CD_methods import SCM_learner
from functions import *

from tqdm import tqdm

In [4]:
network_data = torch.load('../data/c_elegans_data.pt')

with open('../data/c_elegans_spike_data_single_node_stimuli.pickle', 'rb') as f:
    spike_data = pickle.load(f)

In [49]:
n_neurons = network_data.num_nodes

G = to_networkx(network_data, node_attrs = ['position'])
position_dict = nx.get_node_attributes(G, 'position')

# sample neurons
n_obs = 50
index_obs = np.sort(np.random.choice(n_neurons, size = n_obs, replace = False))

# design stimulation protocol
#stimulation_protocol = [[i] for i in index_obs]
prop_intervened = 0.3
stimulate_nodes = np.sort(np.random.choice(index_obs, size = int(n_obs*prop_intervened), replace = False))
stimulation_protocol = [[i] for i in stimulate_nodes]
stimulation_protocol_str = [str(i) for i in stimulate_nodes] + ['null']
print(stimulation_protocol_str)
print('num experiments=', len(stimulation_protocol_str))

['5', '7', '10', '37', '68', '76', '112', '119', '132', '207', '231', '237', '238', '251', '268', 'null']
num experiments= 16


In [50]:
spike_data_obs = dict()
spike_data_obs['null'] = spike_data['null'][index_obs]
for intervention in index_obs:
    spike_data_obs[str(intervention)] = spike_data[str(intervention)][index_obs]

In [51]:
G_obs = nx.subgraph(G, index_obs)
index_hidden = [node for node in range(n_neurons) if node not in index_obs]
confounders = []
for node in index_hidden:
    count = 0
    for _, v in G.out_edges(node):
        if v in index_obs:
            count += 1
    if count >= 2:
        confounders.append(node)
print('num. confounding variables = ',len(confounders))

num. confounding variables =  92


In [52]:
G_learned = SCM_learner(spike_data_obs, 
                        node_list=index_obs, 
                        stimulation_protocol=stimulation_protocol, 
                        alpha = 0.01)

100%|██████████████████████████████████████████| 15/15 [00:00<00:00, 314.11it/s]


In [53]:
G_true = nx.subgraph(G, index_obs)

A_true = nx.adjacency_matrix(G_true, nodelist=index_obs).todense() 
A_learned = nx.adjacency_matrix(G_learned, nodelist=index_obs).todense() 
A_diff = A_true - A_learned

SHD = np.sum(np.abs(A_true- A_learned))
TP = np.sum( (A_true == 1)*(A_learned==1) )
TN = np.sum( (A_true == 0)*(A_learned==0)) 
FP = np.sum(A_diff == -1)
FN = np.sum(A_diff == 1) 

index_hidden = [node for node in range(n_neurons) if node not in index_obs]
confounders = []
for node in index_hidden:
    count = 0
    for _, v in G.out_edges(node):
        if v in index_obs:
            count += 1
    if count >= 2:
        confounders.append(node)
print('num. confounding variables = ', len(confounders))

print('total edges (in true observed graph) = ',G_true.number_of_edges())
print('percentage of nodes observed = ', np.round(G_obs.number_of_nodes() / G.number_of_nodes() * 100, 2), '%')
print('SHD =',compute_SHD(G_true, G_learned))
print('sensitivity=', np.round(compute_sensitivity(G_true, G_learned, nodelist=index_obs), 4))
print('specificity=', np.round(compute_specificity(G_true, G_learned, nodelist=index_obs), 4))

num. confounding variables =  92
total edges (in true observed graph) =  63
percentage of nodes observed =  17.92 %
SHD = 15
sensitivity= 0.8077
specificity= 1.0


In [63]:
experiment_random_single_node_intervention = []

#for n_obs in np.arange(10, 100, 10):
n_obs = 80
for p in np.arange(0.1, 1.01, 0.1):

    print('n_obs = ', n_obs, 'proportion intervened = ', np.round(p, 2))
    index_obs = np.sort(np.random.choice(n_neurons, size = n_obs, replace = False))

    # select intervened nodes randomly
    stimulate_nodes = np.sort(np.random.choice(index_obs, size = int(n_obs*p), replace = False))
    stimulation_protocol = [[i] for i in stimulate_nodes]
    stimulation_protocol_str = [str(i) for i in stimulate_nodes] + ['null']

    # get data
    spike_data_obs = dict()
    spike_data_obs['null'] = spike_data['null'][index_obs]
    for intervention in index_obs:
        spike_data_obs[str(intervention)] = spike_data[str(intervention)][index_obs]


    G_learned = SCM_learner(spike_data_obs, 
                        node_list=index_obs, 
                        stimulation_protocol=stimulation_protocol, 
                        alpha = 0.01)

    G_true = nx.subgraph(G, index_obs)

    print('SHD =',compute_SHD(G_true, G_learned))
    print('sensitivity=', np.round(compute_sensitivity(G_true, G_learned, nodelist=index_obs), 4))
    print('specificity=', np.round(compute_specificity(G_true, G_learned, nodelist=index_obs), 4))
    print('')

n_obs =  80 proportion intervened =  0.1
['18', '90', '110', '136', '168', '169', '177', '231', 'null']


100%|█████████████████████████████████████████████| 8/8 [00:00<00:00, 81.88it/s]


SHD = 63
sensitivity= 0.7225
specificity= 1.0

n_obs =  80 proportion intervened =  0.2
['33', '71', '103', '106', '127', '147', '148', '161', '169', '170', '171', '192', '195', '215', '234', '240', 'null']


100%|███████████████████████████████████████████| 16/16 [00:00<00:00, 67.48it/s]


SHD = 62
sensitivity= 0.7033
specificity= 1.0

n_obs =  80 proportion intervened =  0.30000000000000004
['35', '50', '55', '67', '84', '88', '92', '100', '103', '104', '110', '125', '128', '131', '137', '146', '151', '189', '197', '222', '232', '233', '249', '268', 'null']


100%|███████████████████████████████████████████| 24/24 [00:00<00:00, 78.36it/s]


SHD = 58
sensitivity= 0.7883
specificity= 1.0

n_obs =  80 proportion intervened =  0.4
['4', '10', '21', '29', '31', '41', '44', '54', '60', '71', '94', '95', '102', '114', '115', '117', '120', '125', '135', '139', '148', '159', '178', '193', '221', '222', '226', '228', '232', '250', '272', '273', 'null']


100%|███████████████████████████████████████████| 32/32 [00:00<00:00, 66.17it/s]


SHD = 48
sensitivity= 0.7714
specificity= 1.0

n_obs =  80 proportion intervened =  0.5
['0', '3', '4', '7', '8', '18', '28', '29', '35', '47', '53', '57', '74', '77', '87', '93', '95', '107', '108', '113', '119', '122', '127', '128', '133', '135', '159', '167', '176', '178', '183', '187', '203', '207', '227', '229', '230', '250', '261', '264', 'null']


100%|███████████████████████████████████████████| 40/40 [00:00<00:00, 65.46it/s]


SHD = 40
sensitivity= 0.8444
specificity= 1.0

n_obs =  80 proportion intervened =  0.6000000000000001
['7', '15', '29', '31', '33', '38', '39', '41', '42', '43', '47', '59', '71', '79', '89', '91', '93', '97', '99', '113', '118', '119', '128', '130', '131', '133', '138', '143', '146', '162', '185', '194', '197', '198', '204', '210', '211', '215', '221', '224', '232', '247', '251', '255', '268', '269', '270', '277', 'null']


100%|███████████████████████████████████████████| 48/48 [00:00<00:00, 91.61it/s]


SHD = 21
sensitivity= 0.8618
specificity= 1.0

n_obs =  80 proportion intervened =  0.7000000000000001
['1', '2', '4', '5', '8', '11', '13', '18', '26', '29', '33', '44', '58', '59', '64', '68', '73', '84', '96', '101', '102', '105', '108', '114', '122', '124', '137', '143', '151', '153', '154', '156', '157', '166', '169', '171', '173', '174', '175', '179', '180', '184', '191', '204', '220', '221', '233', '236', '240', '246', '250', '253', '258', '261', '271', '273', 'null']


100%|███████████████████████████████████████████| 56/56 [00:00<00:00, 90.72it/s]


SHD = 24
sensitivity= 0.8509
specificity= 1.0

n_obs =  80 proportion intervened =  0.8
['4', '13', '14', '18', '19', '20', '24', '29', '31', '38', '41', '42', '46', '50', '54', '58', '67', '79', '83', '84', '87', '101', '106', '111', '113', '118', '120', '121', '128', '132', '137', '139', '147', '149', '162', '164', '165', '172', '179', '180', '181', '183', '186', '189', '190', '197', '204', '209', '215', '221', '227', '230', '234', '236', '237', '239', '241', '244', '249', '254', '257', '266', '273', '278', 'null']


100%|███████████████████████████████████████████| 64/64 [00:01<00:00, 46.75it/s]


SHD = 21
sensitivity= 0.8934
specificity= 1.0

n_obs =  80 proportion intervened =  0.9
['3', '7', '10', '12', '19', '32', '42', '50', '55', '56', '57', '59', '63', '68', '69', '70', '72', '80', '85', '99', '100', '101', '107', '108', '111', '118', '121', '131', '141', '142', '143', '144', '145', '148', '149', '150', '156', '157', '163', '166', '170', '172', '173', '176', '178', '180', '181', '183', '187', '191', '196', '201', '210', '213', '214', '224', '225', '229', '232', '233', '237', '241', '242', '245', '249', '252', '255', '263', '264', '267', '268', '272', 'null']


100%|███████████████████████████████████████████| 72/72 [00:01<00:00, 58.96it/s]


SHD = 7
sensitivity= 0.9662
specificity= 1.0

n_obs =  80 proportion intervened =  1.0
['1', '2', '13', '17', '21', '22', '24', '29', '38', '41', '54', '55', '58', '60', '63', '64', '66', '67', '68', '72', '75', '77', '79', '83', '86', '87', '89', '100', '102', '112', '118', '120', '121', '122', '125', '126', '128', '131', '133', '136', '137', '138', '142', '154', '158', '159', '160', '161', '166', '168', '169', '170', '174', '179', '181', '194', '198', '204', '212', '213', '217', '219', '222', '226', '227', '229', '231', '238', '241', '243', '246', '249', '251', '257', '261', '263', '264', '267', '268', '277', 'null']


100%|███████████████████████████████████████████| 80/80 [00:01<00:00, 75.06it/s]

SHD = 1
sensitivity= 0.994
specificity= 1.0






### Notes
- Seems to be a 'critical threshold' for causal discovery, where the learned graph is reliable and accurate whenever we observe more than 28 % of the network 
 - Why this happens is unclear. But somehow the signal of the observed neurons is not strong enough with small samples. 
     Could it be:
     - Some set of 'critical nodes' are almost always included in the sampled data, fx high degree nodes
     - The number of confounders relative to observed nodes is reduced
- Need to investigate: 
    - The exact number of nodes required for threshold to be reached
    - Maybe look at the out- and in-degree distribution of nodes observed

In [60]:
sample_space = list(np.arange(279))
sample_space.remove(0)
sample_space.remove(6)
sample_space.remove(12)
sample_space.remove(98)
sample_space.remove(111)
sample_space.remove(129)
sample_space.remove(134)
sample_space.remove(142)
sample_space.remove(230)
sample_space.remove(238)
sample_space.remove(143)
sample_space.remove(146)
sample_space.remove(188)
sample_space.remove(270)

print(len(sample_space))

265


In [61]:
for n_obs in np.arange(70, 80, 2):
    index_obs = list(np.random.choice(128, size = n_obs, replace = False))
    index_obs += [146, 188]
    index_obs=np.sort(index_obs)
    print('n_obs = ', len(index_obs))
    
    stimulation_protocol = [[i] for i in index_obs]
    spike_data_obs = dict()
    spike_data_obs['null'] = spike_data['null'][index_obs]
    for intervention in index_obs:
        spike_data_obs[str(intervention)] = spike_data[str(intervention)][index_obs]
    
    # count num. confounders
    G_obs = nx.subgraph(G, index_obs)
    index_hidden = [node for node in range(n_neurons) if node not in index_obs]
    confounders = []
    for node in index_hidden:
        count = 0
        for _, v in G.out_edges(node):
            if v in index_obs:
                count += 1
        if count >= 2:
            confounders.append(node)
    print('num. confounding variables = ',len(confounders))
    
    G_learned = SCM_learner(spike_data_obs, 
                        node_list=index_obs, 
                        stimulation_protocol=stimulation_protocol, 
                        alpha = 0.01)
    
    G_true = nx.subgraph(G, index_obs)

    print('SHD =',compute_SHD(G_true, G_learned))
    print('sensitivity=', np.round(compute_sensitivity(G_true, G_learned, nodelist=index_obs), 4))
    print('specificity=', np.round(compute_specificity(G_true, G_learned, nodelist=index_obs), 4))
    print('')

n_obs =  72
num. confounding variables =  95


100%|██████████████████████████████████████████| 72/72 [00:00<00:00, 108.26it/s]


SHD = 0
sensitivity= 1.0
specificity= 1.0

n_obs =  74
num. confounding variables =  104


100%|███████████████████████████████████████████| 74/74 [00:00<00:00, 81.00it/s]


SHD = 0
sensitivity= 1.0
specificity= 1.0

n_obs =  76
num. confounding variables =  112


100%|██████████████████████████████████████████| 76/76 [00:00<00:00, 109.90it/s]


SHD = 0
sensitivity= 1.0
specificity= 1.0

n_obs =  78
num. confounding variables =  94


100%|██████████████████████████████████████████| 78/78 [00:00<00:00, 104.35it/s]


SHD = 1
sensitivity= 0.9928
specificity= 1.0

n_obs =  80
num. confounding variables =  104


100%|███████████████████████████████████████████| 80/80 [00:00<00:00, 91.74it/s]

SHD = 0
sensitivity= 1.0
specificity= 1.0






# Notes
- when high degree nodes are included in dataset, the learner struggles below 28% threshold. if not included, it still works. might be other 'critical nodes' to consider, like 0 and 6. Node 129 seems to be a problem.
- removing all nodes with in_degree > 20 from observed data is also bad...
- This info can be useful wrt designing experiments! see how many experiments are needed to get good results when targeting high degree nodes first.

### Update
- I was making a mistake when computing SHD, turns out there is not a threshold????? Weird, because it seemed like some nodes where making a big difference yesterday
- Now complete single node stimulation protocol will give a basically perfect result for any size of the observed network. I guess that is in line with what we should expect, given the fact that the model captures every feature of the spiking pattern and we correct for any confounding.