-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #78 from privacylab/ryscheng-pir
PIR client
- Loading branch information
Showing
10 changed files
with
324 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
package pirclient | ||
|
||
import ( | ||
"crypto/rand" | ||
"fmt" | ||
|
||
"github.com/privacylab/talek/common" | ||
"github.com/privacylab/talek/pir/xor" | ||
) | ||
|
||
// Client handles the basic functionalities of a PIR client, | ||
// generating request vectors and combining partial responses | ||
type Client struct { | ||
log *common.Logger | ||
name string | ||
} | ||
|
||
// NewClient creates a new PIR client | ||
func NewClient(name string) *Client { | ||
c := &Client{} | ||
c.log = common.NewLogger(name) | ||
c.name = name | ||
return c | ||
} | ||
|
||
// GenerateRequestVectors creates numServers requestVectors | ||
// to retrieve data at the specified bucket | ||
func (c *Client) GenerateRequestVectors(bucket uint64, numServers uint64, numBuckets uint64) ([][]byte, error) { | ||
if numServers < 2 { | ||
c.log.Error.Printf("GenerateRequestVectors called with too few servers=%v", numServers) | ||
return nil, fmt.Errorf("numServers=%v must be >1", numServers) | ||
} | ||
if bucket >= numBuckets { | ||
c.log.Error.Printf("GenerateRequestVectors called with invalid bucket=%v, numBuckets=%v", bucket, numBuckets) | ||
return nil, fmt.Errorf("bucket=%v must be <numBuckets=%v", bucket, numBuckets) | ||
} | ||
|
||
req := make([][]byte, numServers) | ||
numBytes := numBuckets / 8 | ||
if (numBuckets % 8) != 0 { | ||
numBytes++ | ||
} | ||
|
||
// Encode the secret | ||
req[0] = make([]byte, numBytes) | ||
req[0][bucket/8] |= 1 << (bucket % 8) | ||
|
||
var err error | ||
// Generate numServers-1 random request vectors | ||
for i := uint64(1); i < numServers; i++ { | ||
req[i] = make([]byte, numBytes) | ||
_, err = rand.Read(req[i]) | ||
if err != nil { | ||
c.log.Error.Printf("GenerateRequestVectors failed: error generating random numbers %v", err) | ||
return nil, err | ||
} | ||
// XOR this request vector into the secret | ||
xor.Bytes(req[0], req[0], req[i]) | ||
} | ||
|
||
return req, nil | ||
} | ||
|
||
// CombineResponses returns the result from XORing all responses together | ||
// Precondition: all responses are the same length | ||
// Returns a byte array of the result | ||
func (c *Client) CombineResponses(responses [][]byte) ([]byte, error) { | ||
if responses == nil || len(responses) < 1 { | ||
c.log.Error.Printf("CombineResponses failed: no responses input") | ||
return nil, fmt.Errorf("no responses input") | ||
} | ||
length := len(responses[0]) | ||
result := make([]byte, length) | ||
copy(result, responses[0]) | ||
|
||
for i := 1; i < len(responses); i++ { | ||
// Combine into result | ||
if xor.Bytes(result, result, responses[i]) != length { | ||
c.log.Error.Printf("CombineResponses failed: malformed response %v", i) | ||
return nil, fmt.Errorf("malformed response %v", i) | ||
} | ||
} | ||
return result, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
package pirclient | ||
|
||
import ( | ||
"encoding/binary" | ||
"testing" | ||
) | ||
|
||
func TestNewClient(t *testing.T) { | ||
c := NewClient("test") | ||
if c == nil { | ||
t.Errorf("Client should not be nil") | ||
} | ||
} | ||
|
||
func TestGenerateRequestVectors(t *testing.T) { | ||
c := NewClient("test") | ||
reqVec, err := c.GenerateRequestVectors(1, 3, 64) | ||
if err != nil { | ||
t.Errorf("GenerateRequestVectors failed: %v", err) | ||
} | ||
if len(reqVec) != 3 { | ||
t.Errorf("GenerateRequestVectors produced too few request vectors %v, expected 3", len(reqVec)) | ||
} | ||
resultBytes, err := c.CombineResponses(reqVec) | ||
if err != nil { | ||
t.Errorf("CombineResponses failed: %v", err) | ||
} | ||
result, _ := binary.Uvarint(resultBytes) | ||
if result != 2 { | ||
t.Errorf("Secret request vector should translate to 2, not %v", result) | ||
} | ||
} | ||
|
||
func TestGenerateRequestVectorsOddNumBuckets(t *testing.T) { | ||
c := NewClient("test") | ||
reqVec, err := c.GenerateRequestVectors(1, 3, 65) | ||
if err != nil { | ||
t.Errorf("GenerateRequestVectors failed: %v", err) | ||
} | ||
if len(reqVec) != 3 { | ||
t.Errorf("GenerateRequestVectors produced too few request vectors %v, expected 3", len(reqVec)) | ||
} | ||
resultBytes, err := c.CombineResponses(reqVec) | ||
if err != nil { | ||
t.Errorf("CombineResponses failed: %v", err) | ||
} | ||
result, _ := binary.Uvarint(resultBytes) | ||
if result != 2 { | ||
t.Errorf("Secret request vector should translate to 2, not %v", result) | ||
} | ||
} | ||
|
||
func TestGenerateRequestVectorsInvalidNumServers(t *testing.T) { | ||
c := NewClient("test") | ||
_, err := c.GenerateRequestVectors(1, 1, 64) | ||
if err == nil { | ||
t.Errorf("GenerateRequestVectors should fail with 1 server") | ||
} | ||
} | ||
|
||
func TestGenerateRequestVectorsInvalidBucket(t *testing.T) { | ||
c := NewClient("test") | ||
_, err := c.GenerateRequestVectors(65, 3, 64) | ||
if err == nil { | ||
t.Errorf("GenerateRequestVectors should fail with out of bounds bucket") | ||
} | ||
} | ||
|
||
func TestCombineResponses(t *testing.T) { | ||
c := NewClient("test") | ||
result, err := c.CombineResponses([][]byte{ | ||
{1, 2, 3, 4, 5}, | ||
{1, 2, 3, 4, 5}, | ||
}) | ||
if err != nil { | ||
t.Errorf("CombineResponses shouldn't have failed") | ||
} | ||
for _, b := range result { | ||
if b != 0 { | ||
t.Errorf("CombineResponses should return 0") | ||
} | ||
} | ||
} | ||
|
||
func TestCombineResponsesNone(t *testing.T) { | ||
c := NewClient("test") | ||
_, err := c.CombineResponses(make([][]byte, 0)) | ||
if err == nil { | ||
t.Errorf("CombineResponses should have failed with no responses to combine") | ||
} | ||
} | ||
|
||
func TestCombineResponsesInvalid(t *testing.T) { | ||
c := NewClient("test") | ||
_, err := c.CombineResponses([][]byte{ | ||
{1, 2, 3}, | ||
{1}, | ||
}) | ||
if err == nil { | ||
t.Errorf("CombineResponses should have failed with mismatched responses") | ||
} | ||
_, err = c.CombineResponses([][]byte{ | ||
{1}, | ||
{1, 2, 3}, | ||
}) | ||
if err != nil { | ||
t.Errorf("CombineResponses is okay with bigger later responses") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.