In [2]:
import numpy as np
import pandas as pd
import math

import src.hdmm.workload as workload
import src.census_workloads as census
from src.workload_selection import workload_selection
import online_workloads as online_workloads

In [225]:
def pmw_naive(workload, x, analyst_labels, T, eps=0.01, total_k=None, 
         show_messages=False, to_return='error', show_plot=False, show_failure_step=False, eta = None):
    """
    Implement Private Multiplicative Weights Mechanism (PMW) on a workload of
    linear queries where analysts can run out of privacy budget if they use too much of others'. 
    
    In other words, all analysts share from the same privacy budget. 
    
    Last Updated: 4-10-2022

    Algorithm Parameters: 
    - workload = workload of queries (M x k numpy array)
    - x = true database (M x 1 numpy array)
    - T = update threshold
    - eps = privacy budget
    - total_k = total number of update steps alloted for the entire group
    - analyst_labels = list of analyst names corresponding to each query in the workload
    
    Output Controls: 
    - show_messages argument determines whether the function will print information such as 
    error scale, threshold, update steps used, etc.
    - to_return argument determines what the function will return. 
        - if 'pd', pmw() returns pandas df with test data for each 
        query in the workload(showing query, d_t_hat, updated, algo_ans, real_ans, 
        abs_error, rel_error). 
        - if 'update_count', pmw() returns the update count for the total
        amount of queries.
    - show_plot - T/F whether the function will display a plot
    - show_failure_step - T/F whether function prints what step failure mode is reached
    """ 
    
    # initialize constants
    m = x.size  # database len
    n = x.sum()
    if(eta == None):
        eta = (math.log(m, np.e) / ((math.sqrt(n))) )
    #print(eta)
    delta = 1 / (n * math.log(n, np.e))
    x_norm = x / np.sum(x)
    
    # initialize synthetic databases at time 0 (prior to any queries)
    x_t = np.ones(m) / m
    y_t = np.ones(m) / m

    # initialize tracker lists to construct pandas dataframe at the end 
    x_list = [x_t] # create a list of x_t synthetic database at every time step
    update_list = []
    update_count = 0
    pmw_answers = []
    update_times = [] # record times that database is updated
    d_t_hat_list = []
    
    # initialize total_k, the total number of update steps if not default
    if total_k == None:
        total_k = round(n * math.log(math.sqrt(m))/770)
        print(f'{total_k=}')
    
    def lazy_round():
        """
        "Lazy Round" of querying using the stored synthetic database, x_t, in list x_list.
        
        We call this the lazy round because it is contrasted with the updated step where we update the 
        sythetic database and answer the query using the real database.
        """
        update_list.append('no')
        answer = np.dot(query, x_list[time])
        if answer < 0:
            pmw_answers.append(0)
        else: 
            pmw_answers.append(answer)
        x_list.append(x_list[time].round(3))
    
    # inititate first instance of SVT with half the budget and k updates; will be reset in the main loop
    SVTtrigger = False 
    SVTepsilon1 = ((eps/2)/2)
    SVTepsilon2 = ((eps/2)/2)
    rho = np.random.laplace(loc=0, scale=(1/SVTepsilon1), size=1)[0]
    #print(rho + T)
    
    
    for time, query in enumerate(workload):
        
        analyst = analyst_labels[time]
        
        # Do one round of sparse vector technique; compute noisy answer by adding Laplacian noise
        A_t = np.random.laplace(loc=0, scale=(total_k/SVTepsilon2), size=1)[0]
        a_t_hat = (np.dot(query, x_norm)*n ) + A_t
        d_t_hat = a_t_hat - (n*np.dot(query, x_list[time]))
        
        # LAZY ROUND: QUERY USING THE SYNTHETIC DATABASE
        if (abs(d_t_hat) <= T + rho):
            d_t_hat_list.append(d_t_hat)
            lazy_round()

        # UPDATE ROUND: UPDATE SYNTHETIC DATABASE AND RETURN NOISY ANSWER, A_T-HAT
        else:
            # noise
            A_t = np.random.laplace(loc=0, scale=(2*total_k/eps), size=1)[0]
            
            # noisy answer
            a_t_hat = (np.dot(query, x_norm)*n ) + A_t
            d_t_hat = a_t_hat - (n*np.dot(query, x_list[time]))
            d_t_hat_list.append(d_t_hat)
            update_times.append(time)
            
            # step a
            if d_t_hat < 0:
                r_t = query
            else:
                r_t = np.ones(m) - query
            for i in range(len(y_t)):
                y_t[i] = x_list[time][i] * math.exp(-( eta * r_t[i]))# eta is the learning rate
            
            # step b
            x_t = y_t / np.sum(y_t)
            update_count = update_list.count('yes')
            
            # if threshold for num updates is reached, just do a lazy round (synthetic database) answer
            if total_k == 0: 
                if show_failure_step:
                    print(f'Failure mode reached at query number {time}: {query}')
                lazy_round()
                
            # if there are still update steps that the analyst can use, 
            # 1. update the synthetic database
            # 2. answer the query using the noisy answer from the database itself 
            else: 
                x_list.append(x_t.round(3))
                update_list.append('yes') # increment number of updates counter
                answer = a_t_hat / np.sum(x)
                
                if answer < 0:
                    pmw_answers.append(0)
                else: 
                    pmw_answers.append(answer)
                
                total_k -= 1 # use one of the total update steps
        
        #print(f'{x_list[time] - x_list[time - 1]=}')
        
        
    update_count = update_list.count('yes')      

    # calculate error
    real_ans = np.matmul(workload, x_norm)
    abs_error = np.abs(pmw_answers - real_ans)
    rel_error = np.abs(abs_error / np.where(real_ans == 0, 0.000001,
                                                real_ans))
    
    if show_messages:
        np.set_printoptions(suppress=True)
        """Print inputes/outputs to analyze each query"""
        print(f'Original database: {x}\n')
        print(f'Normalized database: {x_norm}\n')
        print(f'Synthetic Database (before) = {x_list[0]}\n')
        print(f'Synthetic Database (after) = {x_list[len(x_list) - 1]}\n')
        print(f'Difference btw. Final Synthetic and true database = {x_list[len(x_list) - 1] - x_norm}\n')
        print(f'Update Count = {update_count}\n')
        print(f'{T=}\n')
        print(f'Error Scale Query Answer= {2*((2*total_k/eps)**2)}\n')
        print(f'Error Scale SVT= {2*((2*total_k/SVTepsilon2)**2)}\n')
        print(f'Update Parameter Scale = {eta}\n')
        print(f'{delta=}\n')
        
    if show_plot: 
        plt.title('Error across queries:')
        rel_line, = plt.plot(rel_error, label='Relative Error')
        abs_line, = plt.plot(abs_error, label='Absolute Error')
        for xc in update_times:
            plt.axvline(x=xc, color='red', label='Update Times', linestyle='dashed')
        plt.legend(handles=[abs_line,rel_line])
        plt.xticks(range(0, len(workload), round(len(workload)/5)))
    
    if to_return == "pd":
        # hacky fix: remove the first synthetic database to keep length of lists consistent with the
        # other lists that comprise of the pandas dataframe
        x_list.pop(0).tolist() 
        d = {
            'algo_ans': pmw_answers,
            'real_ans': real_ans.tolist(),
            'queries': workload.tolist(), 
            'updated': update_list,
            'abs_error': abs_error,               
            'rel_error': rel_error,
            'synthetic database': x_list,
            'analyst': analyst_labels,
            'd_t_hat': d_t_hat_list, 

             }
        test_data = pd.DataFrame(data=d)
        #test_data = test_data.round(3)
        return test_data
    
    if to_return == "error":
        d = {'analyst': analyst_labels,
             'abs_error': abs_error,               
             'rel_error': rel_error,}
        data = pd.DataFrame(data=d)
        data = data.round(3)
        
        analyst_error = {}
        for analyst in list(sorted(analyst_labels)):
            analyst_error[analyst] = data[data.analyst==analyst]['abs_error'].sum()
        return analyst_error
    
    if to_return == "tse":
        d = {'analyst': analyst_labels,
             'abs_error': abs_error,}
        data = pd.DataFrame(data=d)
        data['squared_err'] = data['abs_error'] ** 2
        
        analyst_error = {}
        for analyst in list(sorted(analyst_labels)):
            analyst_error[analyst] = data[data.analyst==analyst]['abs_error'].sum()
        return analyst_error

