In [1]:
%cd ../

/Users/thayer/develop/covid_households


In [73]:
# confidence intervals for multinomial outcomes sampled however many times

import numpy as np
import statsmodels.stats.proportion as smprop

p=np.array([0.7,0.2,0.05,0.04,0.01])
counts = np.array(np.unique(np.random.choice(5,300000,p=p), return_counts=True)).T[:,1]
smprop.multinomial_proportions_confint(counts)

array([[0.69718585, 0.70149867],
       [0.19865032, 0.20241626],
       [0.04931836, 0.05137486],
       [0.03883804, 0.04067565],
       [0.00958512, 0.01052321]])

In [70]:
counts = np.array(np.unique(np.random.choice(4,250000,), return_counts=True)).T[:,1]
smprop.multinomial_proportions_confint(counts)

array([[0.24695378, 0.25127474],
       [0.24802664, 0.25235382],
       [0.24717713, 0.25149938],
       [0.24921519, 0.25354922]])

In [2]:
import src.recipes as recipes
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
from datetime import datetime

pd.set_option('mode.chained_assignment', 'raise')

In [5]:
!ls ./epidemics

[34mcomplete[m[m                          trials-10-powers-23-09-21.xlsx
fine_grain_sar.parquet            trials-10-powers-23-09-23.xlsx
[34mparts[m[m                             trials-300-powers-23-09-55.xlsx
trials-10-powers-23-09-08.xlsx    trials-300-powers-24-13-59.xlsx
trials-10-powers-23-09-11.xlsx    ~$trials-300-powers-23-09-55.xlsx
trials-10-powers-23-09-14.xlsx


In [8]:
results = recipes.Results.load('./epidemics/final', 'fine_grain_sar.parquet')

In [9]:
results.find_frequencies(inplace=True)

s80   p80   SAR   size  infections
0.02  0.02  0.01  2     1             0.997812
                        2             0.002188
                  3     1             0.994432
                        2             0.005052
                        3             0.000516
                                        ...   
0.80  0.80  0.60  3     3             0.761556
                  4     1             0.081860
                        2             0.010840
                        3             0.014612
                        4             0.892688
Name: count, Length: 1713577, dtype: float64

# Power calculations

In [10]:
def restrict_parameters(base_results, included_parameters):
    freqs = base_results.df['frequency'].copy()

    for parameter in set(base_results.metadata.parameters) - set(included_parameters):
        if parameter not in ['s80', 'p80']:
            raise ValueError("can't exclude SAR as it has no default hypothesis.")
        parameter_level = freqs.index.get_level_values(base_results.metadata.parameters.index(parameter))
        freqs = freqs[(parameter_level == 0.8)]

    return freqs

def restrict_on_sizes(frequencies, included_sizes):
    frequencies = frequencies[frequencies.index.get_level_values('size').isin(included_sizes)]
    return frequencies

In [11]:
import src.likelihood as likelihood

def SAR_pvalue_for_trial(baseline_logl, comparison_logl, for_increase=False):
    baseline_posterior = np.exp(baseline_logl.sort_values(ascending=False)-baseline_logl.max())
    baseline_posterior = baseline_posterior/baseline_posterior.sum()
    # we groupby 'SAR' and sum so that we can capture all the probability at that SAR — regardless of other parameter values
    baseline_probability_over_sars = baseline_posterior.groupby('SAR').sum()

    comparison_posterior = np.exp(comparison_logl.sort_values(ascending=False)-comparison_logl.max())
    #print(baseline_logl.idxmax(), comparison_logl.idxmax())
    #if baseline_logl.idxmax()[3] == 0.01:
    #    import pdb; pdb.set_trace()
    #import pdb; pdb.set_trace()
    baseline_SAR_confidence_interval = likelihood.confidence_interval_from_confidence_mask(likelihood.confidence_mask_from_logl(baseline_logl, percentiles=(0.9,)), key='SAR')
    comparison_SAR_confidence_interval = likelihood.confidence_interval_from_confidence_mask(likelihood.confidence_mask_from_logl(comparison_logl, percentiles=(0.9,)), key='SAR')
    comparison_posterior = comparison_posterior/comparison_posterior.sum()
    probability_over_sars = comparison_posterior.groupby('SAR').sum()

    # use the probability surface to generate imagined MLEs
    sample1 = np.random.choice(baseline_probability_over_sars.index, 10000, p=baseline_probability_over_sars)
    sample2 = np.random.choice(probability_over_sars.index, 10000, p=probability_over_sars)

    # what fraction of the time does the first group have a increased/decreased SAR compared to the second group
    if for_increase:
        pvalue = np.count_nonzero((sample2-sample1) > 0)/len(sample1)
    else:
        pvalue = np.count_nonzero((sample2-sample1) < 0)/len(sample1)

    return pvalue, baseline_SAR_confidence_interval, comparison_SAR_confidence_interval

interval_notes = defaultdict(list)

def calculate_power_over_SAR_range(population, trials, basline_parameters, sar_range, hypotheses, frequencies_by_hypothesis, for_increase=False):
    pvalue_sets = []
    for hypothesis_name in hypotheses.keys():
        frequencies = frequencies_by_hypothesis[hypothesis_name]
        for sar in sar_range:
            # replace baseline sar with target sar
            parameters = list(basline_parameters)
            parameters[results.metadata.parameters.index('SAR')] = float(f'{sar:0.3f}')
            parameters = tuple(parameters)
            #print(parameters)
   
            # get imagined infections from the simulated data at the baseline parameters to establish the probability surface for the MLE w.r.t. the baseline
            samples = results.resample(basline_parameters, population, trials=trials)
            baseline_logl = likelihood.logl_from_frequencies_and_counts(frequencies, samples['count'], results.metadata.parameters)

            # get imagined infections from the simulated data at the comparison parameters to establish the probability surface for the MLE w.r.t. the comparison point
            samples = results.resample(parameters, population, trials=trials)
            logl = likelihood.logl_from_frequencies_and_counts(frequencies, samples['count'], results.metadata.parameters)

            comparison_logl_grouped = logl.groupby('trial')
            single_trial_pvalues = []
            for key, baseline_logl_trial_group in baseline_logl.groupby('trial'):
                comparison_logl_trial_group = comparison_logl_grouped.get_group(key)
                pvalue, baseline_SAR_confidence_interval, comparison_SAR_confidence_interval = SAR_pvalue_for_trial(baseline_logl_trial_group, comparison_logl_trial_group, for_increase=for_increase)
                single_trial_pvalues.append(pvalue)
            #index = pd.MultiIndex.from_product([sar, hypothesis_name, list(range(trials))], names=['SAR', 'hypothesis', 'trial'])
            #pvalue_sets.append(pd.Series(data=single_trial_pvalues, index=index))
            pvalue_sets.append(pd.DataFrame({'pvalue':single_trial_pvalues, 'SAR':sar, 'hypothesis':hypothesis_name, 'trial':list(range(trials))}))
    df_piece = pd.concat(pvalue_sets)
    return df_piece

In [17]:
sar_range = np.linspace(0.10, 0.25, 4)
trials = 300
power_pvalue = 0.9

# no, medium, and high heterogeneity as defined in the paper
baseline_parameter_sets = [
    (0.8, 0.8, 0.25),
    (0.8, 0.2, 0.25),
    #(0.5, 0.5, 0.25),
    #(0.2, 0.2, 0.25),
]

# target population = 216 (divisible by 36)
populations = [
    {2: 36, 3:24, 4:18},
]

for p in populations:
    assert sum([k*v for k,v in p.items()]) == 216

all_sizes = set()
for p in populations:
    all_sizes = all_sizes.union(set(p.keys()))

print(all_sizes)

hypotheses = {
    'all': ['s80', 'p80', 'SAR'],
    'inf-and-SAR-vary': ['p80', 'SAR'],
    'sus-and-SAR-vary': ['s80', 'SAR'],
    'only-SAR-varies': ['SAR'],
}
frequencies_by_hypothesis = {k: restrict_parameters(results, included_parameters) for k,included_parameters in hypotheses.items()}

frequencies_by_hypothesis = {k: restrict_on_sizes(f, all_sizes) for k,f in frequencies_by_hypothesis.items()}


pvalue_dfs = []
power_dfs = defaultdict(list)

pvalue_df_pieces = []
for baseline_parameters in baseline_parameter_sets:
    for population in populations:
        print(population)
        pvalue_df_piece = calculate_power_over_SAR_range(population, trials, baseline_parameters, sar_range, hypotheses, frequencies_by_hypothesis)
        pvalue_df_piece['parameters'] = str(baseline_parameters)
        pvalue_df_piece['population'] = str(population)
        #print(pvalue_df_piece)
        pvalue_df_pieces.append(pvalue_df_piece)
        #pvalue_df = pd.DataFrame(pvalues_for_decrease, index=[float(f'{sar:0.3f}') for sar in sar_range]).transpose()
        #pvalue_dfs.append(pvalue_df)
        #power = ((pvalue_df > power_pvalue).sum()/trials)
        #power.name = str(population)
        #power_dfs[baseline_parameters].append(power)
    #import pdb; pdb.set_trace()
pvalue_df = pd.concat(pvalue_df_pieces)
pvalue_df = pvalue_df.set_index(['population', 'parameters', 'hypothesis', 'SAR', 'trial']).squeeze().unstack([0,1,2,3])
pvalue_df = (pvalue_df > 0.9).sum()/trials
pvalue_df.name = 'power'

path = f'./epidemics/fine-sar-trials-{trials}-powers-' + datetime.strftime(datetime.now(), '%d-%H-%M') + '.xlsx'
pvalue_df.unstack([1,2]).round(2).to_excel(path)

{2, 3, 4}
{2: 36, 3: 24, 4: 18}
{2: 36, 3: 24, 4: 18}


In [18]:
raw_df = pd.concat(pvalue_df_pieces)
raw_df = raw_df.set_index(['population', 'parameters', 'hypothesis', 'SAR', 'trial']).squeeze().unstack([0,1,2,3])

In [19]:
raw_df

population,"{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}","{2: 36, 3: 24, 4: 18}"
parameters,"(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)",...,"(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)"
hypothesis,all,all,all,all,inf-and-SAR-vary,inf-and-SAR-vary,inf-and-SAR-vary,inf-and-SAR-vary,sus-and-SAR-vary,sus-and-SAR-vary,...,inf-and-SAR-vary,inf-and-SAR-vary,sus-and-SAR-vary,sus-and-SAR-vary,sus-and-SAR-vary,sus-and-SAR-vary,only-SAR-varies,only-SAR-varies,only-SAR-varies,only-SAR-varies
SAR,0.10,0.15,0.20,0.25,0.10,0.15,0.20,0.25,0.10,0.15,...,0.20,0.25,0.10,0.15,0.20,0.25,0.10,0.15,0.20,0.25
trial,Unnamed: 1_level_4,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4,Unnamed: 14_level_4,Unnamed: 15_level_4,Unnamed: 16_level_4,Unnamed: 17_level_4,Unnamed: 18_level_4,Unnamed: 19_level_4,Unnamed: 20_level_4,Unnamed: 21_level_4
0,0.9606,0.6958,0.1750,0.3960,0.9846,0.9801,0.9860,0.7545,0.9977,0.7556,...,0.1244,0.8891,0.9293,0.8558,0.5461,0.3065,0.9996,0.9668,0.6785,0.2314
1,0.9811,0.8471,0.9679,0.6649,0.9997,0.8847,0.9906,0.0968,0.9997,0.8378,...,0.5315,0.2404,0.9955,0.8771,0.8033,0.8793,1.0000,0.9994,0.9626,0.5046
2,0.9993,0.9285,0.0622,0.4903,0.9945,0.9657,0.9082,0.2761,0.9968,0.8378,...,0.8257,0.1683,0.8729,0.5884,0.6401,0.8616,0.8820,0.9873,0.4562,0.7678
3,0.9890,0.9289,0.6319,0.3302,0.9999,0.9955,0.9099,0.8587,0.9939,0.9308,...,0.8607,0.5081,0.9714,0.7296,0.8730,0.0238,0.9899,0.9933,0.9100,0.9998
4,0.9966,0.8906,0.7687,0.5159,0.9988,0.9678,0.8826,0.2210,0.9944,0.7987,...,0.9621,0.5145,0.9990,0.9923,0.7534,0.6472,0.9998,0.9847,0.9617,0.0395
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,0.9922,0.8615,0.5306,0.3607,0.9943,0.9975,0.5799,0.1523,0.9044,0.9356,...,0.9398,0.3718,0.9997,0.5781,0.8639,0.6454,1.0000,0.9858,0.9212,0.1320
296,0.9924,0.8303,0.9283,0.7750,0.8741,0.9725,0.7086,0.5703,0.9951,0.8165,...,0.6939,0.4529,0.7273,0.9746,0.9364,0.7792,0.9985,0.9721,0.9200,0.3861
297,0.9903,0.9926,0.7694,0.5726,1.0000,0.9765,0.6964,0.5825,0.9261,0.5363,...,0.8805,0.2828,0.9773,0.2323,0.6317,0.3408,0.9982,0.9933,0.9096,0.1122
298,0.9911,0.7642,0.7056,0.3332,1.0000,0.9888,0.5154,0.3462,0.8999,0.7502,...,0.8435,0.1881,0.9444,0.9978,0.8061,0.8550,0.9997,0.9873,0.8948,0.2106


In [13]:
pvalue_df.unstack([1,2]).round(2)

Unnamed: 0_level_0,parameters,"(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.8, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)","(0.8, 0.2, 0.25)"
Unnamed: 0_level_1,hypothesis,all,inf-and-SAR-vary,sus-and-SAR-vary,only-SAR-varies,all,inf-and-SAR-vary,sus-and-SAR-vary,only-SAR-varies
population,SAR,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
"{2: 36, 3: 24, 4: 18}",0.1,0.9,0.99,0.88,1.0,0.88,0.94,0.84,0.96
"{2: 36, 3: 24, 4: 18}",0.15,0.59,0.85,0.47,0.85,0.59,0.7,0.45,0.75
"{2: 36, 3: 24, 4: 18}",0.2,0.18,0.39,0.16,0.4,0.23,0.33,0.18,0.34
"{2: 36, 3: 24, 4: 18}",0.25,0.02,0.09,0.04,0.08,0.06,0.09,0.05,0.1


In [None]:
pvalue_df.unstack([1,2]).round(2).to_excel('./figures/powers/powers_1000_trials_SAR_25_America_fixed_0s_fixed_pop_part2.xlsx')

In [59]:
results.df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,count,sus_variance,inf_variance,beta,inf_constant_value,sus_constant_value,frequency
s80,p80,SAR,size,infections,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
0.02,0.02,0.01,2,1,399026.0,1000.0,1000.0,0.009,,,0.997565
0.02,0.02,0.01,2,2,974.0,1000.0,1000.0,0.009,,,0.002435
0.02,0.02,0.01,3,1,397577.0,1000.0,1000.0,0.009,,,0.993942
0.02,0.02,0.01,3,2,2192.0,1000.0,1000.0,0.009,,,0.005480
0.02,0.02,0.01,3,3,231.0,1000.0,1000.0,0.009,,,0.000577
...,...,...,...,...,...,...,...,...,...,...,...
0.80,0.80,0.60,8,4,0.0,,,0.165,1.0,1.0,0.000000
0.80,0.80,0.60,8,5,0.0,,,0.165,1.0,1.0,0.000000
0.80,0.80,0.60,8,6,0.0,,,0.165,1.0,1.0,0.000000
0.80,0.80,0.60,8,7,0.0,,,0.165,1.0,1.0,0.000000
