In [1]:
import pandas as pd
import numpy as np
import math
from matplotlib import pyplot as plt
from src.hdmm import workload, fairtemplates, error, fairmechanism, matrix, mechanism, templates

In [2]:
def pmw(workload, x, eps=0.01, beta=0.1, k=0, show_messages=True, to_return='pd', analyst_labels = []):
    """
    Implement Private Multiplicative Weights Mechanism (PMW) on a workload of
    linear queries. New arguments to allow for optimizing the amount of
    privacy budget used in each step.
    
    to_return argument determines what the function will return. 
        - if 'pd', pmw() returns pandas df with test data for each 
        query (query, d_t_hat, updated, algo_ans, real_ans, abs_error, 
        rel_error). 
        - if 'update_count', pmw() returns the update count for the total
        amount of queries

    - W = workload of queries (M x k numpy array)
    - x = true database (M x 1 numpy array)
    """ 
    
    # initialize constants
    m = x.size  # database len
    n = x.sum()
    if k==0: # essentially, if k hasn't been changed from its default value, use the length of the workload
        k = len(workload)  # num of queries
    delta = 1 / (n * math.log(n, np.e))
    x_norm = x / np.sum(x)
    eta = (math.log(m, np.e) ** (1 / 4)) / (math.sqrt(n))
    sigma = 10 * math.log(1 / delta, np.e) * ((math.log(m, np.e)) ** (1 / 4)) / (
            math.sqrt(n) * eps)
    threshold = 4 * sigma * (math.log(k, np.e) + math.log(1 / beta, np.e))
    
    # synthetic databases at time 0 (prior to any queries)
    y_t = np.ones(m) / m
    x_t = np.ones(m) / m

    # append to list of databases y_t and x_t
    x_list = [x_t]
    
    update_list = []
    update_count = 0
    pmw_answers = []
    update_times = []
    d_t_hat_list = []
    
    failure_mode = False # if reaches failure, failure mode because true and only runs lazy rounds
    
    def lazy_round():
        update_list.append('no')
        pmw_answers.append(np.dot(query, x_list[time]))
        x_list.append(x_list[time])
        
    
    # iterate through time = (0, k)
    for time, query in enumerate(workload):
        # compute noisy answer by adding Laplacian noise
        a_t = np.random.laplace(loc=0, scale=sigma, size=1)[0]
        a_t_hat = (np.dot(query, x_norm)*n ) + a_t

        # difference between noisy and maintained histogram answer
        d_t_hat = a_t_hat - (n*np.dot(query, x_list[time]))
        d_t_hat_list.append(d_t_hat)

        # lazy round: use maintained histogram to answer the query
        if abs(d_t_hat) <= threshold or failure_mode:
            lazy_round()
            continue

        # update round: update histogram and return noisy answer
        else:
            update_times.append(time)

            # step a
            if d_t_hat < 0:
                r_t = query
            else:
                r_t = np.ones(m) - query
            for i, v in enumerate(y_t):
                y_t[i] = x_list[time][i] * math.exp(-eta * r_t[i])
    
            # step b
            x_t = y_t / np.sum(y_t)
            update_count = update_list.count('yes')
            
            if update_count > n * math.log(m, np.e) ** (1 / 2): # threshold for num updates is reached, enter failure_mode
                failure_mode = True
                if show_messages:
                    print(f'Failure mode reached at query number {time}: {query}')
                lazy_round()
            else: # threshold for num updates is not reached yet
                x_list.append(x_t)
                update_list.append('yes') # increment number of updates counter
                pmw_answers.append(a_t_hat / np.sum(x))
                
    update_count = update_list.count('yes')      

    # calculate error
    real_ans = np.matmul(workload, x_norm)
    abs_error = np.abs(pmw_answers - real_ans)
    rel_error = np.abs(abs_error / np.where(real_ans == 0, 0.000001,
                                                real_ans))
    
    
    def print_outputs():
        np.set_printoptions(suppress=True)
        """Print inputes/outputs to analyze each query"""
        print(f'Original database: {x}\n')
        print(f'Normalized database: {x_norm}\n')
        print(f'Updated Database = {x_t}\n')
        print(f'Update Count = {update_count}\n')
        print(f'{threshold=}\n')
        print(f'Error Scale = {sigma}\n')
        print(f'Update Parameter Scale = {eta}\n')
        print(f'{delta=}\n')
    
    def plot_error():
        """Plot absolute and relative error"""
        plt.xticks(range(0, k, 5))
        plt.title('Error across queries:')
        rel_line, = plt.plot(rel_error, label='Relative Error')
        abs_line, = plt.plot(abs_error, label='Absolute Error')
        for xc in update_times:
            plt.axvline(x=xc, color='red', label='Update Times', linestyle='dashed')
        plt.locator_params(axis='x', nbins=10)
        plt.legend(handles=[rel_line, abs_line])
    
    if show_messages:
        print_outputs()
        plot_error()
        
    if to_return == "update_count":
        return update_count

    if to_return == "pd":
        x_list.pop(0).tolist() # remove the first synthetic database to keep length of lists consistent-x_list[t] represents the synthetic database at the end of time t
        d = {
            'queries': workload.tolist(), 
            'd_t_hat': d_t_hat_list, 
            'updated': update_list,
            'algo_ans': pmw_answers,
            'real_ans': real_ans.tolist(),
            'abs_error': abs_error,               
            'rel_error': rel_error,
            'synthetic database': x_list,
            'analyst': analyst_labels,
             }
        test_data = pd.DataFrame(data=d)
        return test_data