In [11]:
tenq = np.vstack((online_workloads.identity(5), online_workloads.identity(5)))
fiftyq = np.vstack((tenq, 
                   tenq, 
                   tenq, 
                   tenq, 
                   tenq))

In [12]:
# initialize databases

import pandas as pd
data_path = "migration_tworace.csv"
x_race = pd.read_csv(data_path, header=None).iloc[:, 1].to_numpy()
n = x_race.shape[0]
x_race

x_example = np.array([1000, 2000, 3000, 4000, 5000])
new_x = np.array([.1, .15, .2, .25, .3]) * 1500

In [7]:
ten_identity_q = np.vstack((online_workloads.identity(10), 
                            online_workloads.identity(10),
                            online_workloads.identity(10),
                            online_workloads.identity(10),
                            online_workloads.identity(10),
                            online_workloads.identity(10),
                            online_workloads.identity(10),
                            online_workloads.identity(10),
                            online_workloads.identity(10),
                            online_workloads.identity(10)))

hundred_identity_q = np.vstack((ten_identity_q, 
                               ten_identity_q, 
                               ten_identity_q, 
                               ten_identity_q, 
                               ten_identity_q, 
                               ten_identity_q, 
                               ten_identity_q, 
                               ten_identity_q, 
                               ten_identity_q, 
                               ten_identity_q))

print(f'{ten_identity_q.shape=}')
print(f'{hundred_identity_q.shape=}')

ten_identity_q.shape=(100, 10)
hundred_identity_q.shape=(1000, 10)


In [247]:
x_race_first_ten = x_race[:10]

