# Tools for SUITE Risk-Limiting Election Audits



In [1]:
from __future__ import print_function

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display

from collections import OrderedDict
from itertools import product
import math

import numpy as np
from ballot_comparison import ballot_comparison_pvalue
from fishers_combination import fisher_combined_pvalue, maximize_fisher_combined_pvalue, \
     bound_fisher_fun, calculate_lambda_range    
from sprt import ballot_polling_sprt

from cryptorandom.cryptorandom import SHA256
from cryptorandom.sample import sample_by_index

  return f(*args, **kwds)


In [2]:
# global audit parameters

seed = 12345678901234567890  # use, e.g., 20 rolls of a 10-sided die
risk_limit = 0.05    # risk limit

gamma=1.03905         # gamma from Lindeman and Stark (2012)

lambda_step = 0.05    # stepsize for the discrete bounds on Fisher's combining function

# assumptions for finding initial sample sizes

o1_rate = 0.002       # expect 2 1-vote overstatements per 1000 ballots in the CVR stratum
o2_rate = 0           # expect 0 2-vote overstatements
u1_rate = 0           # expect 0 1-vote understatements
u2_rate = 0           # expect 0 2-vote understatements

stratum_sizes = [100000, 5000]  # total ballots in the two strata, CVR, no-CVR

n_ratio = stratum_sizes[0]/np.sum(stratum_sizes) 
                     # allocate sample in proportion to ballots cast in each stratum

# contest-specific parameters
num_winners = 2       # maximum number of winners, per social choice function

# Initial sample size

Reported Votes

In [3]:
# input number of winners
# input names as well as reported votes in each stratum

# candidates are a dict with name, [votes_in_stratum_1, votes_in_stratum_2]
candidates = { "candidate 3": [30000, 500],
               "candidate 2": [50000, 1000],
               "candidate 1": [10000, 500],
               "candidate 4": [500, 10]}

cvr_votes = poll_votes = 0

for name, votes in candidates.items():
    cvr_votes += votes[0]
    poll_votes += votes[1]
    votes.append(votes[0]+votes[1])

In [4]:
assert cvr_votes <= stratum_sizes[0]
assert poll_votes <= stratum_sizes[1]
assert (n_ratio >= 0) and (n_ratio <= 1)

In [5]:
# find reported winners, losers, and pairwise margins

candidates = OrderedDict(sorted(candidates.items(), key=(lambda t: t[1][2]), reverse=True))
                        
winners = list(candidates.keys())[0:num_winners]
losers = list(candidates.keys())[num_winners:]

votes = np.zeros(3)

margins = {}  # will hold the (winner, loser) pairwise margins indexed by pairs

for x in product(winners, losers):
    margins[x] = candidates[x[0]][2] - candidates[x[1]][2]

margins = OrderedDict(sorted(margins.items(), key=(lambda t: t[1]), reverse=True))

min_margin = np.amin(list(margins.values()))

print('\nTotal reported votes:\n\t\t\tCVR\tno-CVR\ttotal')
for k, v in candidates.items():
    print('\t', k, ':', v[0], '\t', v[1], '\t', v[2])
print('\n\t total votes:\t', cvr_votes, '\t', poll_votes, '\t', cvr_votes + poll_votes)
print('\n\t non-votes:\t',\
      stratum_sizes[0] - cvr_votes,'\t',\
      stratum_sizes[1] - poll_votes,'\t',\
      stratum_sizes[0] + stratum_sizes[1] - cvr_votes - poll_votes\
     )



print('\nwinners:')
for w in winners:
    print('\t',w)

print('\nlosers:')
for ell in losers:
    print('\t',ell)

print('\n\nmargins:')
for k, v in margins.items():
    dum = k[0] + ' beat ' + k[1] + ' by'
    print('\t', dum , v, 'votes')

print('\nsmallest margin:', min_margin, '\ndiluted margin:', min_margin/np.sum(stratum_sizes))


Total reported votes:
			CVR	no-CVR	total
	 candidate 2 : 50000 	 1000 	 51000
	 candidate 3 : 30000 	 500 	 30500
	 candidate 1 : 10000 	 500 	 10500
	 candidate 4 : 500 	 10 	 510

	 total votes:	 90500 	 2010 	 92510

	 non-votes:	 9500 	 2990 	 12490

winners:
	 candidate 2
	 candidate 3

