In [1]:
import torch
import pickle
import matplotlib.pyplot as plt
import numpy as np 
import statsmodels.api as sm
import networkx as nx
import seaborn as sns

from spikeometric.models import BernoulliGLM
from spikeometric.datasets import NormalGenerator, ConnectivityDataset
from spikeometric.stimulus import RegularStimulus

from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj, to_networkx, from_networkx
from CD_methods import SCM_learner

from tqdm import tqdm

In [2]:
network_data = torch.load('data/c_elegans_data.pt')

with open('data/c_elegans_spike_data_single_node_stimuli.pickle', 'rb') as f:
    spike_data = pickle.load(f)

In [17]:
n_neurons = network_data.num_nodes

G = to_networkx(network_data, node_attrs = ['position'])
position_dict = nx.get_node_attributes(G, 'position')

# sample neurons
n_obs = 90
index_obs = np.sort(np.random.choice(n_neurons, size = n_obs, replace = False))

stimulation_protocol = [[i] for i in index_obs]
spike_data_obs = dict()
spike_data_obs['null'] = spike_data['null'][index_obs]
for intervention in index_obs:
    spike_data_obs[str(intervention)] = spike_data[str(intervention)][index_obs]

In [20]:
G_obs = nx.subgraph(G, index_obs)
index_hidden = [node for node in range(n_neurons) if node not in index_obs]
confounders = []
for node in index_hidden:
    count = 0
    for _, v in G.out_edges(node):
        if v in index_obs:
            count += 1
    if count >= 2:
        confounders.append(node)
print('num. confounding variables = ',len(confounders))

num. confounding variables =  101


In [23]:
G_learned = SCM_learner(spike_data_obs, 
                        node_list=index_obs, 
                        stimulation_protocol=stimulation_protocol, 
                        alpha = 0.01)

100%|███████████████████████████████████████████| 90/90 [00:01<00:00, 74.39it/s]


In [24]:
G_true = nx.subgraph(G, index_obs)

A_true = nx.adjacency_matrix(G_true).todense() 
A_learned = nx.adjacency_matrix(G_learned).todense() 
A_diff = A_true - A_learned

SHD = np.sum(np.abs(A_true- A_learned))
TP = np.sum( (A_true == 1)*(A_learned==1) )
TN = np.sum( (A_true == 0)*(A_learned==0)) 
FP = np.sum(A_diff == -1)
FN = np.sum(A_diff == 1) 

index_hidden = [node for node in range(n_neurons) if node not in index_obs]
confounders = []
for node in index_hidden:
    count = 0
    for _, v in G.out_edges(node):
        if v in index_obs:
            count += 1
    if count >= 2:
        confounders.append(node)
print('num. confounding variables = ', len(confounders))

print('total edges (in true observed graph) = ',G_true.number_of_edges())
print('percentage of nodes observed = ', np.round(G_obs.number_of_nodes() / G.number_of_nodes() * 100, 2), '%')
print('SHD = ', SHD)
print('True positives = ', TP, ', Sensitivity = ', np.round(TP / (TP + FP) * 100, 2) )
print('True negatives = ', TN, ', Specificity = ', np.round(TN / (TN + FN)*100, 2) )
print('False positives = ', FP )
print('False negatives = ', FN )

num. confounding variables =  101
total edges (in true observed graph) =  190
percentage of nodes observed =  32.26 %
SHD =  0
True positives =  190 , Sensitivity =  100.0
True negatives =  7910 , Specificity =  100.0
False positives =  0
False negatives =  0


In [31]:
for n_obs in np.arange(10, 279, 10):
    print('n_obs = ', n_obs)
    index_obs = np.sort(np.random.choice(n_neurons, size = n_obs, replace = False))
    stimulation_protocol = [[i] for i in index_obs]
    spike_data_obs = dict()
    spike_data_obs['null'] = spike_data['null'][index_obs]
    for intervention in index_obs:
        spike_data_obs[str(intervention)] = spike_data[str(intervention)][index_obs]
    
    # count num. confounders
    G_obs = nx.subgraph(G, index_obs)
    index_hidden = [node for node in range(n_neurons) if node not in index_obs]
    confounders = []
    for node in index_hidden:
        count = 0
        for _, v in G.out_edges(node):
            if v in index_obs:
                count += 1
        if count >= 2:
            confounders.append(node)
    print('num. confounding variables = ',len(confounders))
    
    G_learned = SCM_learner(spike_data_obs, 
                        node_list=index_obs, 
                        stimulation_protocol=stimulation_protocol, 
                        alpha = 0.01)
    
    G_true = nx.subgraph(G, index_obs)

    A_true = nx.adjacency_matrix(G_true).todense() 
    A_learned = nx.adjacency_matrix(G_learned).todense() 
    A_diff = A_true - A_learned

    SHD = np.sum(np.abs(A_true- A_learned))
    TP = np.sum( (A_true == 1)*(A_learned==1) )
    TN = np.sum( (A_true == 0)*(A_learned==0)) 
    FP = np.sum(A_diff == -1)
    FN = np.sum(A_diff == 1) 

    print('percentage of nodes observed = ', np.round(G_obs.number_of_nodes() / G.number_of_nodes() * 100, 2), '%')
    print('SHD = ', SHD)
    print('True positives = ', TP, ', Sensitivity = ', np.round(TP / (TP + FP) * 100, 2) )
    print('True negatives = ', TN, ', Specificity = ', np.round(TN / (TN + FN)*100, 2) )
    print('False positives = ', FP )
    print('False negatives = ', FN )
    print('')