# Failure Scenario:
Alice and Bob are analysts who query from a database, x_small, that uses private multiplicative weights in an online setting. Alice submits a set of 500 queries that only queries the first half of the database (indexes 0, 1, 2, 3). Alice's queries come first and use all the budget. Practically, that means Alice runs queries until output = failure. Then Bob does queries on index 5-7 on the synthetic database.
Parameters:
- x_small: [1, 8, 1, 1, 1, 9, 1, 1]
- Number of Alice's queries: 500, they only query from the first four slots in x_small
- Number of Bob's queries: 500, they only query from the last four slots in x_small
- individual Alice or Bob settings: eps=250, beta=0.2, k=500
- collective Alice and Bob settings: eps=500, beta=0.2, k=1000

Results:
- alice's total error individually: 7.286189769066673
- bob's total error individually: 9.116719651841597
- alice's total error in collective: 3.5804403438238794
- bob's total error in collective: 7.894955328194034

Conclusion:
- In this specific experiment, the Sharing incentive and Non-interference desideratas don't seem to be violated.

Next Steps:
- Try this on the new implementation of pmw. An epsilon of 250 and 500 are absurdly high, but they were the numbers I needed to use to reach the failure mode relatively consistently in the experiments where both Alice and Bob were submitting queries

In [3]:
#x_small = np.array([10, 80, 13, 12, 90, 14, 17, 17])
x_small = np.array([1, 8, 1, 1, 1, 9, 1, 1])
normalized = x_small / x_small.sum()
print(normalized)
m = x_small.size  # database len
n = x_small.sum()
print(f'the threshold for failure is {n * math.log(m, np.e) ** (1 / 2)}')

[0.04347826 0.34782609 0.04347826 0.04347826 0.04347826 0.39130435
 0.04347826 0.04347826]
the threshold for failure is 33.16661839182031


In [4]:
random_array = np.random.randint(2, size=(500,4))
zero_array = np.zeros((500,4))
alice = np.hstack((random_array, zero_array))
bob = np.hstack((zero_array, random_array))

In [28]:
# just alice
alice_labels = ['Alice'] * 500

alice_error_list = []
for i in range(1000):
    random_array = np.random.randint(2, size=(500,4))
    alice = np.hstack((random_array, zero_array))
    alice_query_data = pmw(workload=alice, x=x_small, eps=250, beta=0.2, k=500, to_return='pd', 
                           analyst_labels = alice_labels, show_messages=False)
    alice_error_list.append(sum(alice_query_data.abs_error))
print(f"alice's total error individually: {sum(alice_error_list) / len(alice_error_list)}")

alice's total error individually: 7.286189769066673


In [27]:
# just bob
bob_labels = ['Bob'] * 500

bob_error_list = []
for i in range(1000):
    random_array = np.random.randint(2, size=(500,4))
    bob = np.hstack((zero_array, random_array))
    bob_query_data = pmw(workload=bob, x=x_small, eps=250, beta=0.2, k=500, to_return='pd', 
                         analyst_labels = bob_labels, show_messages=False)
    bob_error_list.append(sum(bob_query_data.abs_error))
print(f"bob's total error individually: {sum(bob_error_list) / len(bob_error_list)}")

bob's total error individually: 9.116719651841597


In [25]:
# alice and bob combined
combined_alice_error_list = []
combined_bob_error_list = []

for i in range(1000):
    random_array = np.random.randint(2, size=(500,4))
    alice = np.hstack((random_array, zero_array))
    bob = np.hstack((zero_array, random_array))
    combined_query_data = pmw(workload=np.vstack((alice, bob)), x=x_small, eps=500, beta=0.2, k=1000, to_return='pd', 
                        analyst_labels = alice_labels + bob_labels, show_messages=False)
    combined_alice_error_list.append(sum(combined_query_data[combined_query_data.analyst=='Alice'].abs_error))
    combined_bob_error_list.append(sum(combined_query_data[combined_query_data.analyst=='Bob'].abs_error))
    
print(f"alice's total error in collective: {sum(combined_alice_error_list) / len(combined_alice_error_list)}")
print(f"bob's total error in collective: {sum(combined_bob_error_list) / len(combined_bob_error_list)}")

alice's total error in collective: 3.5804403438238794
bob's total error in collective: 7.894955328194034


Failure Task: 
- [DO] Run test on new iteration of PMW
- [DO] Try to create an even more adversarial setting where Alice tries to take all the updates (i.e., give Alice more queries than Bob)
- Calculate error difference. 
- Create normal distribution workloads
- Use a smaller database so the threshold is less
- Output when failure happened, 
- no longer update from that point and just use synthetic database (only do lazy rounds)
- in order to make it only do lazy rounds, create variable that goes into failure_mode
- fix synthetic database so that it actually updates
- Create a parallel list of people - output the labelled person, i.e. "alice" into a list. output this as a column as well
- PMW with just Alice and just bob - Split k and epsilon in half