Skip to content

Commit

Permalink
feat(lib/trie): Implement verify_proof function (ChainSafe#1883)
Browse files Browse the repository at this point in the history
* feat: add verify_proof function

* chore: adding helpers

* chore: build the tree from proof slice

* chore: remove Nibbles custom type

* chore: fix lint warns

* chore: add benchmark tests

* chore: fix deepsource warns

* chore: redefine LoadFromProof function

* chore: remove logs

* chore: address comments

* chore: fix the condition to load the proof

* chore: address comments

* chore: improve find function

* chore: use map to avoid duplicate keys

* chore: add test cases to duplicate values and nil values

* chore: fix unused param lint error

* chore: use the shortest form

* chore: use set just for find dupl keys
  • Loading branch information
EclesioMeloJunior committed Oct 19, 2021
1 parent 0f63b17 commit 67bb5ef
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 82 deletions.
62 changes: 62 additions & 0 deletions lib/trie/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ package trie

import (
"bytes"
"errors"
"fmt"

"github.com/ChainSafe/gossamer/lib/common"

"github.com/ChainSafe/chaindb"
)

// ErrEmptyProof indicates the proof slice is empty
var ErrEmptyProof = errors.New("proof slice empty")

// Store stores each trie node in the database, where the key is the hash of the encoded node and the value is the encoded node.
// Generally, this will only be used for the genesis trie.
func (t *Trie) Store(db chaindb.Database) error {
Expand Down Expand Up @@ -73,6 +77,64 @@ func (t *Trie) store(db chaindb.Batch, curr node) error {
return nil
}

// LoadFromProof create a partial trie based on the proof slice, as it only contains nodes that are in the proof afaik.
func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error {
if len(proof) == 0 {
return ErrEmptyProof
}

mappedNodes := make(map[string]node, len(proof))

// map all the proofs hash -> decoded node
// and takes the loop to indentify the root node
for _, rawNode := range proof {
decNode, err := decodeBytes(rawNode)
if err != nil {
return err
}

decNode.setDirty(false)
decNode.setEncodingAndHash(rawNode, nil)

_, computedRoot, err := decNode.encodeAndHash()
if err != nil {
return err
}

mappedNodes[common.BytesToHex(computedRoot)] = decNode

if bytes.Equal(computedRoot, root) {
t.root = decNode
}
}

t.loadProof(mappedNodes, t.root)
return nil
}

// loadProof is a recursive function that will create all the trie paths based
// on the mapped proofs slice starting by the root
func (t *Trie) loadProof(proof map[string]node, curr node) {
c, ok := curr.(*branch)
if !ok {
return
}

for i, child := range c.children {
if child == nil {
continue
}

proofNode, ok := proof[common.BytesToHex(child.getHash())]
if !ok {
continue
}

c.children[i] = proofNode
t.loadProof(proof, proofNode)
}
}

// Load reconstructs the trie from the database from the given root hash. Used when restarting the node to load the current state trie.
func (t *Trie) Load(db chaindb.Database, root common.Hash) error {
if root == EmptyHash {
Expand Down
95 changes: 23 additions & 72 deletions lib/trie/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,86 +2,37 @@ package trie

import (
"bytes"
"errors"

"github.com/ChainSafe/chaindb"
)

var (
// ErrProofNodeNotFound when a needed proof node is not in the database
ErrProofNodeNotFound = errors.New("cannot find a trie node in the database")
)

// lookup struct holds the state root and database reference
// used to retrieve trie information from database
type lookup struct {
// root to start the lookup
root []byte
db chaindb.Database
// findAndRecord search for a desired key recording all the nodes in the path including the desired node
func findAndRecord(t *Trie, key []byte, recorder *recorder) error {
return find(t.root, key, recorder)
}

// newLookup returns a Lookup to helps the proof generator
func newLookup(rootHash []byte, db chaindb.Database) *lookup {
lk := &lookup{db: db}
lk.root = make([]byte, len(rootHash))
copy(lk.root, rootHash)

return lk
}

// find will return the desired value or nil if key cannot be found and will record visited nodes
func (l *lookup) find(key []byte, recorder *recorder) ([]byte, error) {
partial := key
hash := l.root

for {
nodeData, err := l.db.Get(hash)
if err != nil {
return nil, ErrProofNodeNotFound
}

nodeHash := make([]byte, len(hash))
copy(nodeHash, hash)

recorder.record(nodeHash, nodeData)

decoded, err := decodeBytes(nodeData)
if err != nil {
return nil, err
}
func find(parent node, key []byte, recorder *recorder) error {
enc, hash, err := parent.encodeAndHash()
if err != nil {
return err
}

switch currNode := decoded.(type) {
case nil:
return nil, nil
recorder.record(hash, enc)

case *leaf:
if bytes.Equal(currNode.key, partial) {
return currNode.value, nil
}
return nil, nil
b, ok := parent.(*branch)
if !ok {
return nil
}

case *branch:
switch len(partial) {
case 0:
return currNode.value, nil
default:
if !bytes.HasPrefix(partial, currNode.key) {
return nil, nil
}
length := lenCommonPrefix(b.key, key)

if bytes.Equal(partial, currNode.key) {
return currNode.value, nil
}
// found the value at this node
if bytes.Equal(b.key, key) || len(key) == 0 {
return nil
}

length := lenCommonPrefix(currNode.key, partial)
switch child := currNode.children[partial[length]].(type) {
case nil:
return nil, nil
default:
partial = partial[length+1:]
copy(hash, child.getHash())
}
}
}
// did not find value
if bytes.Equal(b.key[:length], key) && len(key) < len(b.key) {
return nil
}

return find(b.children[key[length]], key[length+1:], recorder)
}
56 changes: 52 additions & 4 deletions lib/trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
package trie

import (
"bytes"
"encoding/hex"
"errors"
"fmt"

"github.com/ChainSafe/chaindb"
"github.com/ChainSafe/gossamer/lib/common"
Expand All @@ -26,20 +29,32 @@ import (
var (
// ErrEmptyTrieRoot occurs when trying to craft a prove with an empty trie root
ErrEmptyTrieRoot = errors.New("provided trie must have a root")

// ErrValueNotFound indicates that a returned verify proof value doesnt match with the expected value on items array
ErrValueNotFound = errors.New("expected value not found in the trie")

// ErrDuplicateKeys not allowed to verify proof with duplicate keys
ErrDuplicateKeys = errors.New("duplicate keys on verify proof")

// ErrLoadFromProof occurs when there are problems with the proof slice while building the partial proof trie
ErrLoadFromProof = errors.New("failed to build the proof trie")
)

// GenerateProof receive the keys to proof, the trie root and a reference to database
// will
func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, error) {
trackedProofs := make(map[string][]byte)

proofTrie := NewEmptyTrie()
if err := proofTrie.Load(db, common.BytesToHash(root)); err != nil {
return nil, err
}

for _, k := range keys {
nk := keyToNibbles(k)

lookup := newLookup(root, db)
recorder := new(recorder)

_, err := lookup.find(nk, recorder)
err := findAndRecord(proofTrie, nk, recorder)
if err != nil {
return nil, err
}
Expand All @@ -54,10 +69,43 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e
}

proofs := make([][]byte, 0)

for _, p := range trackedProofs {
proofs = append(proofs, p)
}

return proofs, nil
}

// Pair holds the key and value to check while verifying the proof
type Pair struct{ Key, Value []byte }

// VerifyProof ensure a given key is inside a proof by creating a proof trie based on the proof slice
// this function ignores the order of proofs
func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) {
set := make(map[string]struct{}, len(items))

// check for duplicate keys
for _, item := range items {
hexKey := hex.EncodeToString(item.Key)
if _, ok := set[hexKey]; ok {
return false, ErrDuplicateKeys
}
set[hexKey] = struct{}{}
}

proofTrie := NewEmptyTrie()
if err := proofTrie.LoadFromProof(proof, root); err != nil {
return false, fmt.Errorf("%w: %s", ErrLoadFromProof, err)
}

for _, item := range items {
recValue := proofTrie.Get(item.Key)

// here we need to compare value only if the caller pass the value
if item.Value != nil && !bytes.Equal(item.Value, recValue) {
return false, ErrValueNotFound
}
}

return true, nil
}
Loading

0 comments on commit 67bb5ef

Please sign in to comment.