losers:
	 candidate 1
	 candidate 4


margins:
	 candidate 2 beat candidate 4 by 50490 votes
	 candidate 2 beat candidate 1 by 40500 votes
	 candidate 3 beat candidate 4 by 29990 votes
	 candidate 3 beat candidate 1 by 20000 votes

smallest margin: 20000 
diluted margin: 0.19047619047619047


Expected sample sizes

In [6]:
def estimate_n(N_w1, N_w2, N_l1, N_l2, N1, N2,\
               o1_rate=0, o2_rate=0, u1_rate=0, u2_rate=0,\
               n_ratio=None,
               risk_limit=0.05,\
               gamma=1.03905,\
               stepsize=0.05):
    '''
    Estimate the initial sample sizes for the audit.
    
    Parameters
    ----------
    N_w1 : int
        votes for the reported winner in the ballot comparison stratum
    N_w2 : int
        votes for the reported winner in the ballot polling stratum
    N_l1 : int
        votes for the reported loser in the ballot comparison stratum
    N_l2 : int
        votes for the reported loser in the ballot polling stratum
    N1 : int
        total number of votes in the ballot comparison stratum
    N2 : int
        total number of votes in the ballot polling stratum
    o1_rate : float
        expected percent of ballots with 1-vote overstatements in the CVR stratum
    o2_rate : float
        expected percent of ballots with 2-vote overstatements in the CVR stratum
    u1_rate : float
        expected percent of ballots with 1-vote understatements in the CVR stratum
    u2_rate : float
        expected percent of ballots with 2-vote understatements in the CVR stratum
    n_ratio : float
        ratio of sample allocated to each stratum.
        If None, allocate sample in proportion to ballots cast in each stratum
    risk_limit : float
        risk limit
    gamma : float
        gamma from Lindeman and Stark (2012)
    stepsize : float
        stepsize for the discrete bounds on Fisher's combining function
    Returns
    -------
    tuple : estimated initial sample sizes in the CVR stratum and no-CVR stratum
    '''
    n_ratio = n_ratio if n_ratio else N1/(N1+N2)
    n = 5
    reported_margin = (N_w1+N_w2)-(N_l1+N_l2)
    expected_pvalue = 1
    
    def try_n(n):
        n1 = math.ceil(n_ratio * n)
        n2 = n - n1
        if n1 < 1 or n2 < 1:
            return 1
        o1 = math.ceil(o1_rate*n1)
        o2 = math.ceil(o2_rate*n1)
        u1 = math.floor(u1_rate*n1)
        u2 = math.floor(u2_rate*n1)

        cvr_pvalue = lambda alloc: ballot_comparison_pvalue(n=n1, gamma=1.03905, o1=o1, 
                                                    u1=u1, o2=o2, u2=u2, 
                                                    reported_margin=reported_margin, N=N1, 
                                                    null_lambda=alloc)
        sample = sample = np.array([0]*int(n2*N_l2/N2)+[1]*int(n2*N_w2/N2)+\
                                             [np.nan]*int(n2*(N2-N_l2-N_w2)/N2))
        if len(sample) < 1:
            return 1
        nocvr_pvalue = lambda alloc: ballot_polling_sprt(sample=sample, \
                            popsize=N2, \
                            alpha=risk_limit,\
                            Vw=N_w2, Vl=N_l2, null_margin=(N_w2-N_l2) - alloc*reported_margin)['pvalue']
        # Crude maximizer for now 
        # TO DO: this isn't rigorous yet--needs to be fixed
        res = bound_fisher_fun(N_w1, N_l1, N1, N_w2, N_l2, N2,
                       pvalue_funs=(cvr_pvalue, nocvr_pvalue),\
                       stepsize=stepsize, feasible_lambda_range=None)
        expected_pvalue = np.amax(res['upper_bounds'])
        if (n % 10000)==0:
            print('...trying...', n, expected_pvalue)
        return(expected_pvalue)
    
    # step 1: linear search, doubling n each time
    while (expected_pvalue > risk_limit) or (expected_pvalue is np.nan):
        n = 2*n
        expected_pvalue = try_n(n)
        
    # step 2: bisection between n/2 and n
    low_n = n/2
    high_n = n
    mid_pvalue = 1
    # TODO: should there be a tolerance here? e.g. risk_limit - 0.005 
    while  (mid_pvalue > risk_limit) or (expected_pvalue is np.nan):
        mid_n = np.floor((low_n+high_n)/2)
        mid_pvalue = try_n(mid_n)
        if mid_pvalue <= risk_limit:
            high_n = mid_n
        else:
            low_n = mid_n
    
    n1 = math.ceil(n_ratio * mid_n)
    n2 = int(mid_n - n1)
    return (n1, n2)