n_obs =  10
num. confounding variables =  16


100%|█████████████████████████████████████████| 10/10 [00:00<00:00, 3222.92it/s]

percentage of nodes observed =  3.58 %
SHD =  2
True positives =  0 , Sensitivity =  0.0
True negatives =  98 , Specificity =  98.99
False positives =  1
False negatives =  1

n_obs =  20
num. confounding variables =  43



100%|██████████████████████████████████████████| 20/20 [00:00<00:00, 346.48it/s]


percentage of nodes observed =  7.17 %
SHD =  28
True positives =  0 , Sensitivity =  0.0
True negatives =  372 , Specificity =  96.37
False positives =  14
False negatives =  14

n_obs =  30
num. confounding variables =  59


100%|██████████████████████████████████████████| 30/30 [00:00<00:00, 727.03it/s]


percentage of nodes observed =  10.75 %
SHD =  34
True positives =  0 , Sensitivity =  0.0
True negatives =  866 , Specificity =  98.07
False positives =  17
False negatives =  17

n_obs =  40
num. confounding variables =  76


100%|██████████████████████████████████████████| 40/40 [00:00<00:00, 239.13it/s]


percentage of nodes observed =  14.34 %
SHD =  82
True positives =  3 , Sensitivity =  6.82
True negatives =  1515 , Specificity =  97.37
False positives =  41
False negatives =  41

n_obs =  50
num. confounding variables =  75


100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 149.81it/s]


percentage of nodes observed =  17.92 %
SHD =  129
True positives =  0 , Sensitivity =  0.0
True negatives =  2371 , Specificity =  97.37
False positives =  65
False negatives =  64

n_obs =  60
num. confounding variables =  98


100%|██████████████████████████████████████████| 60/60 [00:00<00:00, 119.37it/s]


percentage of nodes observed =  21.51 %
SHD =  174
True positives =  5 , Sensitivity =  5.43
True negatives =  3421 , Specificity =  97.52
False positives =  87
False negatives =  87

n_obs =  70
num. confounding variables =  89


100%|██████████████████████████████████████████| 70/70 [00:00<00:00, 137.95it/s]


percentage of nodes observed =  25.09 %
SHD =  190
True positives =  2 , Sensitivity =  2.06
True negatives =  4708 , Specificity =  98.02
False positives =  95
False negatives =  95

n_obs =  80
num. confounding variables =  107


100%|███████████████████████████████████████████| 80/80 [00:01<00:00, 47.70it/s]


percentage of nodes observed =  28.67 %
SHD =  1
True positives =  232 , Sensitivity =  99.57
True negatives =  6167 , Specificity =  100.0
False positives =  1
False negatives =  0

n_obs =  90
num. confounding variables =  103


100%|███████████████████████████████████████████| 90/90 [00:01<00:00, 50.26it/s]


percentage of nodes observed =  32.26 %
SHD =  0
True positives =  254 , Sensitivity =  100.0
True negatives =  7846 , Specificity =  100.0
False positives =  0
False negatives =  0

n_obs =  100
num. confounding variables =  109


100%|█████████████████████████████████████████| 100/100 [00:01<00:00, 51.22it/s]


percentage of nodes observed =  35.84 %
SHD =  1
True positives =  268 , Sensitivity =  99.63
True negatives =  9731 , Specificity =  100.0
False positives =  1
False negatives =  0

n_obs =  110
num. confounding variables =  111


100%|█████████████████████████████████████████| 110/110 [00:02<00:00, 45.68it/s]


percentage of nodes observed =  39.43 %
SHD =  4
True positives =  333 , Sensitivity =  99.11
True negatives =  11763 , Specificity =  99.99
False positives =  3
False negatives =  1

n_obs =  120
num. confounding variables =  101


100%|█████████████████████████████████████████| 120/120 [00:03<00:00, 36.30it/s]


percentage of nodes observed =  43.01 %
SHD =  1
True positives =  410 , Sensitivity =  99.76
True negatives =  13989 , Specificity =  100.0
False positives =  1
False negatives =  0