# try identity on the race database
pmw_naive(hundred_identity_q, x_race_first_ten, ['A'] * 1000, eps=1, T=40, 
          show_messages=True, to_return = 'pd')#to_return='tse')

total_k=3
Original database: [412 333 285 231 202 174 160 142 146 149]

Normalized database: [0.18442256 0.14905998 0.12757386 0.10340197 0.09042077 0.0778872
 0.07162041 0.06356312 0.06535363 0.06669651]

Synthetic Database (before) = [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]

Synthetic Database (after) = [0.103 0.104 0.104 0.099 0.099 0.099 0.099 0.099 0.099 0.099]

Difference btw. Final Synthetic and true database = [-0.08142256 -0.04505998 -0.02357386 -0.00440197  0.00857923  0.0211128
  0.02737959  0.03543688  0.03364637  0.03230349]

Update Count = 3

T=40

Error Scale Query Answer= 0.0

Error Scale SVT= 0.0

Update Parameter Scale = 0.048716278470739886

delta=5.8046389258630686e-05



Unnamed: 0,algo_ans,real_ans,queries,updated,abs_error,rel_error,synthetic database,analyst,d_t_hat
0,0.183751,0.184423,"[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",yes,0.000672,0.003644,"[0.104, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1...",A,187.098815
1,0.147198,0.149060,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",yes,0.001862,0.012490,"[0.103, 0.104, 0.099, 0.099, 0.099, 0.099, 0.0...",A,105.440862
2,0.126888,0.127574,"[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",yes,0.000686,0.005375,"[0.103, 0.104, 0.104, 0.099, 0.099, 0.099, 0.0...",A,62.302205
3,0.099000,0.103402,"[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",no,0.004402,0.042571,"[0.103, 0.104, 0.104, 0.099, 0.099, 0.099, 0.0...",A,9.834000
4,0.099000,0.090421,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...",no,0.008579,0.094881,"[0.103, 0.104, 0.104, 0.099, 0.099, 0.099, 0.0...",A,-19.166000
...,...,...,...,...,...,...,...,...,...
995,0.099000,0.077887,"[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",no,0.021113,0.271069,"[0.103, 0.104, 0.104, 0.099, 0.099, 0.099, 0.0...",A,-47.166000
996,0.099000,0.071620,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...",no,0.027380,0.382288,"[0.103, 0.104, 0.104, 0.099, 0.099, 0.099, 0.0...",A,-61.166000
997,0.099000,0.063563,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ...",no,0.035437,0.557507,"[0.103, 0.104, 0.104, 0.099, 0.099, 0.099, 0.0...",A,-79.166000
998,0.099000,0.065354,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...",no,0.033646,0.514836,"[0.103, 0.104, 0.104, 0.099, 0.099, 0.099, 0.0...",A,-75.166000


In [213]:
n=64
W_name = ['identity', 'total', 'H2', 'race1', 'race2', 'race3', 'custom', 'prefix_sum']
W_lst = [online_workloads.identity(n), online_workloads.total(n), online_workloads.H2(n), online_workloads.race1(), online_workloads.race2(), online_workloads.race3(), online_workloads.custom(n), online_workloads.prefix_sum(n)]

c = np.random.randint(len(W_lst))
final_W = W_lst[c] 

t = 10
for i in range(t): 
    c = np.random.randint(len(W_lst))
    print(c)
    final_W = np.vstack((final_W, W_lst[c]))
    
print(len(final_W))

final_W

6
1
7
0
5
4
5
6
0
6
470


array([[0.        , 1.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 1.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [1.75807666, 2.34611181, 1.06233559, ..., 0.1837023 , 1.26802981,
        1.40702366]])

In [224]:
pmw_naive(final_W, x_race[:64], ['A'] * len(final_W), eps=1, T=40, total_k = 100, 
          show_messages=True, to_return='error')

Original database: [412 333 285 231 202 174 160 142 146 149 145 181 174 190 213 287 372 499
 619 715 785 821 822 816 799 742 717 697 658 593 564 519 447 403 388 365
 336 306 311 289 261 231 213 196 194 170 175 168 149 142 131 119 112 118
 114 116 112 114 106 111 109 112 113 109]

Normalized database: [0.0200956  0.01624232 0.01390108 0.01126719 0.0098527  0.00848698
 0.00780412 0.00692615 0.00712126 0.00726758 0.00707248 0.00882841
 0.00848698 0.00926739 0.01038923 0.01399863 0.01814457 0.02433909
 0.03019218 0.03487465 0.03828895 0.04004487 0.04009365 0.039801
 0.03897181 0.03619159 0.0349722  0.03399668 0.03209443 0.02892401
 0.02750951 0.0253146  0.02180275 0.01965662 0.01892498 0.01780314
 0.01638865 0.01492537 0.01516925 0.01409619 0.01273047 0.01126719
 0.01038923 0.00956004 0.00946249 0.00829187 0.00853575 0.00819432
 0.00726758 0.00692615 0.00638962 0.00580431 0.00546288 0.00575554
 0.00556043 0.00565798 0.00546288 0.00556043 0.00517023 0.00541411
 0.00531655 0.00546288 0.00551

{'A': 14.629000000000001}