In [7]:
# Find largest expected sample size across (winner, loser) pairs

sample_sizes = {}

for k in product(winners, losers):
    sample_sizes[k] = estimate_n(N_w1 = candidates[k[0]][0],\
                                 N_w2 = candidates[k[0]][1],\
                                 N_l1 = candidates[k[1]][0],\
                                 N_l2 = candidates[k[1]][1],\
                                 N1 = stratum_sizes[0],\
                                 N2 = stratum_sizes[1],\
                                 o1_rate = o1_rate,\
                                 o2_rate = o2_rate,\
                                 u1_rate = u1_rate,\
                                 u2_rate = u2_rate,\
                                 n_ratio = n_ratio,\
                                 risk_limit = risk_limit,\
                                 gamma = gamma,\
                                 stepsize = lambda_step)

In [8]:
sample_size = np.amax([v[0]+v[1] for v in sample_sizes.values()])
n1 = math.ceil(sample_size*n_ratio)    
n2 = sample_size-n1

print(sample_sizes, '\nexpected minimum sample size:', sample_size)

{('candidate 3', 'candidate 1'): (67, 3), ('candidate 3', 'candidate 4'): (58, 2), ('candidate 2', 'candidate 4'): (58, 2), ('candidate 2', 'candidate 1'): (58, 2)} 
expected minimum sample size: 70


# Random sampling

If this section is giving errors, you probably need to update your version of `cryptorandom`.

```
pip install [--update] cryptorandom
```

In [9]:
prng = SHA256(seed)   # initialize the PRNG

In [10]:
# CVR stratum initial sample size, sampled with replacement
sample1 = prng.randint(1, stratum_sizes[0]+1, size=n1)

# No-CVR ballots are sampled without replacement
sample2 = sample_by_index(stratum_sizes[1], n2, prng)

Stratum 1 sample

In [11]:
print("CVR stratum sample:\n", sample1)

CVR stratum sample:
 [76116 45424 33501 45326  2081 56264 25122 16602 79743 61814 57922 41676
 95332 38891 17757 64352 84257 47365 10908 97791 77941 73573 51855 88527
 35549 20934 61419 70683 70220 45067 67903 94304 20823 50570 88735  9973
 44578 34320  8262 32532 85102 87511 63375 96612 52917 91127 84152 74227
 76674 76640 62444 83868  3974 81503 82205 41161 28136 12244 97608  9057
 43082  6522 39347 45600 57836 10233 75516]


In [12]:
print("CVR stratum sample, sorted:\n", np.sort(sample1))

CVR stratum sample, sorted:
 [ 2081  3974  6522  8262  9057  9973 10233 10908 12244 16602 17757 20823
 20934 25122 28136 32532 33501 34320 35549 38891 39347 41161 41676 43082
 44578 45067 45326 45424 45600 47365 50570 51855 52917 56264 57836 57922
 61419 61814 62444 63375 64352 67903 70220 70683 73573 74227 75516 76116
 76640 76674 77941 79743 81503 82205 83868 84152 84257 85102 87511 88527
 88735 91127 94304 95332 96612 97608 97791]


In [13]:
print("CVR stratum sample, sorted, duplicates removed:\n", np.unique(np.sort(sample1)))

CVR stratum sample, sorted, duplicates removed:
 [ 2081  3974  6522  8262  9057  9973 10233 10908 12244 16602 17757 20823
 20934 25122 28136 32532 33501 34320 35549 38891 39347 41161 41676 43082
 44578 45067 45326 45424 45600 47365 50570 51855 52917 56264 57836 57922
 61419 61814 62444 63375 64352 67903 70220 70683 73573 74227 75516 76116
 76640 76674 77941 79743 81503 82205 83868 84152 84257 85102 87511 88527
 88735 91127 94304 95332 96612 97608 97791]


In [14]:
m = np.zeros_like(sample1, dtype=bool)
m[np.unique(sample1, return_index=True)[1]] = True
print("Stratum 1 repeated ballots:\n", sample1[~m])

