Skip to content

Commit

Permalink
Merge pull request #78 from privacylab/ryscheng-pir
Browse files Browse the repository at this point in the history
PIR client
  • Loading branch information
ryscheng committed Jun 7, 2017
2 parents 8d08ee5 + 04031f4 commit 1126f5a
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 42 deletions.
5 changes: 4 additions & 1 deletion libtalek/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ func TestRead(t *testing.T) {
&common.Config{NumBuckets: 64, BucketDepth: 4, DataSize: 1024, BloomFalsePositive: 0.05, MaxLoadFactor: 0.95, LoadFactorStep: 0.05},
time.Second,
time.Second,
[]*common.TrustDomainConfig{common.NewTrustDomainConfig("TestTrustDomain", "127.0.0.1", true, false)},
[]*common.TrustDomainConfig{
common.NewTrustDomainConfig("TestTrustDomain0", "127.0.0.1", true, false),
common.NewTrustDomainConfig("TestTrustDomain1", "127.0.0.1", true, false),
},
"",
}

Expand Down
25 changes: 11 additions & 14 deletions libtalek/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/dchest/siphash"
"github.com/privacylab/talek/common"
"github.com/privacylab/talek/drbg"
"github.com/privacylab/talek/pir/pirclient"
"golang.org/x/crypto/nacl/box"
)

Expand Down Expand Up @@ -77,26 +78,22 @@ func makeReadArg(config *ClientConfig, bucket uint64, rand io.Reader) *common.Re
arg := &common.ReadArgs{}
num := len(config.TrustDomains)
arg.TD = make([]common.PirArgs, num)
arg.TD[0].RequestVector = make([]byte, (config.Config.NumBuckets+7)/8)
arg.TD[0].RequestVector[bucket/8] |= 1 << (bucket % 8)
arg.TD[0].PadSeed = make([]byte, drbg.SeedLength)
if _, err := rand.Read(arg.TD[0].PadSeed); err != nil {

pirClient := pirclient.NewClient("pirclient")
reqVec, err := pirClient.GenerateRequestVectors(bucket, uint64(num), config.Config.NumBuckets)
if err != nil {
return nil
}

for j := 1; j < num; j++ {
arg.TD[j].RequestVector = make([]byte, (config.Config.NumBuckets+7)/8)
if _, err := rand.Read(arg.TD[j].RequestVector); err != nil {
return nil
}
arg.TD[j].PadSeed = make([]byte, drbg.SeedLength)
if _, err := rand.Read(arg.TD[j].PadSeed); err != nil {
for i := 0; i < num; i++ {
arg.TD[i].RequestVector = reqVec[i]
arg.TD[i].PadSeed = make([]byte, drbg.SeedLength)
if _, err := rand.Read(arg.TD[i].PadSeed); err != nil {
return nil
}
for k := 0; k < len(arg.TD[j].RequestVector); k++ {
arg.TD[0].RequestVector[k] ^= arg.TD[j].RequestVector[k]
}

}

return arg
}

Expand Down
3 changes: 3 additions & 0 deletions pir/pircl/shard_cl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func TestShardCLReadv0(t *testing.T) {
t.Fatalf("cannot create new ShardCL: error=%v\n", err)
}
pt.HelperTestShardRead(t, shard)
pt.HelperTestClientRead(t, shard)
pt.AfterEach(t, shard, context)
fmt.Printf("... done \n")
}
Expand All @@ -62,6 +63,7 @@ func TestShardCLReadv1(t *testing.T) {
t.Fatalf("cannot create new ShardCL: error=%v\n", err)
}
pt.HelperTestShardRead(t, shard)
pt.HelperTestClientRead(t, shard)
pt.AfterEach(t, shard, context)
fmt.Printf("... done \n")
}
Expand All @@ -78,6 +80,7 @@ func TestShardCLReadv2(t *testing.T) {
t.Fatalf("cannot create new ShardCL: error=%v\n", err)
}
pt.HelperTestShardRead(t, shard)
pt.HelperTestClientRead(t, shard)
pt.AfterEach(t, shard, context)
fmt.Printf("... done \n")
}
Expand Down
84 changes: 84 additions & 0 deletions pir/pirclient/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package pirclient

