In [None]:
from neuroconnect.connect_math import discretised_rv, get_dist_mean
from neuroconnect.mpf_connection import CombProb
from neuroconnect.connectivity_patterns import OutgoingDistributionConnections
from scipy.stats import truncexpon
import matplotlib.pyplot as plt

In [None]:
# 1. Config - define the power law
# sizes = [100, 1000, 5000, 10000, 50000, 100000]
sizes = [100, 1000, 2500, 5000, 10000, 50000, 100000, 1000000]
scales = [10, 100, 250, 500, 1000, 5000, 10000, 100000]

desired_cfg = 0.4
percent_out = 0.2
sub_rate = 0.001

In [None]:
# Check means
for s, scale in zip(sizes, scales):
    max_out = s // 2
    dist = discretised_rv(truncexpon(s // 2, scale=scale, loc=0), 0, s // 2)
    mean = get_dist_mean(dist)
    plt.plot(list(dist.keys()), list(dist.values()), label=s)
    print(100 * mean / s)
plt.legend()

In [None]:
# 2. How to run the stats
def do_mpf(num_samples, size, dist, percent_out, clt_start=30, subsample_rate=0.01):
    region1_nodes = list(range(size))
    region2_nodes = list(range(size, 2 * size))
    num_region1_senders = size * percent_out
    delta_params = dict(
        num_start=len(region1_nodes),
        num_end=len(region2_nodes),
        num_senders=num_region1_senders,
        out_connections_dist=dist,
        total_samples=num_samples[0],
        clt_start=clt_start,
        sub=subsample_rate,
    )
    connection_prob = CombProb(
        len(region1_nodes),
        num_samples[0],
        num_region1_senders,
        len(region2_nodes),
        num_samples[1],
        OutgoingDistributionConnections.static_expected_connections,
        subsample_rate=subsample_rate,
        approx_hypergeo=False,
        **delta_params,
    )
    return {
        "expected": connection_prob.expected_connections(),
        "total": connection_prob.get_all_prob(),
        "each_expected": {
            k: connection_prob.expected_total(k) for k in range(num_samples[0] + 1)
        },
    }

In [None]:
# 3. Run the stats over different sizes and num samples
# Start with say 20% of the population
def function_to_minimize(
    samples_to_use, full_size, desired, dist, percent_out, subsample_rate
):
    result = do_mpf(samples_to_use, full_size, dist, percent_out, 30, subsample_rate)
    expected = result["expected"]
    return (expected / samples_to_use[1]) - desired


def find_correct_sample_size(
    full_size, scale, percent_out, desired, lb, ub, subsample_rate=0.01
):
    min_ = 0
    max_ = min(full_size, 2000)
    start = min(full_size // 100, 80)
    samples_to_use = [start, start]
    max_out = full_size // 4
    dist = discretised_rv(truncexpon(max_out, scale=scale, loc=0), 0, max_out)
    mean = get_dist_mean(dist)
    print(100 * mean / full_size)

    max_iters = 200
    n = 0

    while min_ != max_ and n < max_iters:
        result = function_to_minimize(
            samples_to_use, full_size, desired, dist, percent_out, subsample_rate
        )
        n += 1
        if n == max_iters:
            raise RuntimeError(
                f"Only found expected of {result['expected']} with {samples_to_use[0]}, not {desired}"
            )
        if result > -lb and result < ub:
            return samples_to_use[0], n
        if result < 0:
            min_ = samples_to_use[0]
            samples_to_use = [(max_ + min_) // 2, (max_ + min_) // 2]
        elif result > 0:
            max_ = samples_to_use[0]
            samples_to_use = [(max_ + min_) // 2, (max_ + min_) // 2]

    return samples_to_use[0], n


In [None]:
# Run over the sizes
sub_rates = [None]
saved = []
for sr in sub_rates:
    found_sizes = [
        find_correct_sample_size(s, scale, percent_out, desired_cfg, 0.02, 0.02, sr)
        for s, scale in zip(sizes, scales)
    ]
    saved.append(found_sizes)

In [None]:
for found_sizes in saved:
    fig, ax = plt.subplots()
    ax.scatter(sizes, [fs[0] for fs in found_sizes])
    print(sizes, found_sizes)
    plt.show()