Stratum 1 repeated ballots:
 []


Stratum 2 sample

In [15]:
print("No-CVR stratum sample:\n", sample2)

No-CVR stratum sample:
 [1783 4275 4915]


In [16]:
print("No-CVR stratum sample, sorted:\n", np.sort(sample2))

No-CVR stratum sample, sorted:
 [1783 4275 4915]


# Find ballots using ballot manifest

Ballot manifest: Each line must have a batch label, a comma, and one of the following:
  1. the number of ballots in the batch 
  1. a range specified with a colon (e.g., 131:302), or 
  1. a list of ballot identifiers within parentheses, separated by spaces (e.g., (996 998 1000)).
  
Each line should have exactly one comma.

In [17]:
# I'm imagining this is is a list for now
ballot_manifest_cvr = ['1, 10000', '2, 10001:99998', '3, (205 210)']
ballot_manifest_poll = ['1, 1000', '2, 1001:4998', '3, (205 210)']

In [18]:
# step 1: expand the ballot manifest into a dict. keys are batches, values are ballot numbers.

def parse_manifest(manifest):
    '''
    Parses a ballot manifest to put the batches in a canonical ordering and be able
    to look up specific ballots.
    
    Input
    -----
    a ballot manifest in the syntax described above
    
    Returns
    -------
    an ordered dict containing batch ID (key) and ballot identifiers within the batch, 
    either from sequential enumeration or from the given labels. Identifiers are not
    necessarily unique *across* batches.
    '''
    ballots = 0
    ballot_manifest_dict = OrderedDict()
    for i in manifest:
        # assert that the entry is a string with a comma in it
        # pull out batch label
        (batch, val) = i.split(",")
        batch = batch.strip()
        val = val.strip()    
        if (batch in ballot_manifest_dict.keys()):
             raise ValueError('batch is listed more than once')
        else:
            ballot_manifest_dict[batch] = []
    
        # parse what comes after the batch label
        if '(' in val:     # list of identifiers
            val = val[1:-1] # strip out the parentheses  TO DO: use regex to remove )(
            ballot_manifest_dict[batch] += [int(num) for num in val.split()]
            
        elif ':' in val:   # range of identifiers
            limits = val.split(':')
            ballot_manifest_dict[batch] += list(range(int(limits[0]), int(limits[1])+1))  
            
        else:  # this should be an integer number of ballots
            try:
                ballot_manifest_dict[batch] += list(range(1, int(val)+1))
            except:
                print('malformed row in ballot manifest:\n\t', i)
    return(ballot_manifest_dict)

In [19]:
cvr_manifest_parsed = parse_manifest(ballot_manifest_cvr)
poll_manifest_parsed = parse_manifest(ballot_manifest_poll)

In [20]:
# count ballots listed in the manifests
listed_cvr = np.sum([len(v) for v in cvr_manifest_parsed.values()])
listed_poll = np.sum([len(v) for v in poll_manifest_parsed.values()])

# test that manifest matches reported ballot totals

assert listed_cvr == stratum_sizes[0]
assert listed_poll == stratum_sizes[1]

In [21]:
# step 2: give ballots unique IDs

def unique_manifest(parsed_manifest):
    second_manifest = {}
    ballots_counted = 0
    for batch in parsed_manifest.keys():
        batch_size = len(parsed_manifest[batch])
        second_manifest[batch] = list(range(ballots_counted + 1, ballots_counted + batch_size + 1))
        ballots_counted += batch_size
    return(second_manifest)

In [22]:
unique_cvr_manifest = unique_manifest(cvr_manifest_parsed)
unique_poll_manifest = unique_manifest(poll_manifest_parsed)

In [23]:
# step 3: look up sample values

def find_ballot(ballot_num, unique_ballot_manifest, parsed_ballot_manifest):
    '''
    Find ballot in the batches
    
    Input
    -----
    ballot_num : int
        a ballot number that was sampled
    unique_ballot_manifest : dict
        ballot manifest with unique IDs across batches
    parsed_ballot_manifest : dict
        ballot manifest with original ballot IDs supplied in the manifest
        
    Returns
    -------
    tuple : (original_ballot_label, batch_label, which_ballot_in_batch)
    '''
    for batch, ballots in unique_ballot_manifest.items():
        if ballot_num in ballots:
            position = ballots.index(ballot_num)
            original_ballot_label = parsed_ballot_manifest[batch][position]
            return (original_ballot_label, batch, position)
    print("Ballot %i not found" % ballot_num)
    return None

