In [50]:
import syft as sy
import torch as th
import string

hook = sy.TorchHook(th)
from torch import nn, optim

In [51]:
bob = sy.VirtualWorker(hook, id="bob").add_worker(sy.local_worker)
alice = sy.VirtualWorker(hook, id="alice").add_worker(sy.local_worker)
secure_worker = sy.VirtualWorker(hook, id="secure_worker").add_worker(sy.local_worker)

In [2]:
char2index = {}
index2char = {}

In [3]:
for i, char in enumerate('' + string.ascii_lowercase + '0123456789' + string.punctuation):
    char2index[char] = i
    index2char[i] = char

In [4]:
str_input = "Hello"
max_len = 8

In [5]:
def string2values(str_input, max_len=8):

    str_input = str_input[:max_len].lower()

    # pad strings shorter than max len
    if(len(str_input) < max_len):
        str_input = str_input + "." * (max_len - len(str_input))

    values = list()
    for char in str_input:
        values.append(char2index[char])

    return th.tensor(values).long()

In [6]:
string2values("Hello!")

tensor([ 7,  4, 11, 11, 14, 36, 49, 49])

In [7]:
def one_hot(index, length):
    vect = th.zeros(length).long()
    vect[index] = 1
    return vect

In [8]:
one_hot(char2index['p'], len(index2char))

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [13]:
def string2one_hot_matrix(str_input, max_len=8):
    
    str_input = str_input[:max_len].lower()
    
    #pad strings shorter than max len
    if(len(str_input) < max_len):
        str_input = str_input + "," * (max_len - len(str_input))
        
    
    char_vectors = list()
    for char in str_input:
        char_v = one_hot(char2index[char], len(index2char)).unsqueeze(0)
        char_vectors.append(char_v)
        
    return th.cat(char_vectors, dim=0)
    
    

In [14]:
matrix = string2one_hot_matrix("Howdy!")

In [15]:
matrix.shape

torch.Size([8, 68])

In [16]:
matrix

tensor([[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [17]:
str_a = string2one_hot_matrix("Hello")
str_b = string2one_hot_matrix("Hello")

In [18]:
(str_a * str_b).sum()

tensor(8)

In [19]:
vect = (str_a * str_b).sum(1)
vect

tensor([1, 1, 1, 1, 1, 1, 1, 1])

In [20]:
x = vect[0]
x

tensor(1)

In [21]:
vect = (str_a * str_b).sum(1)

x = vect[0]

for i in range(vect.shape[0] - 1):
    x = x * vect[i + 1]
    
key_match = x
key_match

tensor(1)

In [22]:
keys = list()
values = list()

keys.append(string2one_hot_matrix("key1"))
values.append(string2values("value1"))

keys.append(string2one_hot_matrix("key2"))
values.append(string2values("value2"))

In [23]:
values

[tensor([21,  0, 11, 20,  4, 27, 49, 49]),
 tensor([21,  0, 11, 20,  4, 28, 49, 49])]

In [24]:
query_str = "key1"

In [25]:
query_matrix = string2one_hot_matrix(query_str)

In [26]:
query_matrix

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [27]:
def strings_equal(str_a, str_b):
    
    vect = (str_a * str_b).sum(1)
    
    x = vect[0]
    
    for i in range(vect.shape[0] - 1):
        x = x * vect[i + 1]
        
    str_match = x
    return str_match

In [28]:
key_matches = list()
for key in keys:
    key_match = strings_equal(key, query_matrix)
    key_matches.append(key_match)

In [29]:
key_matches

[tensor(1), tensor(0)]

In [30]:
query_str = "key1"

query_matrix = string2one_hot_matrix(query_str)

key_matches = list()
for key in keys:
    
    key_match = strings_equal(key, query_matrix)
    key_matches.append(key_match)
    
result = values[0] * key_matches[0]

for i in range(len(values) - 1):
    result += values[i+1] * key_matches[i+1]
    
result

tensor([21,  0, 11, 20,  4, 27, 49, 49])

In [42]:
def string2values(str_input, max_len=8):
    
    str_input = str_input[:max_len].lower()
    
    # pad strings shorter than max len
    if(len(str_input) < max_len):
        str_input = str_input + "." * (max_len - len(str_input))
        
        values = list()
        for char in str_input:
            values.append(char2index[char])
            
        return th.tensor(values).long()
    
def values2string(input_values):
    s = ""
    for value in input_values:
        s += index2char[int(value)]
    return s

In [43]:
for i in range(len(values) - 1):
    result += values[i+1] * key_matches[i+1]
    
values2string(result).replace(".","")

'value1'

In [44]:
def query(query_str):
    query_matrix = string2one_hot_matrix(query_str)
    
    key_matches = list()
    for key in keys:
        
        key_match = strings_equal(key, query_matrix)
        key_matches.append(key_match)
        
    result = values[0] * key_matches[0]
    
    for i in range(len(values) - 1):
        result += values[i+1] * key_matches[i+1]
        
    return values2string(result).replace(".","")

In [45]:
query("key2")

'value2'

In [46]:
class EncryptedDB():
    
    def __init__(self, max_key_len=8, max_val_len=8):
        self.max_key_len = 8
        self.max_val_len = 8
        
        self.keys = list()
        self.values = list()
        
    def add_entry(self, key, value):
        self.keys.append(string2one_hot_matrix(key))
        self.values.append(string2values(value))
        
    def query(query_str):
        query_matrix = string2one_hot_matrix(query_str)
        
        key_matches = list()
        for key in self.keys:
            
            key_match = strings_equal(key, query_matrix)
            key_matches.append(key_match)
            
        result = self.values[0] * key_matches[0]
        
        for i in range(len(self.values) - 1):
            result += self.values[i+1] * key_matches[i+1]
            
        return values2string(result).replace(".","")

In [47]:
db = EncryptedDB()

In [48]:
db.add_entry("key1", "value1")
db.add_entry("key2", "value2")
db.add_entry("key3", "value3")
db.add_entry("key4", "value4")

In [49]:
db.query("key1")

TypeError: query() takes 1 positional argument but 2 were given

In [None]:
class EncryptedDB():
    
    def __init__(self, *owners, max_key_len=8, max_val_len=8):
        self.max_key_len = 8
        self.max_val_len = 8
        
        self.keys = list()
        self.values = list()
        self.owners = owners
        
    def add_entry(self, key, value):
        key = string2one_hot_matrix(key)
        key = key.share(*self.owners)
        self.keys.append(key)
        
        value = string2values(value)
        value = value.share(*self.owners)
        self.values.append(value)
        
    def query(self, query_str):
        query_matrix = string2one_hot_matrix(query_str)
        
        query_matrix = query_matrix.share(*self.owners)
        
        key_matches = list()
        for key in self.keys:
            
            key_match = strings_equal(key, query_matrix)
            key_matches.append(key_match)
            
        result = self.values[0] * key_matches[0]
        
        for i in range(len(self.values) - 1):
            result += self.values[i+1] * key_matches[i+1]
            
        result = result.get()
        
        return values2string(result).replace(".","")