n_obs =  130
num. confounding variables =  102


100%|█████████████████████████████████████████| 130/130 [00:04<00:00, 26.76it/s]


percentage of nodes observed =  46.59 %
SHD =  5
True positives =  521 , Sensitivity =  99.43
True negatives =  16374 , Specificity =  99.99
False positives =  3
False negatives =  2

n_obs =  140
num. confounding variables =  94


100%|█████████████████████████████████████████| 140/140 [00:05<00:00, 23.37it/s]


percentage of nodes observed =  50.18 %
SHD =  2
True positives =  635 , Sensitivity =  100.0
True negatives =  18963 , Specificity =  99.99
False positives =  0
False negatives =  2

n_obs =  150
num. confounding variables =  94


100%|█████████████████████████████████████████| 150/150 [00:06<00:00, 21.66it/s]


percentage of nodes observed =  53.76 %
SHD =  3
True positives =  720 , Sensitivity =  99.59
True negatives =  21777 , Specificity =  100.0
False positives =  3
False negatives =  0

n_obs =  160
num. confounding variables =  91


100%|█████████████████████████████████████████| 160/160 [00:08<00:00, 18.27it/s]


percentage of nodes observed =  57.35 %
SHD =  2
True positives =  855 , Sensitivity =  99.88
True negatives =  24743 , Specificity =  100.0
False positives =  1
False negatives =  1

n_obs =  170
num. confounding variables =  83


100%|█████████████████████████████████████████| 170/170 [00:07<00:00, 23.09it/s]


percentage of nodes observed =  60.93 %
SHD =  3
True positives =  757 , Sensitivity =  99.61
True negatives =  28140 , Specificity =  100.0
False positives =  3
False negatives =  0

n_obs =  180
num. confounding variables =  76


100%|█████████████████████████████████████████| 180/180 [00:08<00:00, 20.40it/s]


percentage of nodes observed =  64.52 %
SHD =  5
True positives =  886 , Sensitivity =  99.44
True negatives =  31509 , Specificity =  100.0
False positives =  5
False negatives =  0

n_obs =  190
num. confounding variables =  74


100%|█████████████████████████████████████████| 190/190 [00:10<00:00, 18.53it/s]


percentage of nodes observed =  68.1 %
SHD =  3
True positives =  938 , Sensitivity =  99.68
True negatives =  35159 , Specificity =  100.0
False positives =  3
False negatives =  0

n_obs =  200
num. confounding variables =  67


100%|█████████████████████████████████████████| 200/200 [00:12<00:00, 16.52it/s]


percentage of nodes observed =  71.68 %
SHD =  1
True positives =  1061 , Sensitivity =  100.0
True negatives =  38938 , Specificity =  100.0
False positives =  0
False negatives =  1

n_obs =  210
num. confounding variables =  54


100%|█████████████████████████████████████████| 210/210 [00:20<00:00, 10.37it/s]


percentage of nodes observed =  75.27 %
SHD =  6
True positives =  1281 , Sensitivity =  99.61
True negatives =  42813 , Specificity =  100.0
False positives =  5
False negatives =  1

n_obs =  220
num. confounding variables =  48


100%|█████████████████████████████████████████| 220/220 [00:22<00:00,  9.88it/s]


percentage of nodes observed =  78.85 %
SHD =  5
True positives =  1415 , Sensitivity =  99.65
True negatives =  46980 , Specificity =  100.0
False positives =  5
False negatives =  0

n_obs =  230
num. confounding variables =  37


100%|█████████████████████████████████████████| 230/230 [00:24<00:00,  9.48it/s]


percentage of nodes observed =  82.44 %
SHD =  6
True positives =  1484 , Sensitivity =  99.66
True negatives =  51410 , Specificity =  100.0
False positives =  5
False negatives =  1

n_obs =  240
num. confounding variables =  35


100%|█████████████████████████████████████████| 240/240 [00:23<00:00, 10.11it/s]


percentage of nodes observed =  86.02 %
SHD =  11
True positives =  1545 , Sensitivity =  99.36
True negatives =  56044 , Specificity =  100.0
False positives =  10
False negatives =  1

n_obs =  250
num. confounding variables =  24


KeyboardInterrupt: 

### Notes
- Seems to be a 'critical threshold' for causal discovery, where the learned graph is reliable and accurate whenever we observe more than 28 % of the network 
 - Why this happens is unclear. But somehow the signal of the observed neurons is not strong enough with small samples. 
     Could it be:
     - Some set of 'critical nodes' are almost always included in the sampled data, fx high degree nodes
     - The number of confounders relative to observed nodes is reduced