In [24]:
print("CVR Stratum")
print("sampled ballot, original ballot label, batch label, which ballot in batch")
i = 0
for s in sample1:
    i += 1
    original_ballot_label, batch_label, which_ballot = find_ballot(s, \
                                                                   unique_cvr_manifest, \
                                                                   cvr_manifest_parsed)
    print(s, original_ballot_label, batch_label, which_ballot) # This uses 0-indexing still. Should we change it be 1,...,n?

CVR Stratum
sampled ballot, original ballot label, batch label, which ballot in batch
76116 76116 2 66115
45424 45424 2 35423
33501 33501 2 23500
45326 45326 2 35325
2081 2081 1 2080
56264 56264 2 46263
25122 25122 2 15121
16602 16602 2 6601
79743 79743 2 69742
61814 61814 2 51813
57922 57922 2 47921
41676 41676 2 31675
95332 95332 2 85331
38891 38891 2 28890
17757 17757 2 7756
64352 64352 2 54351
84257 84257 2 74256
47365 47365 2 37364
10908 10908 2 907
97791 97791 2 87790
77941 77941 2 67940
73573 73573 2 63572
51855 51855 2 41854
88527 88527 2 78526
35549 35549 2 25548
20934 20934 2 10933
61419 61419 2 51418
70683 70683 2 60682
70220 70220 2 60219
45067 45067 2 35066
67903 67903 2 57902
94304 94304 2 84303
20823 20823 2 10822
50570 50570 2 40569
88735 88735 2 78734
9973 9973 1 9972
44578 44578 2 34577
34320 34320 2 24319
8262 8262 1 8261
32532 32532 2 22531
85102 85102 2 75101
87511 87511 2 77510
63375 63375 2 53374
96612 96612 2 86611
52917 52917 2 42916
91127 91127 2 81126
84152 8

In [25]:
print("Polling Stratum")
print("sampled ballot, original ballot label, batch label, which ballot in batch")
i = 0
for s in sample2:
    i += 1
    original_ballot_label, batch_label, which_ballot = find_ballot(s, \
                                                                   unique_poll_manifest, \
                                                                   poll_manifest_parsed)
    print(i, s, batch_label, which_ballot) # This uses 0-indexing still. Should we change it be 1,...,n?

Polling Stratum
sampled ballot, original ballot label, batch label, which ballot in batch
1 1783 2 782
2 4275 2 3274
3 4915 2 3914


# Enter the sample data

Sample statistics for the CVR stratum (stratum 1)

In [26]:
# Number of observed...

o1 = 1 # 1-vote overstatements
o2 = 0 # 2-vote overstatements
u1 = 0 # 1-vote understatements
u2 = 0 # 2-vote understatements

Sample statistics for the no-CVR stratum (stratum 2)

In [27]:
# Number of votes for each candidate
# recall that in the provided example, n2=3 so the totals here must add up to <= 3.

# no-CVR sample is stored in a dict with name, votes in the sample
observed_poll = { "candidate 3": 1,
               "candidate 2": 2,
               "candidate 1": 0,
               "candidate 4": 0}

# Should more ballots be audited?

In [28]:
# Find audit p-values across (winner, loser) pairs

audit_pvalues = {}

for k in product(winners, losers):
    N_w1 = candidates[k[0]][0]
    N_w2 = candidates[k[0]][1]
    N_l1 = candidates[k[1]][0]
    N_l2 = candidates[k[1]][1]
    reported_margin = (N_w1+N_w2)-(N_l1+N_l2)
    cvr_pvalue = lambda alloc: ballot_comparison_pvalue(n=n1, gamma=1.03905, o1=o1, 
                                                    u1=u1, o2=o2, u2=u2, 
                                                    reported_margin=reported_margin, 
                                                    N=stratum_sizes[0], 
                                                    null_lambda=alloc)
    
    n2w = observed_poll[k[0]]
    n2l = observed_poll[k[1]]
    nocvr_pvalue = lambda alloc: ballot_polling_sprt(\
                            sample= np.array([0]*n2l+[1]*n2w+[np.nan]*(n2-n2w-n2l)), \
                            popsize=stratum_sizes[1], \
                            alpha=0.05,  # set this param but we don't need to use it
                            Vw=N_w2, Vl=N_l2, \
                            null_margin=(N_w2-N_l2) - alloc*reported_margin)['pvalue']
    # Crude maximizer for now
    res = bound_fisher_fun(N_w1, N_l1, stratum_sizes[0], N_w2, N_l2, stratum_sizes[1],
                       pvalue_funs=(cvr_pvalue, nocvr_pvalue), stepsize=lambda_step,\
                       feasible_lambda_range=None)
    audit_pvalues[k] = np.max(res['upper_bounds'])
                           
