In [None]:
from collections import OrderedDict
from math import isclose
from pprint import pprint

import numpy as np
from scipy.stats import truncexpon, expon, skewnorm
import matplotlib.pyplot as plt

from neuroconnect.connectivity_patterns import ConnectionStrategy
from neuroconnect.connect_math import (
    sample_from_dist,
    random_draw_dist,
    expected_unique,
    apply_fn_to_dist,
    hypergeometric_pmf,
    combine_dists,
    discretised_rv,
)
from neuroconnect.mpf_connection import CombProb
from neuroconnect.simple_graph import (
    create_graph,
    matrix_vis,
    to_matrix,
    find_connected_limited,
    reverse,
)
from neuroconnect.monte_carlo import (
    monte_carlo,
    list_to_df,
    summarise_monte_carlo,
    get_distribution,
)


In [None]:
# CONFIG
max_outgoing_connections = 10000
region1_nodes = np.arange(0, 10000)
region2_nodes = np.arange(10000, 21000)
num_region1_senders = 4000
num_samples = [40, 40]
subsample_rate = 0.01
clt_start = 50
num_monte_carlo_iters = 50000
do_matrix_visualisation = True
smoothing_win_size = 20
exp_scale = 100

In [None]:
# 1. Define a connection strategy that uses only an outgoing connections distribution
class OutgoingDistributionConnections(ConnectionStrategy):
    """Outgoing connections only that are from a distribution."""

    def __init__(self, region1_nodes, region2_nodes, distribution, num_senders):
        self.region1_nodes = region1_nodes
        self.region2_nodes = region2_nodes
        self.distribution = distribution
        self.num_senders = num_senders

    def create_connections(self):
        graph = []
        max_forward = max(self.distribution.keys())

        # Choose the forward connectors
        connected = np.random.choice(
            self.region1_nodes, size=self.num_senders, replace=False
        )
        num_choices_for_each_sender = sample_from_dist(
            self.distribution, self.num_senders
        )

        f_idx = 0
        for vert in self.region1_nodes:
            if vert in connected:
                forward_connection_subset = np.random.choice(
                    self.region2_nodes, num_choices_for_each_sender[f_idx], replace=True
                )
                if isinstance(forward_connection_subset, np.int32):
                    graph.append(np.array([forward_connection_subset], dtype=np.int32))
                else:
                    graph.append(forward_connection_subset.astype(np.int32))
                f_idx = f_idx + 1
            else:
                graph.append([])

        return graph, connected

    def expected_connections(self, num_samples, **kwargs):
        return OutgoingDistributionConnections.static_expected_connections(
            len(self.region1_nodes),
            len(self.region_nodes),
            self.num_senders,
            self.distribution,
            num_samples,
            **kwargs
        )

    @staticmethod
    def static_expected_connections(
        num_start,
        num_end,
        num_senders,
        out_connections_dist,
        total_samples,
        clt_start=30,
        sub=0.01,
        **kwargs
    ):
        def fn_to_apply(k):
            return expected_unique(num_end, k, do_round=True)

        ab_dist = random_draw_dist(
            total_samples,
            out_connections_dist,
            num_end,
            apply_fn=False,
            keep_all=True,
            clt_start=clt_start,
            sub=sub,
        )

        ab_un_dist = OrderedDict()
        for k, v in ab_dist.items():
            ab_un_dist[k] = apply_fn_to_dist(v, fn_to_apply, sub=sub)
        final_dist = ab_un_dist

        prob_a_senders = OrderedDict()

        for i in range(total_samples + 1):
            prob_a_senders[i] = float(
                hypergeometric_pmf(num_start, num_senders, total_samples, i)
            )

        weighted_dist = combine_dists(
            range(num_end + 1), final_dist, prob_a_senders, sub=None
        )

        return ab_un_dist, weighted_dist


In [None]:
def discretised_rv(rv, min_, max_):
    """Discretise a continous RV into min max range"""
    od = OrderedDict()
    for val in range(min_, max_ + 1):
        od[val] = rv.cdf(min(max_, val + 1)) - rv.cdf(max(min_, val))
    return od