import (
"crypto/rand"
"fmt"

"github.com/privacylab/talek/common"
"github.com/privacylab/talek/pir/xor"
)

// Client handles the basic functionalities of a PIR client,
// generating request vectors and combining partial responses
type Client struct {
log *common.Logger
name string
}

// NewClient creates a new PIR client
func NewClient(name string) *Client {
c := &Client{}
c.log = common.NewLogger(name)
c.name = name
return c
}

// GenerateRequestVectors creates numServers requestVectors
// to retrieve data at the specified bucket
func (c *Client) GenerateRequestVectors(bucket uint64, numServers uint64, numBuckets uint64) ([][]byte, error) {
if numServers < 2 {
c.log.Error.Printf("GenerateRequestVectors called with too few servers=%v", numServers)
return nil, fmt.Errorf("numServers=%v must be >1", numServers)
}
if bucket >= numBuckets {
c.log.Error.Printf("GenerateRequestVectors called with invalid bucket=%v, numBuckets=%v", bucket, numBuckets)
return nil, fmt.Errorf("bucket=%v must be <numBuckets=%v", bucket, numBuckets)
}

req := make([][]byte, numServers)
numBytes := numBuckets / 8
if (numBuckets % 8) != 0 {
numBytes++
}

// Encode the secret
req[0] = make([]byte, numBytes)
req[0][bucket/8] |= 1 << (bucket % 8)

var err error
// Generate numServers-1 random request vectors
for i := uint64(1); i < numServers; i++ {
req[i] = make([]byte, numBytes)
_, err = rand.Read(req[i])
if err != nil {
c.log.Error.Printf("GenerateRequestVectors failed: error generating random numbers %v", err)
return nil, err
}
// XOR this request vector into the secret
xor.Bytes(req[0], req[0], req[i])
}

return req, nil
}

// CombineResponses returns the result from XORing all responses together
// Precondition: all responses are the same length
// Returns a byte array of the result
func (c *Client) CombineResponses(responses [][]byte) ([]byte, error) {
if responses == nil || len(responses) < 1 {
c.log.Error.Printf("CombineResponses failed: no responses input")
return nil, fmt.Errorf("no responses input")
}
length := len(responses[0])
result := make([]byte, length)
copy(result, responses[0])

for i := 1; i < len(responses); i++ {
// Combine into result
if xor.Bytes(result, result, responses[i]) != length {
c.log.Error.Printf("CombineResponses failed: malformed response %v", i)
return nil, fmt.Errorf("malformed response %v", i)
}
}
return result, nil
}
109 changes: 109 additions & 0 deletions pir/pirclient/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package pirclient

import (
"encoding/binary"
"testing"
)

func TestNewClient(t *testing.T) {
c := NewClient("test")
if c == nil {
t.Errorf("Client should not be nil")
}
}

func TestGenerateRequestVectors(t *testing.T) {
c := NewClient("test")
reqVec, err := c.GenerateRequestVectors(1, 3, 64)
if err != nil {
t.Errorf("GenerateRequestVectors failed: %v", err)
}
if len(reqVec) != 3 {
t.Errorf("GenerateRequestVectors produced too few request vectors %v, expected 3", len(reqVec))
}
resultBytes, err := c.CombineResponses(reqVec)
if err != nil {
t.Errorf("CombineResponses failed: %v", err)
}
result, _ := binary.Uvarint(resultBytes)
if result != 2 {
t.Errorf("Secret request vector should translate to 2, not %v", result)
}
}

func TestGenerateRequestVectorsOddNumBuckets(t *testing.T) {
c := NewClient("test")
reqVec, err := c.GenerateRequestVectors(1, 3, 65)
if err != nil {
t.Errorf("GenerateRequestVectors failed: %v", err)
}
if len(reqVec) != 3 {
t.Errorf("GenerateRequestVectors produced too few request vectors %v, expected 3", len(reqVec))
}
resultBytes, err := c.CombineResponses(reqVec)
if err != nil {
t.Errorf("CombineResponses failed: %v", err)
}
result, _ := binary.Uvarint(resultBytes)
if result != 2 {
t.Errorf("Secret request vector should translate to 2, not %v", result)
}
}