audit_pvalues

{('candidate 2', 'candidate 1'): 3.662272354854057e-05,
 ('candidate 2', 'candidate 4'): 5.15130637146477e-07,
 ('candidate 3', 'candidate 1'): 0.02115848454394742,
 ('candidate 3', 'candidate 4'): 0.0008500403360027775}

# Escalation guidance: how many more ballots should be drawn?

In [29]:
def estimate_escalation_n(N_w1, N_w2, N_l1, N_l2, N1, N2, n1, n2, \
                          o1_obs, o2_obs, u1_obs, u2_obs, \
                          n2l_obs, n2w_obs, \
                          o1_rate=0, o2_rate=0, u1_rate=0, u2_rate=0, \
                          n_ratio=None,
                          risk_limit=0.05,\
                          gamma=1.03905,\
                          stepsize=0.05):
    '''
    Estimate the initial sample sizes for the audit.
    
    Parameters
    ----------
    N_w1 : int
        votes for the reported winner in the ballot comparison stratum
    N_w2 : int
        votes for the reported winner in the ballot polling stratum
    N_l1 : int
        votes for the reported loser in the ballot comparison stratum
    N_l2 : int
        votes for the reported loser in the ballot polling stratum
    N1 : int
        total number of votes in the ballot comparison stratum
    N2 : int
        total number of votes in the ballot polling stratum
    n1 : int
        size of sample already drawn in the ballot comparison stratum
    n2 : int
        size of sample already drawn in the ballot polling stratum
    o1_obs : int
        observed number of ballots with 1-vote overstatements in the CVR stratum
    o2_obs : int
        observed number of ballots with 2-vote overstatements in the CVR stratum
    u1_obs : int
        observed number of ballots with 1-vote understatements in the CVR stratum
    u2_obs : int
        observed number of ballots with 2-vote understatements in the CVR stratum
    n2l_obs : int
        observed number of votes for the reported loser in the no-CVR stratum
    n2w_obs : int
        observed number of votes for the reported winner in the no-CVR stratum
    o1_rate : float
        expected percent of ballots with 1-vote overstatements in the CVR stratum
    o2_rate : float
        expected percent of ballots with 2-vote overstatements in the CVR stratum
    u1_rate : float
        expected percent of ballots with 1-vote understatements in the CVR stratum
    u2_rate : float
        expected percent of ballots with 2-vote understatements in the CVR stratum
    n_ratio : float
        ratio of sample allocated to each stratum.
        If None, allocate sample in proportion to ballots cast in each stratum
    risk_limit : float
        risk limit
    gamma : float
        gamma from Lindeman and Stark (2012)
    stepsize : float
        stepsize for the discrete bounds on Fisher's combining function
    Returns
    -------
    tuple : estimated initial sample sizes in the CVR stratum and no-CVR stratum
    '''
    n_ratio = n_ratio if n_ratio else N1/(N1+N2)
    n = n1+n2
    reported_margin = (N_w1+N_w2)-(N_l1+N_l2)
    expected_pvalue = 1
    
    n1_original = n1
    n2_original = n2
    observed_nocvr_sample = [0]*n2l_obs + [1]*n2w_obs + [np.nan]*(n2_original-n2l_obs-n2w_obs)
    
    def try_n(n):
        n1 = math.ceil(n_ratio * n)
        n2 = n - n1
        o1 = math.ceil(o1_rate*(n1-n1_original)) + o1_obs
        o2 = math.ceil(o2_rate*(n1-n1_original)) + o2_obs
        u1 = math.floor(u1_rate*(n1-n1_original)) + u1_obs
        u2 = math.floor(u2_rate*(n1-n1_original)) + u2_obs

        cvr_pvalue = lambda alloc: ballot_comparison_pvalue(n=n1, gamma=1.03905, o1=o1, 
                                                    u1=u1, o2=o2, u2=u2, 
                                                    reported_margin=reported_margin, N=N1, 
                                                    null_lambda=alloc)
        expected_new_sample = [0]*int((n2-n2_original)*N_l2/N2)+\
                                             [1]*int((n2-n2_original)*N_w2/N2)+\
                                             [np.nan]*int((n2-n2_original)*(N2-N_l2-N_w2)/N2)
        nocvr_pvalue = lambda alloc: ballot_polling_sprt( \
                            sample=np.array(observed_nocvr_sample+expected_new_sample),\
                            popsize=N2, \
                            alpha=risk_limit,\
                            Vw=N_w2, Vl=N_l2, null_margin=(N_w2-N_l2) - alloc*reported_margin)['pvalue']
        # Crude maximizer for now 
        # TO DO: this isn't rigorous yet--needs to be fixed
        res = bound_fisher_fun(N_w1, N_l1, N1, N_w2, N_l2, N2,
                       pvalue_funs=(cvr_pvalue, nocvr_pvalue),\
                       stepsize=stepsize, feasible_lambda_range=None)
        expected_pvalue = np.amax(res['upper_bounds'])
        if (n % 10000)==0:
            print('...trying...', n, expected_pvalue)
        return(expected_pvalue)
    
    # step 1: linear search, increasing n by a factor of 1.1 each time
    while (expected_pvalue > risk_limit) or (expected_pvalue is np.nan):
        n = np.ceil(1.1*n)
        expected_pvalue = try_n(n)
        
    # step 2: bisection between n/1.1 and n
    low_n = n/1.1
    high_n = n
    mid_pvalue = 1
    # TODO: should there be a tolerance here? e.g. risk_limit - 0.005 
    while  (mid_pvalue > risk_limit) or (expected_pvalue is np.nan):
        mid_n = np.floor((low_n+high_n)/2)
        mid_pvalue = try_n(mid_n)
        if mid_pvalue <= risk_limit:
            high_n = mid_n
        else:
            low_n = mid_n
    
    n1 = math.ceil(n_ratio * mid_n)
    n2 = int(mid_n - n1)
    return (n1, n2)

