In [None]:
from collections import OrderedDict
from math import isclose
from pprint import pprint

import numpy as np
from scipy.stats import truncexpon, expon, skewnorm
import matplotlib.pyplot as plt
import seaborn as sns

from neuroconnect.connectivity_patterns import OutgoingDistributionConnections
from neuroconnect.connect_math import (
    sample_from_dist,
    random_draw_dist,
    expected_unique,
    apply_fn_to_dist,
    hypergeometric_pmf,
    combine_dists,
    discretised_rv,
)
from neuroconnect.mpf_connection import CombProb
from neuroconnect.simple_graph import (
    create_graph,
    matrix_vis,
    to_matrix,
    find_connected_limited,
    reverse,
)
from neuroconnect.monte_carlo import (
    monte_carlo,
    list_to_df,
    summarise_monte_carlo,
    get_distribution,
)
from skm_pyutils.table import list_to_df

In [None]:
# CONFIG
max_outgoing_connections = 10000
region1_nodes = np.arange(0, 10000)
region2_nodes = np.arange(10000, 21000)
num_region1_senders = 4000
num_samples = [40, 40]
subsample_rate = None
clt_start = 50
num_monte_carlo_iters = 50000
do_matrix_visualisation = True
smoothing_win_size = 20
exp_scale = 100

np.random.seed(42)

In [None]:
# 2. Create a very skewed connections distribution and link it to a graph.
rv = truncexpon(max_outgoing_connections, scale=100, loc=0)

from neuroconnect.connect_math import get_dist_mean, get_dist_var

skewed_dist = discretised_rv(rv, 0, 2000)
print(list(skewed_dist.values())[-1])
mean_, var_ = get_dist_mean(skewed_dist), get_dist_var(skewed_dist)

connection_instance = OutgoingDistributionConnections(
    region1_nodes, region2_nodes, skewed_dist, num_region1_senders
)

def check_dist_sum(d):
    plt.plot(list(d.keys()), list(d.values()))
result = {}

print(mean_, var_)
check_dist_sum(skewed_dist)

from scipy.stats import uniform
uniform_dist = discretised_rv(uniform(scale=1000, loc=0), 0, 1000)
print(get_dist_mean(uniform_dist), get_dist_var(uniform_dist))
check_dist_sum(uniform_dist)

from scipy.stats import norm
norm_dist = discretised_rv(norm(loc=400, scale=400), 0, 2000)
print(get_dist_mean(norm_dist), get_dist_var(norm_dist))
check_dist_sum(norm_dist)

from scipy.stats import skewnorm
skewnorm_dist = discretised_rv(skewnorm(loc=0, scale=600, a=40), 0, 4000)
print(get_dist_mean(skewnorm_dist), get_dist_var(skewnorm_dist))
check_dist_sum(skewnorm_dist)

In [None]:
# 3a. Test Whether pulling samples of neurons that send connections follows the CLT - the stats part
delta_params = dict(
    num_start=len(region1_nodes),
    num_end=len(region2_nodes),
    num_senders=num_region1_senders,
    out_connections_dist=skewed_dist,
    total_samples=num_samples[0],
    clt_start=clt_start,
    sub=subsample_rate,
)
connection_prob = CombProb(
    len(region1_nodes),
    num_samples[0],
    num_region1_senders,
    len(region2_nodes),
    num_samples[1],
    OutgoingDistributionConnections.static_expected_connections,
    subsample_rate=subsample_rate,
    approx_hypergeo=False,
    **delta_params,
)
result["mpf"] = {
    "expected": connection_prob.expected_connections(),
    "total": connection_prob.get_all_prob(),
    "each_expected": {
        k: connection_prob.expected_total(k) for k in range(num_samples[0] + 1)
    },
}

In [None]:
# 3b the graph part
g_graph, g_connected = connection_instance.create_connections()
g_graph.extend([[] for _ in region2_nodes])
g_reverse_graph = reverse(g_graph)


def random_var_gen(iter_val):
    graph, connected = g_graph, g_connected
    sources = np.random.choice(region1_nodes, num_samples[0], replace=False)
    targets = np.random.choice(region2_nodes, num_samples[-1], replace=False)

    return graph, sources, targets


def fn_to_eval(graph, sources, targets):
    reverse_graph = g_reverse_graph
    reachable = find_connected_limited(
        graph, sources, targets, max_depth=1, reverse_graph=reverse_graph
    )
    return (len(reachable),)


mc_res = monte_carlo(
    fn_to_eval,
    random_var_gen,
    num_monte_carlo_iters,
    num_cpus=1,
    headers=["Connections"],
    save_name="graph_mc.csv",
    save_every=10000,
    progress=True,
)
df = list_to_df(mc_res, ["Connections"])
mc_res = summarise_monte_carlo(
    df, to_plot=["Connections"], plt_outfile="graph_dist.png"
)
distrib = get_distribution(df, "Connections", num_monte_carlo_iters)

if do_matrix_visualisation:
    graph, _, _ = random_var_gen(0)
    AB, BA, AA, BB = to_matrix(graph, len(region1_nodes), len(region2_nodes))
    matrix_vis(
        AB, None, None, None, k_size=smoothing_win_size, name="graph_mat_vis.pdf"
    )

result["graph"] = {"full_results": df, "summary_stats": mc_res, "dist": distrib}
plt.show()


In [None]:
# 4. Compare the graph and stats
fig, ax = plt.subplots()
if "mpf" in result.keys():
    mpf_res = result["mpf"]["total"]
    x = list(mpf_res.keys())
    y = list(mpf_res.values())
    ax.plot(x, y, c="k", label="Statistical estimation")
if "graph" in result.keys():
    graph_res = result["graph"]["dist"]
    x = list(graph_res.keys())
    y = list(graph_res.values())
    ax.plot(x, y, c="b", linestyle="--", label="Monte Carlo simulation")
plt.legend()
plt.show()

pprint(result["graph"])
pprint(result["mpf"])

In [None]:
dist_list = []
mpf_res = result["mpf"]["total"]
graph_res = result["graph"]["dist"]

for k, v in mpf_res.items():
    v2 = graph_res.get(k, 0)
    dist_list.append([k, v, "Statistical estimation"])
    dist_list.append([k, v2, "Monte Carlo simulation"])

cols = [
    "Number of recorded connected neurons",
    "Probability",
    "Calculation",
]
df = list_to_df(dist_list, headers=cols)

fig, ax = plt.subplots(1, 2)
ax[0].plot(list(skewed_dist.keys()), list(skewed_dist.values()))
ax[0].set_xlabel("Forward connections")
ax[0].set_ylabel("Probability")
sns.despine()

sns.lineplot(
    data=df,
    ax=ax[1],
    x="Number of recorded connected neurons",
    y="Probability",
    hue="Calculation",
    style="Calculation",
)
sns.despine()