In [None]:
# coding: utf-8
import math
import numpy as np
import matplotlib.pyplot as plt

import itertools
from scipy.special import comb
from scipy import stats

class Subsetselection:
    def __init__(self, absz, pri_para): # absz: alphabet size, pri_para: privacy parameter
        self.insz = absz #input alphabet size k
        self.exp = math.exp(pri_para)
        self.d = int(math.ceil(1.0*self.insz/(self.exp+1))) # number of 1s in output bit string 
        self.outsz = int(comb(self.insz,self.d))
        self.outset= self.kbits(self.insz,self.d)
        self.channel = self.generate_channel()
        
                                  
    def encode_symbol(self,ori):  # encode a single symbol into a privatized version
        pmf = self.channel[ori]* np.ones(self.outsz) 
        sample_rv = stats.rv_discrete(values=(range(self.outsz),pmf))
        sample = sample_rv.rvs(size=1)
        return sample
     
    def encode_string(self,in_list):  # encode string into a privatized string
        out_list = [self.encode_symbol(x) for x in in_list]
        return out_list
    
    def decode_string(self, out_list): # get the privatized string and learn the original distribution
        l = len(out_list)
        temp1 = ((self.insz-1)*self.exp+1.0*(self.insz-1)*(self.insz-self.d)/self.d) / ((self.insz-self.d)*(self.exp-1))
        temp2 = ((self.d-1)*self.exp+self.insz-self.d) / (1.0*(self.insz-self.d)*(self.exp-1))
        dist = [0.0]*self.insz
        for m in range (self.insz):
            t = 0
            for n in range(l):
                t = t + int ((self.outset[out_list[n][0]])[m])
            dist[m] = temp1 * t/l - temp2
        return dist

    
    def kbits(self, n, k): # generate ouptbut alphabet set
        result = []
        for bits in itertools.combinations(range(n), k):
            s = [False] * n
            for bit in bits:
                s[bit] = True
            result.append(s)
        return result
    
    def generate_channel(self): # generate channel matrix
        low_transfer_probability = 1.0 / (comb(self.insz-1,self.d-1)*self.exp+comb(self.insz-1,self.d)) 
        high_transfer_probability = low_transfer_probability * self.exp
        channel = []
        for m in range(self.insz):
            channel_temp=[]
            for n in self.outset:
                if (n[m]==True):                    
                    channel_temp.append(high_transfer_probability)
                else:
                    channel_temp.append(low_transfer_probability)
            channel.append(channel_temp)    
        return channel   





