In [1]:
import pandas as pd
import numpy as np
import math
from matplotlib import pyplot as plt
from src.hdmm import workload, fairtemplates, error, fairmechanism, matrix, mechanism, templates

In [188]:
def pmw(workload, x, eps=10, beta=0.1, k=0, show_messages=True, to_return='pd', ):
    """
    Implement Private Multiplicative Weights Mechanism (PMW) on a workload of
    linear queries. New arguments to allow for optimizing the amount of
    privacy budget used in each step.
    
    to_return argument determines what the function will return. 
        - if 'pd', pmw() returns pandas df with test data for each 
        query (query, d_t_hat, updated, algo_ans, real_ans, abs_error, 
        rel_error). 
        - if 'update_count', pmw() returns the update count for the total
        amount of queries

    - W = workload of queries (M x k numpy array)
    - x = true database (M x 1 numpy array)
    """ 
    
    # initialize constants
    m = x.size  # database len
    n = x.sum()
    if k==0: # essentially, if k hasn't been changed from its default value, use the length of the workload
        k = len(workload)  # num of queries
    delta = 1 / (n * math.log(n, np.e))
    x_norm = x / np.sum(x)
    eta = math.log(m, np.e) ** (1 / 4) / math.sqrt(n)
    sigma = 10 * math.log(1 / delta, np.e) * (math.log(m, np.e)) ** (1 / 4) / (
            math.sqrt(n) * eps)
    threshold = 4 * sigma * (math.log(k, np.e) + math.log(1 / beta, np.e))
    
    # synthetic databases at time 0 (prior to any queries)
    y_t = np.ones(m) / m
    x_t = np.ones(m) / m

    # append to list of databases y_t and x_t
    y_list = [y_t]
    x_list = [x_t]
    
    update_list = []
    update_count = 0
    algo_answers = []
    update_times = []
    d_t_hat_list = []
    # iterate through time = (0, k)
    for time, query in enumerate(workload):

        # compute noisy answer by adding Laplacian noise
        a_t = np.random.laplace(loc=0, scale=sigma, size=1)[0]
        a_t_hat = np.dot(query, x_norm) + a_t

        # difference between noisy and maintained histogram answer
        d_t_hat = a_t_hat - np.dot(query, x_list[time])
        d_t_hat_list.append(d_t_hat)

        # lazy round: use maintained histogram to answer the query
        if abs(d_t_hat) <= threshold:
            algo_answers.append(np.dot(query, x_list[time]))
            x_list.append(x_list[time])
            update_list.append('no')
            continue

        # update round: update histogram and return noisy answer
        else:
            update_list.append('yes')
            update_times.append(time)

            # step a
            if d_t_hat > 0:
                r_t = query
            else:
                r_t = np.ones(m) - query
            for i, v in enumerate(y_t):
                y_t[i] = x_list[time][i] * math.exp(-eta * r_t[i])
            y_list.append(y_t)

            # step b
            x_t = y_t / np.sum(y_t)
            x_list.append(x_t)
        
        update_count = update_list.count('yes')

        if update_count > n * math.log(m, 10) ** (1 / 2):
            return "failure"
        else:
            algo_answers.append(a_t_hat / np.sum(x))

    # calculate error
    real_ans = np.matmul(workload, x_norm)
    abs_error = np.abs(algo_answers - real_ans)
    rel_error = np.abs(algo_answers / np.where(real_ans == 0, 0.000001,
                                                real_ans))

    def print_outputs():
        """Print inputes/outputs to analyze each query"""
        print(f'Original database: {x}\n')
        print(f'Normalized database: {x_norm}\n')
        print(f'Updated Database = {x_t}\n')
        print(f'Update Count = {update_count}\n')
        print(f'{threshold=}\n')
    
    def plot_error():
        """Plot absolute and relative error"""
        plt.xticks(range(0, k, 5))
        plt.title('Error across queries:')
        rel_line, = plt.plot(rel_error, label='Relative Error')
        abs_line, = plt.plot(abs_error, label='Absolute Error')
        for xc in update_times:
            plt.axvline(x=xc, color='red', label='Update Times', linestyle='dashed')
        plt.legend(handles=[rel_line, abs_line])
    
    if show_messages:
        print_outputs()
        plot_error()
        
    if to_return == "update_count":
        return update_count
    
    if to_return == "pd":
        d = {
            'queries': workload.tolist(), 
            'd_t_hat': d_t_hat_list, 
            'updated': update_list,
            'algo_ans': algo_answers,
            'real_ans': real_ans.tolist(),
            'abs_error': abs_error,               
            'rel_error': rel_error,
             }
        test_data = pd.DataFrame(data=d)
        return test_data

In [199]:
x_peaks = np.array([1000, 8000, 1300, 1250, 9000, 1450, 1700, 1720])
W_allrange = workload.AllRange(8).dense_matrix()

25420

In [198]:
m = x_peaks.size  # database len
n = x_peaks.sum()
n * math.log(m, 10) ** (1 / 2)

