### Thresh algorithm

### TODO:
- Check estimation part (works well only for epoch=0)

In [1]:
import numpy as np

In [2]:
def Ber(p, size=1):
    return np.random.binomial(1,p, size=size)

In [3]:
class Users():
    def __init__(self, T, l, a, b, eps, tau, n):
        self.a = a
        self.b = b
        self.T = T
        self.eps = eps
        self.l = l
        self.cV = np.zeros((n,))
        self.cE = np.zeros((n,))
        self.tau = tau
        self.n = n
#         self.xi = np.random.randint(0,2, size=n*l).reshape(n, l)
        self.p_last = -1
        self.true_val = 0.5
        
    def local_estimate(self):
        self.last_est = self.true_val + np.random.randn(n) / 10000
        return self.last_est
#         return np.mean(self.xi[:,-self.l:], axis=1)
    
    def vote(self, t, debug=False):        
        log_T = int(np.floor(np.log(self.T)))
        diff = np.abs(self.local_estimate() - self.p_last)
        b_star = np.c_[[b*(diff > self.tau(b)) for b in range(-1, log_T+1)]].max(axis=0)
        
        print(f"b*={b_star.mean()}")
        VoteYes = (self.cV < self.eps/4) & np.logical_not(t % 2**(log_T - b_star))
        
        at = Ber(1 / (np.exp(self.a) + 1), size=self.n)
        at[VoteYes] = Ber(np.exp(self.a) / (np.exp(self.a) + 1), size=VoteYes.sum())
        if debug:
            return b_star
        return at
    
    def est(self):
        SendEstimate = self.cE < self.eps/4
        self.cE += self.b * SendEstimate
        
        p_t = Ber(1/(np.exp(self.b) + 1), size=self.n)
        p_t[SendEstimate] = Ber(
            (1 + self.local_estimate()*(np.exp(self.b)-1)) /\
            (np.exp(self.b) + 1), size=self.n
        )[SendEstimate]
        return p_t
    
    def fake_est(self, t):
        return Ber(1 / (np.exp(self.b) + 1), size=self.n)
    
    def update_p(self):
        self.p_last = self.last_est

class Aggregator():
    def __init__(self, eps, delta, num_epochs, epoch_length, num_users, min_subgroup_size):
        self.eps = eps
        self.delta = delta
        self.T = num_epochs
        self.l = epoch_length
        self.n = num_users
        self.L = min_subgroup_size
        self.m = self.n // self.L
        tmp0 = np.log(12*self.m*self.T / delta)
        self.a = 4 * np.sqrt(2 * self.n * tmp0) /\
                 (self.L - 3/np.sqrt(2)*np.sqrt(self.n * tmp0))
        # vote noise level
        tmp1 = np.log(12*self.T / delta) / 2
        self.b = np.sqrt(2 * tmp1/self.n) /\
                 (np.log(self.T)*np.sqrt(tmp1 / self.l) - np.sqrt(tmp1 / self.n))
        # estimate noise level
        
        self.last_pub = -1
        self.p_last = -1
        
        assert self.L > (3/np.sqrt(2) + np.sqrt(32)/self.eps)*np.sqrt(n*np.log(12*self.m*self.T / self.delta)), "Assumption 4.2"
        
    def init_users(self):
        def tau(b):
            return 2 * (b + 1) * np.sqrt(np.log(10 * self.n * self.T / self.delta) / (2*self.l))
        return Users(self.T, self.l, self.a, self.b, self.eps, tau, self.n)
    
    def epoch(self, t, users):
        ats = users.vote(t)
        
        print(np.mean(ats))
        
        val = (1/(np.exp(self.a) + 1) + np.sqrt(np.log(10*self.T/self.delta)/(2*self.n)))
        print(val)
        global_update = (
            np.mean(ats) > val
        )
        
        
        if global_update:
            print("Global update!")
            self.last_pub = t
            users.update_p()
            pts = users.est()
            self.p_last = np.mean( (pts*(np.exp(self.b)+1) - 1) / (np.exp(self.b) - 1) ) #not clear how this is supposed to be accurate
        else:
            pts = users.fake_est(t)
        return self.p_last

In [5]:
n = 10**6 # num users
T = 10**3 # num epochs
L = 10**5 # min subgroup size
m = 1e2 # num subgroups
l = 10**7  # epochs length
eps = 5 # privacy parameter
delta = 1e-6 # failure parameter

In [6]:
agg = Aggregator(eps, delta, T, l, n, L)

In [9]:
usrs = agg.init_users()
usrs.true_val

0.5

In [10]:
agg.epoch(0, usrs)

b*=6.0
0.579472
0.42406787050183253
Global update!


0.5001196452230695

In [11]:
usrs.true_val = 0.49999

In [12]:
agg.epoch(1, usrs)

b*=0.0
0.420463
0.42406787050183253


0.5001196452230695

In [13]:
usrs.true_val = 0.45

In [14]:
agg.epoch(2, usrs)

b*=6.0
0.580198
0.42406787050183253
Global update!


0.45078094635979704

In [15]:
usrs.true_val = 0.46

In [16]:
agg.epoch(3, usrs)

b*=2.0
0.419714
0.42406787050183253


0.45078094635979704