# Private information retrieval

This notebook explains how to do PIR with Concrete, in a simple way, with applications to blocking spam phone numbers or bad URLs. The principle of PIR is that there is a non-encrypted large database on the server side, which we can't move to the client side for some reasons, like it's too big or we don't want to for privacy reasons or it's updated too often. With PIR, we let the user query the database, and the query (input and output) is not seen in the clear by the server.

In [1]:
# Importing some libraries

import argparse
import random
import time

import numpy as np

from concrete import fhe

### One-hot vector

Our database will be represented as a table T, and inputs to T will be `n_input_bits` bits. Cells of T are represented as `m_output_bits`.

Before querying the database, the input `i` will be represented as [one-hot vectors](https://en.wikipedia.org/wiki/One-hot), ie a vector of `2**n_input_bits` bits: all of them are 0, but the one in the i-th position.

In [2]:
def make_one_hot_vector(index: int, size: int) -> np.ndarray:

    one_hot_vector = np.zeros(shape=(size,), dtype=np.int8)
    one_hot_vector[index] = 1
    return one_hot_vector

Here are a few examples of one-hot vectors.

In [3]:
x = make_one_hot_vector(0, size=5)
assert np.array_equal(x, np.array([1, 0, 0, 0, 0]))

x = make_one_hot_vector(2, size=7)
assert np.array_equal(x, np.array([0, 0, 1, 0, 0, 0, 0]))

## Quering the database with a one-hot vector

Quering the database will be very simple: the input will be given as a one-hot vector and is encrypted by the client. The database is in the clear on the server side. Just making a dot product with the database will return the right input.

In [4]:
def get_ith_element_of_database(one_hot_vector: np.ndarray, database: np.ndarray) -> int:
    return np.dot(one_hot_vector, database)

In [5]:
database = np.array([3, 5, 21, 7, 11, 13, 2, 17])
assert database.ndim == 1
database_length = database.shape[0]
database_output_bits = np.ceil(np.log2(np.max(database))).astype(np.int32)

assert database_output_bits == 5

# For now, we have not compiled our functions so here, all the computations
# in the following asserts are done in the clear, just to check the semantic
# of the functions
assert (
    get_ith_element_of_database(make_one_hot_vector(0, size=database_length), database)
    == database[0]
)
assert (
    get_ith_element_of_database(make_one_hot_vector(3, size=database_length), database)
    == database[3]
)
assert (
    get_ith_element_of_database(make_one_hot_vector(4, size=database_length), database)
    == database[4]
)

## Private information retrieval with FHE

Now, let's make that in a private way, without the server seing the query in the clear. First we compile the function with Concrete.

In [6]:
def compile_function(database, **kwargs):
    assert database.ndim == 1
    database_length = database.shape[0]
    inputset_length = 100
    inputset = [
        (make_one_hot_vector(np.random.randint(database_length), database_length), database)
        for _ in range(inputset_length)
    ]
    # Also add the extreme value, which would actually be sufficient alone, as it reaches the maximal
    # values in get_ith_element_of_database
    inputset.append((make_one_hot_vector(np.argmax(database), database_length), database))

    compiler = fhe.Compiler(
        get_ith_element_of_database, {"one_hot_vector": "encrypted", "database": "clear"}
    )
    circuit = compiler.compile(inputset, **kwargs)

    return circuit


circuit = compile_function(database, show_mlir=True, show_graph=True)


Computation Graph
--------------------------------------------------------------------------------
%0 = one_hot_vector        # EncryptedTensor<uint1, shape=(8,)>        ∈ [0, 1]
%1 = database              # ClearTensor<uint5, shape=(8,)>            ∈ [2, 21]
%2 = dot(%0, %1)           # EncryptedScalar<uint5>                    ∈ [2, 21]
return %2
--------------------------------------------------------------------------------

MLIR
--------------------------------------------------------------------------------
module {
  func.func @main(%arg0: tensor<8x!FHE.eint<5>>, %arg1: tensor<8xi6>) -> !FHE.eint<5> {
    %0 = "FHELinalg.to_signed"(%arg0) : (tensor<8x!FHE.eint<5>>) -> tensor<8x!FHE.esint<5>>
    %1 = "FHELinalg.dot_eint_int"(%0, %arg1) : (tensor<8x!FHE.esint<5>>, tensor<8xi6>) -> !FHE.esint<5>
    %2 = "FHE.to_unsigned"(%1) : (!FHE.esint<5>) -> !FHE.eint<5>
    return %2 : !FHE.eint<5>
  }
}
--------------------------------------------------------------------------------



Then we can make inferences over encrypted input.

In [7]:
def test_encrypted_queries(database, circuit, how_many_tests=1, verbose=True):

    times = []

    for _ in range(how_many_tests):
        database_length = database.shape[0]
        log_database_length = np.ceil(np.log2(database_length)).astype(np.int32)

        # Random index in the database
        random_index = np.random.randint(database_length)

        # Turn it into one hot vector
        x = make_one_hot_vector(random_index, database_length)

        # Encrypt the query, on the client side
        encrypted_x, _ = circuit.encrypt(x, None)

        # Run the FHE computation on the server side
        time_begin = time.time()
        encrypted_y = circuit.run(encrypted_x, database)
        time_end = time.time()
        times.append(time_end - time_begin)

        if verbose:
            print(
                f"FHE computation done in {(time_end - time_begin) * 1000:.1f} milliseconds -- database is {database_length} (2**{log_database_length}) elements of {database_output_bits} bits"
            )

        # Decrypt the result on the client side
        y = circuit.decrypt(encrypted_y)

        # And check the computations worked fine
        assert y == database[random_index]

    return times


_ = test_encrypted_queries(database, circuit)

FHE computation done in 2.6 milliseconds -- database is 8 (2**3) elements of 5 bits


## Performances

Now, obviously, we can do it for much larger databases. Let's see the performance!

In [8]:
how_many_tests = 10

sample_list = [
    (4, 8),
    (4, 16),
    (8, 8),
    (8, 16),
    (9, 8),
    (9, 16),
    (10, 4),
    (10, 8),
    (12, 4),
    (12, 8),
    (14, 4),
    (14, 8),
]
timings_dic = {}

for database_input_bits, database_output_bits in sample_list:

    # Take a random database of expected size and output_bits
    database_length = 2**database_input_bits
    database = np.array(
        [np.random.randint(2**database_output_bits) for _ in range(database_length)]
    )

    circuit = compile_function(database, show_mlir=False, show_graph=False)

    # Benchmark
    times = test_encrypted_queries(database, circuit, how_many_tests=how_many_tests, verbose=False)
    mean_time = np.mean(times)
    timings_dic[(database_input_bits, database_output_bits)] = mean_time
    print(
        f"For a database of 2**{str(database_input_bits):>2s} elements of {str(database_output_bits):>2s} bits, average execution time is {1000 * mean_time:.1f} milliseconds"
    )

For a database of 2** 4 elements of  8 bits, average execution time is 1.2 milliseconds
For a database of 2** 4 elements of 16 bits, average execution time is 1.3 milliseconds
For a database of 2** 8 elements of  8 bits, average execution time is 4.5 milliseconds
For a database of 2** 8 elements of 16 bits, average execution time is 5.0 milliseconds
For a database of 2** 9 elements of  8 bits, average execution time is 6.8 milliseconds
For a database of 2** 9 elements of 16 bits, average execution time is 13.4 milliseconds
For a database of 2**10 elements of  4 bits, average execution time is 8.1 milliseconds
For a database of 2**10 elements of  8 bits, average execution time is 13.3 milliseconds
For a database of 2**12 elements of  4 bits, average execution time is 36.2 milliseconds
For a database of 2**12 elements of  8 bits, average execution time is 55.1 milliseconds
For a database of 2**14 elements of  4 bits, average execution time is 156.5 milliseconds
For a database of 2**14 el

## Using several tables

Finally, let's remark that, to store more information without paying the price of having large database_output_bits, we can apply several dot product in the function, and, on the client side, concatenate the results. Let's do an example with 4 subtables, to gain 2 extra bits on the output.

In [9]:
# We'll have 4 subdatabases of 8b each
number_of_subdatabases = 4
database_output_bits_subdatabases = 8

database_output_bits = database_output_bits * number_of_subdatabases


def get_ith_element_of_database(
    one_hot_vector: np.ndarray,
    database0: np.ndarray,
    database1: np.ndarray,
    database2: np.ndarray,
    database3: np.ndarray,
) -> (int, int, int, int):
    return (
        np.dot(one_hot_vector, database0),
        np.dot(one_hot_vector, database1),
        np.dot(one_hot_vector, database2),
        np.dot(one_hot_vector, database3),
    )


# Take a random database of expected size and output_bits
database_length = 2**database_input_bits
database = np.array([np.random.randint(2**database_output_bits) for _ in range(database_length)])

database0 = (database >> 0) & 0xFF
database1 = (database >> 8) & 0xFF
database2 = (database >> 16) & 0xFF
database3 = (database >> 24) & 0xFF


def compile_function_split_database(database0, database1, database2, database3, **kwargs):
    database_length = database0.shape[0]
    inputset_length = 100
    inputset = [
        (
            make_one_hot_vector(np.random.randint(database_length), database_length),
            database0,
            database1,
            database2,
            database3,
        )
        for _ in range(inputset_length)
    ]
    compiler = fhe.Compiler(
        get_ith_element_of_database,
        {
            "one_hot_vector": "encrypted",
            "database0": "clear",
            "database1": "clear",
            "database2": "clear",
            "database3": "clear",
        },
    )
    circuit = compiler.compile(inputset, **kwargs)
    return circuit


circuit = compile_function_split_database(
    database0, database1, database2, database3, show_mlir=False, show_graph=False
)


def test_encrypted_queries_split_database(database, circuit, how_many_tests=1, verbose=True):

    times = []

    for _ in range(how_many_tests):
        database_length = database.shape[0]
        log_database_length = np.ceil(np.log2(database_length)).astype(np.int32)

        # Random index in the database
        random_index = np.random.randint(database_length)

        # Turn it into one hot vector
        x = make_one_hot_vector(random_index, database_length)

        # Encrypt the query, on the client side
        encrypted_x, _, _, _, _ = circuit.encrypt(x, None, None, None, None)

        # Run the FHE computation on the server side
        time_begin = time.time()
        encrypted_y = circuit.run(encrypted_x, database0, database1, database2, database3)
        time_end = time.time()
        times.append(time_end - time_begin)

        if verbose:
            print(
                f"FHE computation done in {(time_end - time_begin) * 1000:.1f} milliseconds -- database is {database_length} (2**{log_database_length}) elements of {database_output_bits} bits"
            )

        # Decrypt the result on the client side
        y_bits = circuit.decrypt(encrypted_y)

        # Relinearize the result
        y = y_bits[0] + y_bits[1] * 256 + y_bits[2] * 256**2 + y_bits[3] * 256**3

        # And check the computations worked fine
        assert (
            y == database[random_index]
        ), f"{y} {y:x} {y_bits} {y_bits[0]:x} {y_bits[1]:x} {y_bits[2]:x} {y_bits[3]:x} {database[random_index]:x}"

    return times


# Benchmark
times = test_encrypted_queries_split_database(
    database, circuit, how_many_tests=how_many_tests, verbose=False
)
print(
    f"For a database of 2**{str(database_input_bits):>2s} elements of {str(database_output_bits):>2s} bits with 4 sub-databases, average execution time is {1000 * np.mean(times):.1f} milliseconds"
)

For a database of 2**14 elements of 32 bits with 4 sub-databases, average execution time is 276.4 milliseconds


## Use-cases for phone spamming

Now, let see where PIR could be used. Let's imagine we want to build a spam database. In France, there are 10 ** 9 ~ 2 ** 30 phone numbers, we could have a database T[i] for i an integer of 30 bits, returning a boolean stating if the phone number is a known spam number. The database would be server side, and often updated. Phones could query the database on an number, and if the result is positive, filter the call as a spam. All of this would be done without the server knowing the calling numbers.

Then, the goal is to represent this table T as a database D:
- with inputs of bitsize database_input_bits
- with number_of_subdatabases subdatabases, all of which outputs database_output_bits_subdatabases bits of information
such that bitsize database_input_bits + log2(database_output_bits * number_of_subdatabases) >= 30.

Let's suppose the phone number to query is N, where N is represented as N0 || N1, where N0 is the first database_input_bits bits of N, and N1 are the remaining bits.

The user would query D with N0. She would receive database_output_bits * number_of_subdatabases bits of information. By looking at the bit as the N1-th position, he would know if the phone number is a spam. 

For example, we can take
- database_input_bits = 14
- number_of_subdatabases = 8192
- database_output_bits = 8

Obviously, the goal is to find the combination fulfilling the condition such that the execution time is smaller. 

Another possibility to make this less computing intensive is to let some of the bits of the phone number in the clear, and just hide the remaining bits. Eg, we could have the first 3 digits in the clear, and just hide the last 6 digits, turning the condition to bitsize database_input_bits + log2(database_output_bits * number_of_subdatabases) >= 20.

Below, we exhaust to find the best combination, for 30 and 20b phones.


In [10]:
# Finding the best combination
def find_best_combination(expected_total_bits):
    best_combination = None

    for database_input_bits, database_output_bits in timings_dic.keys():
        remaining_bits = expected_total_bits - database_input_bits
        assert remaining_bits > 0
        number_of_subdatabases = np.ceil(2**remaining_bits / database_output_bits).astype(np.int32)
        estimated_time = np.ceil(
            number_of_subdatabases * timings_dic[(database_input_bits, database_output_bits)]
        )

        print(
            f"Estimated time would be {str(estimated_time):>8s} seconds for {str(number_of_subdatabases):>8s} DBs of {(database_input_bits, database_output_bits)}"
        )

        if best_combination is None or estimated_time < best_combination[0]:
            best_combination = (
                estimated_time,
                number_of_subdatabases,
                database_input_bits,
                database_output_bits,
            )

    print(
        f"\nBest combination: {best_combination[0]} seconds for a DB of {expected_total_bits} bits\n"
    )


find_best_combination(30)
find_best_combination(20)

Estimated time would be  10070.0 seconds for  8388608 DBs of (4, 8)
Estimated time would be   5356.0 seconds for  4194304 DBs of (4, 16)
Estimated time would be   2362.0 seconds for   524288 DBs of (8, 8)
Estimated time would be   1308.0 seconds for   262144 DBs of (8, 16)
Estimated time would be   1775.0 seconds for   262144 DBs of (9, 8)
Estimated time would be   1757.0 seconds for   131072 DBs of (9, 16)
Estimated time would be   2135.0 seconds for   262144 DBs of (10, 4)
Estimated time would be   1740.0 seconds for   131072 DBs of (10, 8)
Estimated time would be   2372.0 seconds for    65536 DBs of (12, 4)
Estimated time would be   1806.0 seconds for    32768 DBs of (12, 8)
Estimated time would be   2564.0 seconds for    16384 DBs of (14, 4)
Estimated time would be   1814.0 seconds for     8192 DBs of (14, 8)

Best combination: 1308.0 seconds for a DB of 30 bits

Estimated time would be     10.0 seconds for     8192 DBs of (4, 8)
Estimated time would be      6.0 seconds for     409

## Another use-case for URL checking

It might also be tempting to keep (and refresh very often) a list of bad URL on the server side, and to use them to protect user to click on bad links. Of course, there will be too many URLs to keep with the previous system: fortunately we have an hash-based solution for this. 

The principle will be to use a small non-cryptographic hash function, which maps strings to small integers, let say 20-bit integers. Then, any time the server would see a bad URL `u`, it would hash it to `h` and would store `T[h] = 1` to set that this hash is potentially dangerous. Then, with our system, the user can use privacy-preserving PIR to know if a given URL `u'` is dangerous, by having it to `h'` and checking if `T[h'] == 1`. 

As we know with such-a-small hash function, there will be collisions, which means that sometimes, the user will receive false positives: having `T[h] == 1` doesn't mean that this given URL is dangerous, but that there exists an URL with same hash which is dangerous. These collisions is not a problem per se, the user may just see a "Warning, this URL is potentially dangerous" but still access if he is confident. Or, we could use several different hash functions and different tables `T_i`, and we would check if all `T_i` return 1 to define if an URL is a spam, to highly reduce the probability of collisions. 