24156.89211530533

Here are some preliminary results: 

The following are results for updating the privacy budget by scales of 10. 1000 experiments were run at each privacy budget. The query used was an allrange workload with 36 total queries. The database comprised of 25420 total individuals: [1000, 8000, 1300, 1250, 9000, 1450, 1700, 1720]. 

As epsilon increases, so does the amount of times the algorithm updates. 

In [181]:
average_update_count_list = []
epsilons_to_try = [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 250, 500, 750, 1000, 10000, 100000]

def update_count_for_eps(set_eps):
    total_update_count = 0
    num_exp = 1000
    for i in range(num_exp):
        total_update_count += pmw(workload=W_allrange,
                                  x=x_peaks,
                                  eps=set_eps,
                                  show_messages=False,
                                  to_return='update_count')
    
    average_update_count = total_update_count / num_exp
    print(f'average update_count for privacy budget of {set_eps} is \t {average_update_count}')
    return average_update_count
    
for i in epsilons_to_try:
    average_update_count_list.append(update_count_for_eps(set_eps=i))

print(average_update_count_list)

average update_count for privacy budget of 0.0001 is 	 0.0
average update_count for privacy budget of 0.001 is 	 0.0
average update_count for privacy budget of 0.01 is 	 0.0
average update_count for privacy budget of 0.1 is 	 0.0
average update_count for privacy budget of 1 is 	 0.0
average update_count for privacy budget of 10 is 	 0.0
average update_count for privacy budget of 100 is 	 1.892
average update_count for privacy budget of 250 is 	 18.024
average update_count for privacy budget of 500 is 	 28.983
average update_count for privacy budget of 750 is 	 32.895
average update_count for privacy budget of 1000 is 	 34.001
average update_count for privacy budget of 10000 is 	 35.0
average update_count for privacy budget of 100000 is 	 35.0
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.892, 18.024, 28.983, 32.895, 34.001, 35.0, 35.0]


In [191]:
average_update_count_list = []
betas_to_try = [0.001, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

def update_count_for_beta(set_beta, set_eps):
    total_update_count = 0
    for i in range(1000):
        total_update_count += pmw(workload=W_allrange,
                                  x=x_peaks,
                                  eps=set_eps,
                                  beta=set_beta,
                                  show_messages=False,
                                  to_return='update_count')
    
    average_update_count = total_update_count / num_exp
    print(f'average update_count for beta value of {set_beta} is \t {average_update_count}')
    return average_update_count
    
for i in betas_to_try:
    average_update_count_list.append(update_count_for_beta(set_beta=i, set_eps=100))

print(average_update_count_list)

average update_count for beta value of 0.001 is 	 0.0
average update_count for beta value of 0.01 is 	 0.006
average update_count for beta value of 0.1 is 	 1.933
average update_count for beta value of 0.2 is 	 3.417
average update_count for beta value of 0.3 is 	 5.399
average update_count for beta value of 0.4 is 	 6.258
average update_count for beta value of 0.5 is 	 7.013
average update_count for beta value of 0.6 is 	 7.861
average update_count for beta value of 0.7 is 	 8.695
average update_count for beta value of 0.8 is 	 9.329
average update_count for beta value of 0.9 is 	 9.818
average update_count for beta value of 1 is 	 10.235
[0.0, 0.006, 1.933, 3.417, 5.399, 6.258, 7.013, 7.861, 8.695, 9.329, 9.818, 10.235]


In [197]:
average_update_count_list = []
k_to_try = [1, 5, 10, 15, 20, 25, 30, 36, 100]

def update_count_for_k(set_beta, set_eps, set_k):
    total_update_count = 0
    for i in range(1000):
        total_update_count += pmw(workload=W_allrange,
                                  x=x_peaks,
                                  eps=set_eps,
                                  beta=set_beta,
                                  k=set_k,
                                  show_messages=False,
                                  to_return='update_count')
    
    average_update_count = total_update_count / num_exp
    print(f'average update_count for k value of {set_k} is \t {average_update_count}')
    return average_update_count
    
for i in k_to_try:
    average_update_count_list.append(update_count_for_k(set_beta=0.2, set_eps=100, set_k=i))

print(average_update_count_list)

average update_count for k value of 1 is 	 25.382
average update_count for k value of 5 is 	 12.233
average update_count for k value of 10 is 	 8.833
average update_count for k value of 15 is 	 6.838
average update_count for k value of 20 is 	 5.946
average update_count for k value of 25 is 	 5.191
average update_count for k value of 30 is 	 4.245
average update_count for k value of 36 is 	 3.402
average update_count for k value of 100 is 	 1.372
[25.382, 12.233, 8.833, 6.838, 5.946, 5.191, 4.245, 3.402, 1.372]


- what I would suggest is fixing some query set
- then plotting the number of update steps as you vary some values
- like privacy budget (0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000)

, failure probability beta, expected queries k, number of individuals n