In [None]:
# 2. Create a very skewed connections distribution and link it to a graph.
rv = truncexpon(max_outgoing_connections, scale=exp_scale, loc=0)
x = np.linspace(rv.ppf(0.01), rv.ppf(0.99), 100)

plt.plot(x, rv.pdf(x))

skewed_dist = discretised_rv(rv, 0, max_outgoing_connections)

connection_instance = OutgoingDistributionConnections(
    region1_nodes, region2_nodes, skewed_dist, num_region1_senders
)
summed_vals = []
for val in skewed_dist.values():
    if summed_vals:
        val = val + summed_vals[-1]
    summed_vals.append(val)
plt.plot(list(skewed_dist.keys()), list(skewed_dist.values()))

if not isclose(1.0, summed_vals[-1]):
    print(f"ERROR distribution sums to {summed_vals[-1]}")

result = {}

In [None]:
# 3a. Test Whether pulling samples of neurons that send connections follows the CLT - the stats part
delta_params = dict(
    num_start=len(region1_nodes),
    num_end=len(region2_nodes),
    num_senders=num_region1_senders,
    out_connections_dist=skewed_dist,
    total_samples=num_samples[0],
    clt_start=clt_start,
    sub=subsample_rate,
)
connection_prob = CombProb(
    len(region1_nodes),
    num_samples[0],
    num_region1_senders,
    len(region2_nodes),
    num_samples[1],
    OutgoingDistributionConnections.static_expected_connections,
    subsample_rate=subsample_rate,
    approx_hypergeo=False,
    **delta_params,
)
result["mpf"] = {
    "expected": connection_prob.expected_connections(),
    "total": connection_prob.get_all_prob(),
    "each_expected": {
        k: connection_prob.expected_total(k) for k in range(num_samples[0] + 1)
    },
}

In [None]:
# 3b the graph part
g_graph, g_connected = connection_instance.create_connections()
g_graph.extend([[] for _ in region2_nodes])
g_reverse_graph = reverse(g_graph)


def random_var_gen(iter_val):
    graph, connected = g_graph, g_connected
    sources = np.random.choice(region1_nodes, num_samples[0], replace=False)
    targets = np.random.choice(region2_nodes, num_samples[-1], replace=False)

    return graph, sources, targets


def fn_to_eval(graph, sources, targets):
    reverse_graph = g_reverse_graph
    reachable = find_connected_limited(
        graph, sources, targets, max_depth=1, reverse_graph=reverse_graph
    )
    return (len(reachable),)


mc_res = monte_carlo(
    fn_to_eval,
    random_var_gen,
    num_monte_carlo_iters,
    num_cpus=1,
    headers=["Connections"],
    save_name="graph_mc.csv",
    save_every=10000,
    progress=True,
)
df = list_to_df(mc_res, ["Connections"])
mc_res = summarise_monte_carlo(
    df, to_plot=["Connections"], plt_outfile="graph_dist.png"
)
distrib = get_distribution(df, "Connections", num_monte_carlo_iters)

if do_matrix_visualisation:
    graph, _, _ = random_var_gen(0)
    AB, BA, AA, BB = to_matrix(graph, len(region1_nodes), len(region2_nodes))
    matrix_vis(
        AB, None, None, None, k_size=smoothing_win_size, name="graph_mat_vis.pdf"
    )

result["graph"] = {"full_results": df, "summary_stats": mc_res, "dist": distrib}
plt.show()


In [None]:
# 4. Compare the graph and stats
fig, ax = plt.subplots()
if "mpf" in result.keys():
    mpf_res = result["mpf"]["total"]
    x = list(mpf_res.keys())
    y = list(mpf_res.values())
    ax.plot(x, y, c="k", label="Statistical estimation")
if "graph" in result.keys():
    mpf_res = result["graph"]["dist"]
    x = list(mpf_res.keys())
    y = list(mpf_res.values())
    ax.plot(x, y, c="b", linestyle="--", label="Monte Carlo simulation")
plt.legend()
plt.show()

pprint(result["graph"])
pprint(result["mpf"])