- Need to investigate: 
    - The exact number of nodes required for threshold to be reached
    - Maybe look at the out- and in-degree distribution of nodes observed

In [62]:
sample_space = list(np.arange(279))
sample_space.remove(0)
sample_space.remove(6)
sample_space.remove(12)
sample_space.remove(98)
sample_space.remove(111)
sample_space.remove(129)
sample_space.remove(134)
sample_space.remove(142)
sample_space.remove(230)
sample_space.remove(238)
sample_space.remove(143)
sample_space.remove(146)
sample_space.remove(188)
sample_space.remove(270)

print(len(sample_space))

267


In [71]:
for n_obs in np.arange(70, 80, 2):
    index_obs = list(np.random.choice(128, size = n_obs, replace = False))
    index_obs += [129]
    index_obs=np.sort(index_obs)
    print('n_obs = ', len(index_obs))
    
    stimulation_protocol = [[i] for i in index_obs]
    spike_data_obs = dict()
    spike_data_obs['null'] = spike_data['null'][index_obs]
    for intervention in index_obs:
        spike_data_obs[str(intervention)] = spike_data[str(intervention)][index_obs]
    
    # count num. confounders
    G_obs = nx.subgraph(G, index_obs)
    index_hidden = [node for node in range(n_neurons) if node not in index_obs]
    confounders = []
    for node in index_hidden:
        count = 0
        for _, v in G.out_edges(node):
            if v in index_obs:
                count += 1
        if count >= 2:
            confounders.append(node)
    print('num. confounding variables = ',len(confounders))
    
    G_learned = SCM_learner(spike_data_obs, 
                        node_list=index_obs, 
                        stimulation_protocol=stimulation_protocol, 
                        alpha = 0.01)
    
    G_true = nx.subgraph(G, index_obs)

    A_true = nx.adjacency_matrix(G_true).todense() 
    A_learned = nx.adjacency_matrix(G_learned).todense() 
    A_diff = A_true - A_learned

    SHD = np.sum(np.abs(A_true- A_learned))
    TP = np.sum( (A_true == 1)*(A_learned==1) )
    TN = np.sum( (A_true == 0)*(A_learned==0)) 
    FP = np.sum(A_diff == -1)
    FN = np.sum(A_diff == 1) 

    print('percentage of nodes observed = ', np.round(G_obs.number_of_nodes() / G.number_of_nodes() * 100, 2), '%')
    print('SHD = ', SHD)
    print('True positives = ', TP, ', Sensitivity = ', np.round(TP / (TP + FP) * 100, 2) )
    print('True negatives = ', TN, ', Specificity = ', np.round(TN / (TN + FN)*100, 2) )
    print('False positives = ', FP )
    print('False negatives = ', FN )
    print('')

n_obs =  71
num. confounding variables =  103


100%|███████████████████████████████████████████| 71/71 [00:00<00:00, 80.70it/s]


percentage of nodes observed =  25.45 %
SHD =  270
True positives =  5 , Sensitivity =  3.57
True negatives =  4766 , Specificity =  97.25
False positives =  135
False negatives =  135

n_obs =  73
num. confounding variables =  84


100%|███████████████████████████████████████████| 73/73 [00:00<00:00, 87.02it/s]


percentage of nodes observed =  26.16 %
SHD =  284
True positives =  0 , Sensitivity =  0.0
True negatives =  5045 , Specificity =  97.26
False positives =  142
False negatives =  142

n_obs =  75
num. confounding variables =  95


100%|███████████████████████████████████████████| 75/75 [00:00<00:00, 82.19it/s]


percentage of nodes observed =  26.88 %
SHD =  296
True positives =  5 , Sensitivity =  3.27
True negatives =  5324 , Specificity =  97.3
False positives =  148
False negatives =  148

n_obs =  77
num. confounding variables =  99


100%|███████████████████████████████████████████| 77/77 [00:00<00:00, 79.04it/s]


percentage of nodes observed =  27.6 %
SHD =  0
True positives =  152 , Sensitivity =  100.0
True negatives =  5777 , Specificity =  100.0
False positives =  0
False negatives =  0

n_obs =  79
num. confounding variables =  97


100%|███████████████████████████████████████████| 79/79 [00:00<00:00, 79.46it/s]

percentage of nodes observed =  28.32 %
SHD =  0
True positives =  177 , Sensitivity =  100.0
True negatives =  6064 , Specificity =  100.0
False positives =  0
False negatives =  0






# Notes
- when high degree nodes are included in dataset, the learner struggles below 28% threshold. if not included, it still works. might be other 'critical nodes' to consider, like 0 and 6. Node 129 seems to be a problem.
- removing all nodes with in_degree > 20 from observed data is also bad...
- This info can be useful wrt designing experiments! see how many experiments are needed to get good results when targeting high degree nodes first.