In [30]:
# CVR stratum
n1 = 70
n2 = 3
o1 = 1
u1 = 0
o2 = 0
u2 = 0
n2l = 1
n2w = 1

sample_sizes_new = {}

for k in product(winners, losers):
    sample_sizes_new[k] = estimate_escalation_n(\
                                 N_w1 = candidates[k[0]][0],\
                                 N_w2 = candidates[k[0]][1],\
                                 N_l1 = candidates[k[1]][0],\
                                 N_l2 = candidates[k[1]][1],\
                                 N1 = stratum_sizes[0],\
                                 N2 = stratum_sizes[1],\
                                 n1 = n1,\
                                 n2 = n2,\
                                 o1_obs = o1,\
                                 o2_obs = o2,\
                                 u1_obs = u1,\
                                 u2_obs = u2,\
                                 n2l_obs = n2l,\
                                 n2w_obs = n2w,\
                                 o1_rate = o1_rate,\
                                 o2_rate = o2_rate,\
                                 u1_rate = u1_rate,\
                                 u2_rate = u2_rate,\
                                 n_ratio = n_ratio,\
                                 risk_limit = risk_limit,\
                                 gamma = gamma,\
                                 stepsize = lambda_step)

In [31]:
## TODO: Check that we like this sort of output

sample_size_new = np.amax([v[0]+v[1] for v in sample_sizes_new.values()])
n1_new = np.amax([v[0] for v in sample_sizes_new.values()])
n2_new = np.amax([v[1] for v in sample_sizes_new.values()])

print(sample_sizes_new, '\n\nExpected minimum sample size:', sample_size_new)
print("\nBallots to draw in the CVR stratum:", n1_new - n1)
print("Ballots to draw in the no-CVR stratum:", n2_new - n2)

{('candidate 3', 'candidate 1'): (76, 3), ('candidate 3', 'candidate 4'): (74, 3), ('candidate 2', 'candidate 4'): (74, 3), ('candidate 2', 'candidate 1'): (74, 3)} 

Expected minimum sample size: 79

Ballots to draw in the CVR stratum: 6
Ballots to draw in the no-CVR stratum: 0