func TestGenerateRequestVectorsInvalidNumServers(t *testing.T) {
c := NewClient("test")
_, err := c.GenerateRequestVectors(1, 1, 64)
if err == nil {
t.Errorf("GenerateRequestVectors should fail with 1 server")
}
}

func TestGenerateRequestVectorsInvalidBucket(t *testing.T) {
c := NewClient("test")
_, err := c.GenerateRequestVectors(65, 3, 64)
if err == nil {
t.Errorf("GenerateRequestVectors should fail with out of bounds bucket")
}
}

func TestCombineResponses(t *testing.T) {
c := NewClient("test")
result, err := c.CombineResponses([][]byte{
{1, 2, 3, 4, 5},
{1, 2, 3, 4, 5},
})
if err != nil {
t.Errorf("CombineResponses shouldn't have failed")
}
for _, b := range result {
if b != 0 {
t.Errorf("CombineResponses should return 0")
}
}
}

func TestCombineResponsesNone(t *testing.T) {
c := NewClient("test")
_, err := c.CombineResponses(make([][]byte, 0))
if err == nil {
t.Errorf("CombineResponses should have failed with no responses to combine")
}
}

func TestCombineResponsesInvalid(t *testing.T) {
c := NewClient("test")
_, err := c.CombineResponses([][]byte{
{1, 2, 3},
{1},
})
if err == nil {
t.Errorf("CombineResponses should have failed with mismatched responses")
}
_, err = c.CombineResponses([][]byte{
{1},
{1, 2, 3},
})
if err != nil {
t.Errorf("CombineResponses is okay with bigger later responses")
}
}
14 changes: 9 additions & 5 deletions pir/pircpu/shard_cpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/privacylab/talek/common"
"github.com/privacylab/talek/pir/pirinterface"
"github.com/privacylab/talek/pir/xor"
)

// ShardCPU represents a read-only shard of the database
Expand Down Expand Up @@ -63,6 +64,10 @@ func NewShardCPU(name string, bucketSize int, data []byte, readVersion int) (*Sh
return nil, fmt.Errorf("NewShardCPU(%v) failed: data(len=%v) not multiple of bucketSize=%v", name, len(data), bucketSize)
}

if readVersion < 0 || readVersion > 2 {
return nil, fmt.Errorf("NewShardCPU(%v) failed: readVersion=%v must be 0, 1, or 2", name, readVersion)
}

s.bucketSize = bucketSize
s.numBuckets = (len(data) / bucketSize)
s.data = data
Expand Down Expand Up @@ -115,15 +120,14 @@ func (s *ShardCPU) Insert(bucket int, offset int, toCopy []byte) int {
func (s *ShardCPU) Read(reqs []byte, reqLength int) ([]byte, error) {
if len(reqs)%reqLength != 0 {
return nil, fmt.Errorf("ShardCPU.Read expects len(reqs)=%d to be a multiple of reqLength=%d", len(reqs), reqLength)
} else if s.readVersion == 0 {
return s.read0(reqs, reqLength)
} else if s.readVersion == 1 {
return s.read1(reqs, reqLength)
} else if s.readVersion == 2 {
return s.read2(reqs, reqLength)
}

return nil, fmt.Errorf("ShardCPU.Read: invalid readVersion=%d", s.readVersion)
// Default to version 0
return s.read0(reqs, reqLength)
}

func (s *ShardCPU) read0(reqs []byte, reqLength int) ([]byte, error) {
Expand All @@ -141,7 +145,7 @@ func (s *ShardCPU) read0(reqs []byte, reqLength int) ([]byte, error) {
bucketOffset := bucketIndex * s.bucketSize
bucket := s.data[bucketOffset:(bucketOffset + s.bucketSize)]
response := responses[respOffset:(respOffset + s.bucketSize)]
xorWords(response, response, bucket)
xor.Words(response, response, bucket)
}
}
}
Expand All @@ -165,7 +169,7 @@ func (s *ShardCPU) read1(reqs []byte, reqLength int) ([]byte, error) {
bucketOffset := bucketIndex * s.bucketSize
bucket := s.data[bucketOffset:(bucketOffset + s.bucketSize)]
response := responses[respOffset:(respOffset + s.bucketSize)]
xorBytes(response, response, bucket)
xor.Bytes(response, response, bucket)
}
}
}
Expand Down

0 comments on commit 1126f5a

Please sign in to comment.