Skip to content

Commit

Permalink
add profiling runtime flag, improve merkle performance (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
worm-emoji authored Jun 27, 2023
1 parent 5f28063 commit 8b608fc
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 82 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
tmp/
.env
*.pprof
10 changes: 3 additions & 7 deletions api/migrations/scripts/000-rebuild-proofs-hashes/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"context"
"errors"
"fmt"
"log"
"os"
Expand Down Expand Up @@ -43,13 +42,10 @@ func migrateTree(
)
eg.SetLimit(runtime.NumCPU())

for _, l := range leaves {
l := l //avoid capture
for i := range leaves {
i := i
eg.Go(func() error {
pf := tree.Proof(l)
if !merkle.Valid(tree.Root(), pf, l) {
return errors.New("invalid proof for tree")
}
pf := tree.Proof(i)
proofHash := hashProof(pf)
pm.Lock()
proofHashes = append(proofHashes, []any{tree.Root(), proofHash})
Expand Down
10 changes: 6 additions & 4 deletions api/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ func (s *Server) GetProof(w http.ResponseWriter, r *http.Request) {
ctx = r.Context()
root = common.FromHex(r.URL.Query().Get("root"))
leaf = common.FromHex(r.URL.Query().Get("unhashedLeaf"))
addr = common.HexToAddress(r.URL.Query().Get("address"))
addr = common.FromHex(r.URL.Query().Get("address"))
)

if len(root) == 0 {
s.sendJSONError(r, w, nil, http.StatusBadRequest, "missing root")
return
}
if len(leaf) == 0 && addr == (common.Address{}) {
if len(leaf) == 0 && len(addr) == 0 {
s.sendJSONError(r, w, nil, http.StatusBadRequest, "missing leaf")
return
}
Expand All @@ -53,7 +54,7 @@ func (s *Server) GetProof(w http.ResponseWriter, r *http.Request) {
if bytes.Equal(l, leaf) {
target = l
}
} else if leaf2Addr(l, td.Ltd, td.Packed).Hex() == addr.Hex() {
} else if bytes.Equal(leaf2Addr(l, td.Ltd, td.Packed), addr) {
target = l
}
}
Expand All @@ -67,7 +68,8 @@ func (s *Server) GetProof(w http.ResponseWriter, r *http.Request) {
}

var (
p = merkle.New(leaves).Proof(target)
mt = merkle.New(leaves)
p = mt.Proof(mt.Index(target))
phex = []hexutil.Bytes{}
)

Expand Down
61 changes: 25 additions & 36 deletions api/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"net/http"
"sync"

"github.com/contextwtf/lanyard/merkle"
"github.com/ethereum/go-ethereum/accounts/abi"
Expand All @@ -14,7 +13,6 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"golang.org/x/sync/errgroup"
)

func (s *Server) TreeHandler(w http.ResponseWriter, r *http.Request) {
Expand All @@ -31,12 +29,12 @@ func (s *Server) TreeHandler(w http.ResponseWriter, r *http.Request) {
}
}

func leaf2Addr(leaf []byte, ltd []string, packed bool) common.Address {
if len(ltd) == 0 || (len(ltd) == 1 && ltd[0] == "address") {
return common.BytesToAddress(leaf)
func leaf2Addr(leaf []byte, ltd []string, packed bool) []byte {
if len(ltd) == 0 || (len(ltd) == 1 && ltd[0] == "address" && len(leaf) == 20) {
return leaf
}
if ltd[len(ltd)-1] == "address" && len(leaf) > 20 {
return common.BytesToAddress(leaf[len(leaf)-20:])
return leaf[len(leaf)-20:]
}

if packed {
Expand All @@ -45,7 +43,7 @@ func leaf2Addr(leaf []byte, ltd []string, packed bool) common.Address {
return addrUnpacked(leaf, ltd)
}

func addrUnpacked(leaf []byte, ltd []string) common.Address {
func addrUnpacked(leaf []byte, ltd []string) []byte {
var addrStart, pos int
for _, desc := range ltd {
if desc == "address" {
Expand All @@ -54,32 +52,33 @@ func addrUnpacked(leaf []byte, ltd []string) common.Address {
}
pos += 32
}

if len(leaf) >= addrStart+32 {
return common.BytesToAddress(leaf[addrStart:(addrStart + 32)])
l := leaf[addrStart:(addrStart + 32)]
return l[len(l)-20:] // take last 20 bytes
}
return common.Address{}
return []byte{}
}

func addrPacked(leaf []byte, ltd []string) common.Address {
func addrPacked(leaf []byte, ltd []string) []byte {
var addrStart, pos int
for _, desc := range ltd {
t, err := abi.NewType(desc, "", nil)
if err != nil {
return common.Address{}
}
if desc == "address" {
return []byte{}
} else if desc == "address" {
addrStart = pos
break
}
pos += int(t.GetType().Size())
}
if addrStart == 0 && pos != 0 {
return common.Address{}
return []byte{}
}
if len(leaf) >= addrStart+20 {
return common.BytesToAddress(leaf[addrStart:(addrStart + 20)])
return leaf[addrStart:(addrStart + 20)]
}
return common.Address{}
return []byte{}
}

func hashProof(p [][]byte) []byte {
Expand Down Expand Up @@ -117,14 +116,14 @@ func (s *Server) CreateTree(w http.ResponseWriter, r *http.Request) {

var leaves [][]byte
for _, l := range req.Leaves {
// use the go-ethereum HexDecode method because it is more
// use the go-ethereum FromHex method because it is more
// lenient and will allow for odd-length hex strings (by padding them)
leaves = append(leaves, common.FromHex(l))
}

tree := merkle.New(leaves)
root := tree.Root()
var (
tree = merkle.New(leaves)
root = tree.Root()
exists bool
)

Expand All @@ -147,25 +146,15 @@ func (s *Server) CreateTree(w http.ResponseWriter, r *http.Request) {
}

var (
proofHashes = [][]any{}
eg errgroup.Group
pm sync.Mutex
proofHashes = make([][]any, 0, len(leaves))
allProofs = tree.LeafProofs()
)
for _, l := range leaves {
l := l //avoid capture
eg.Go(func() error {
pf := tree.Proof(l)
if !merkle.Valid(tree.Root(), pf, l) {
return errors.New("invalid proof for tree")
}
proofHash := hashProof(pf)
pm.Lock()
proofHashes = append(proofHashes, []any{tree.Root(), proofHash})
pm.Unlock()
return nil
})

for _, p := range allProofs {
proofHash := hashProof(p)
proofHashes = append(proofHashes, []any{root, proofHash})
}
err = eg.Wait()

if err != nil {
s.sendJSONError(r, w, err, http.StatusBadRequest, "generating proofs for tree")
return
Expand Down
91 changes: 79 additions & 12 deletions api/tree_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"bytes"
"testing"

"github.com/ethereum/go-ethereum/common"
Expand All @@ -10,23 +11,23 @@ func TestAddrUnpacked(t *testing.T) {
cases := []struct {
leaf []byte
ltd []string
want common.Address
want []byte
}{
{
common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"),
common.FromHex("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"),
[]string{"uint32", "address"},
common.HexToAddress("0x0000000000000000000000000000000000000001"),
common.FromHex("0x0000000000000000000000000000000000000001"),
},
{
common.Hex2Bytes("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000"),
common.FromHex("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000"),
[]string{"address", "uint32"},
common.HexToAddress("0x0000000000000000000000000000000000000001"),
common.FromHex("0x0000000000000000000000000000000000000001"),
},
}

for _, c := range cases {
addr := addrUnpacked(c.leaf, c.ltd)
if addr != c.want {
if !bytes.Equal(addr, c.want) {
t.Errorf("expected: %v got: %v", c.want, addr)
}
}
Expand All @@ -36,24 +37,90 @@ func TestAddrPacked(t *testing.T) {
cases := []struct {
leaf []byte
ltd []string
want common.Address
want []byte
}{
{
common.Hex2Bytes("000000000000000000000000000000000000000000000001"),
common.FromHex("000000000000000000000000000000000000000000000001"),
[]string{"uint32", "address"},
common.HexToAddress("0x0000000000000000000000000000000000000001"),
common.FromHex("0x0000000000000000000000000000000000000001"),
},
{
common.Hex2Bytes("000000000000000000000000000000000000000100000000"),
common.FromHex("000000000000000000000000000000000000000100000000"),
[]string{"address", "uint32"},
common.HexToAddress("0x0000000000000000000000000000000000000001"),
common.FromHex("0x0000000000000000000000000000000000000001"),
},
}

for _, c := range cases {
addr := addrPacked(c.leaf, c.ltd)
if addr != c.want {
if !bytes.Equal(addr, c.want) {
t.Errorf("expected: %v got: %v", c.want, addr)
}
}
}

func TestLeaf2Addr(t *testing.T) {
cases := []struct {
leaf []byte
ltd []string
packed bool
want []byte
}{
{
common.FromHex("000000000000000000000000000000000000000000000001"),
[]string{"uint32", "address"},
true,
common.FromHex("0x0000000000000000000000000000000000000001"),
},
{
common.FromHex("000000000000000000000000000000000000000100000000"),
[]string{"address", "uint32"},
true,
common.FromHex("0x0000000000000000000000000000000000000001"),
},
{
common.FromHex("0x0000000000000000000000000000000000000001"),
[]string{"address"},
false,
common.FromHex("0x0000000000000000000000000000000000000001"),
},
{
common.FromHex("00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"),
[]string{"uint32", "address"},
false,
common.FromHex("0x0000000000000000000000000000000000000001"),
},
{
common.FromHex("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000"),
[]string{"address", "uint32"},
false,
common.FromHex("0x0000000000000000000000000000000000000001"),
},
{
common.FromHex("00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000001"),
[]string{"uint256", "address"},
false,
common.FromHex("0x0000000000000000000000000000000000000001"),
},
{
common.FromHex("0x0000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000002d"),
[]string{"address", "uint256"},
false,
common.FromHex("0x0000000000000000000000000000000000000001"),
},
{
common.FromHex("0x0000000000000000000000000000000000000000000000000000000000000001"),
[]string{"address"},
true,
common.FromHex("0x0000000000000000000000000000000000000001"),
},
}

for _, c := range cases {
addr := leaf2Addr(c.leaf, c.ltd, c.packed)
if !bytes.Equal(addr, c.want) {
t.Errorf("expected: %v got: %v", c.want, addr)
}
}

}
10 changes: 10 additions & 0 deletions cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"context"
"database/sql"
"flag"
"fmt"
"net"
"net/http"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/jackc/pgx/v4/pgxpool"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/opentracing/opentracing-go"
"github.com/pkg/profile"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer"
Expand All @@ -39,6 +41,14 @@ func main() {
env = "dev"
}

shouldProfile := flag.Bool("profile", false, "enable profiling")
flag.Parse()

if *shouldProfile {
prof := profile.Start(profile.CPUProfile, profile.ProfilePath("."))
defer prof.Stop()
}

ddAgent := os.Getenv("DD_AGENT_HOST")
if ddAgent != "" {
t := opentracer.New(
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/jackc/pgx/v4 v4.16.1
github.com/lib/pq v1.10.9
github.com/opentracing/opentracing-go v1.2.0
github.com/pkg/profile v1.2.1
github.com/rs/cors v1.8.2
github.com/rs/zerolog v1.29.1
golang.org/x/sync v0.3.0
Expand Down Expand Up @@ -43,6 +44,7 @@ require (
github.com/philhofer/fwd v1.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/rs/xid v1.4.0 // indirect
github.com/stretchr/testify v1.8.0 // indirect
github.com/tinylib/msgp v1.1.2 // indirect
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 // indirect
Expand Down
Loading

1 comment on commit 8b608fc

@vercel
Copy link

@vercel vercel bot commented on 8b608fc Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

lanyard – ./

lanyard-git-main.mf.dev
lanyard-production.mf.dev
lanyard.mf.dev

Please sign in to comment.