Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New beacon state: Only populate merkle layers as needed, copy merkle layers on copy/clone. #4689

Merged
merged 5 commits into from Jan 29, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -159,7 +159,7 @@ func TestStore_SaveCheckpointState(t *testing.T) {
}

cp1 := &ethpb.Checkpoint{Epoch: 1, Root: []byte{'A'}}
s1, err := beaconstate.InitializeFromProto(ss.Clone())
s1, err := beaconstate.InitializeFromProto(ss.CloneInnerState())
if err != nil {
t.Fatal(err)
}
Expand All @@ -173,7 +173,7 @@ func TestStore_SaveCheckpointState(t *testing.T) {

cp2 := &ethpb.Checkpoint{Epoch: 2, Root: []byte{'B'}}

s2, err := beaconstate.InitializeFromProto(ss.Clone())
s2, err := beaconstate.InitializeFromProto(ss.CloneInnerState())
if err != nil {
t.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion beacon-chain/blockchain/forkchoice/process_block.go
Expand Up @@ -311,8 +311,9 @@ func (s *Store) verifyBlkPreState(ctx context.Context, b *ethpb.BeaconBlock) (*s
if preState == nil {
return nil, fmt.Errorf("pre state of slot %d does not exist", b.Slot)
}
return preState, nil // No copy needed from newly hydrated DB object.
}
return stateTrie.InitializeFromProto(preState.Clone())
return preState.Copy(), nil
}
preState, err := s.db.State(ctx, bytesutil.ToBytes32(b.ParentRoot))
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion beacon-chain/blockchain/forkchoice/process_block_test.go
Expand Up @@ -471,7 +471,7 @@ func TestCachedPreState_CanGetFromCacheWithFeature(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(s, received) {
if !reflect.DeepEqual(s.InnerStateUnsafe(), received.InnerStateUnsafe()) {
t.Error("cached state not the same")
}
}
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/blockchain/forkchoice/service_test.go
Expand Up @@ -65,7 +65,7 @@ func TestStore_GenesisStoreOk(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(cachedState, genesisState) {
if !reflect.DeepEqual(cachedState.InnerStateUnsafe(), genesisState.InnerStateUnsafe()) {
t.Error("Incorrect genesis state cached")
}
}
Expand Down Expand Up @@ -380,7 +380,7 @@ func TestCacheGenesisState_Correct(t *testing.T) {
}

for _, state := range store.initSyncState {
if !reflect.DeepEqual(s, state) {
if !reflect.DeepEqual(s.InnerStateUnsafe(), state.InnerStateUnsafe()) {
t.Error("Did not get wanted state")
}
}
Expand Down
5 changes: 1 addition & 4 deletions beacon-chain/blockchain/process_attestation_helpers.go
Expand Up @@ -76,10 +76,7 @@ func (s *Service) saveCheckpointState(ctx context.Context, baseState *stateTrie.

// Advance slots only when it's higher than current state slot.
if helpers.StartSlot(c.Epoch) > baseState.Slot() {
stateCopy, err := stateTrie.InitializeFromProto(baseState.Clone())
if err != nil {
return nil, err
}
stateCopy := baseState.Copy()
stateCopy, err = state.ProcessSlots(ctx, stateCopy, helpers.StartSlot(c.Epoch))
if err != nil {
return nil, errors.Wrapf(err, "could not process slots up to %d", helpers.StartSlot(c.Epoch))
Expand Down
3 changes: 2 additions & 1 deletion beacon-chain/blockchain/process_block_helpers.go
Expand Up @@ -64,8 +64,9 @@ func (s *Service) verifyBlkPreState(ctx context.Context, b *ethpb.BeaconBlock) (
if preState == nil {
return nil, fmt.Errorf("pre state of slot %d does not exist", b.Slot)
}
return preState, nil // No copy needed from newly hydrated DB object.
}
return stateTrie.InitializeFromProto(preState.Clone())
return preState.Copy(), nil
}

preState, err := s.beaconDB.State(ctx, bytesutil.ToBytes32(b.ParentRoot))
Expand Down
2 changes: 1 addition & 1 deletion beacon-chain/blockchain/process_block_test.go
Expand Up @@ -348,7 +348,7 @@ func TestCachedPreState_CanGetFromCacheWithFeature(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(s, received) {
if !reflect.DeepEqual(s.InnerStateUnsafe(), received.InnerStateUnsafe()) {
t.Error("cached state not the same")
}
}
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/cache/skip_slot_cache.go
Expand Up @@ -80,7 +80,7 @@ func (c *SkipSlotCache) Get(ctx context.Context, slot uint64) (*stateTrie.Beacon

if exists && item != nil {
skipSlotCacheHit.Inc()
return stateTrie.InitializeFromProto(item.(*stateTrie.BeaconState).Clone())
return item.(*stateTrie.BeaconState).Copy(), nil
}
skipSlotCacheMiss.Inc()
return nil, nil
Expand Down Expand Up @@ -123,7 +123,7 @@ func (c *SkipSlotCache) Put(ctx context.Context, slot uint64, state *stateTrie.B
return nil
}

// Clone state so cached value is not mutated.
// CloneInnerState state so cached value is not mutated.
c.cache.Add(slot, state)

return nil
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/cache/skip_slot_cache_test.go
Expand Up @@ -31,7 +31,7 @@ func TestSkipSlotCache_RoundTrip(t *testing.T) {
t.Error(err)
}

state, err = stateTrie.InitializeFromProto(&pb.BeaconState{
state, err = stateTrie.InitializeFromProtoUnsafe(&pb.BeaconState{
Slot: 10,
})
if err != nil {
Expand All @@ -51,7 +51,7 @@ func TestSkipSlotCache_RoundTrip(t *testing.T) {
t.Error(err)
}

if !reflect.DeepEqual(state, res) {
if !reflect.DeepEqual(state.InnerStateUnsafe(), res.InnerStateUnsafe()) {
t.Error("Expected equal protos to return from cache")
}
}
4 changes: 2 additions & 2 deletions beacon-chain/core/blocks/spectest/block_header_test.go
Expand Up @@ -73,8 +73,8 @@ func runBlockHeaderTest(t *testing.T, config string) {
if err := ssz.Unmarshal(postBeaconStateFile, postBeaconState); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
if !proto.Equal(beaconState.Clone(), postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.Clone(), postBeaconState)
if !proto.Equal(beaconState.CloneInnerState(), postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.CloneInnerState(), postBeaconState)
t.Log(diff)
t.Fatal("Post state does not match expected")
}
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/core/blocks/spectest/block_processing_test.go
Expand Up @@ -94,8 +94,8 @@ func runBlockProcessingTest(t *testing.T, config string) {
t.Fatalf("Failed to unmarshal: %v", err)
}

if !proto.Equal(beaconState.Clone(), postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.Clone(), postBeaconState)
if !proto.Equal(beaconState.CloneInnerState(), postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.CloneInnerState(), postBeaconState)
t.Log(diff)
t.Fatal("Post state does not match expected")
}
Expand Down
2 changes: 1 addition & 1 deletion beacon-chain/core/helpers/committee_test.go
Expand Up @@ -368,7 +368,7 @@ func TestCommitteeAssignments_AgreesWithSpecDefinitionMethod(t *testing.T) {
})
// Test for 2 epochs.
for epoch := uint64(0); epoch < 2; epoch++ {
state, _ := beaconstate.InitializeFromProto(state.Clone())
state, _ := beaconstate.InitializeFromProto(state.CloneInnerState())
assignments, proposers, err := CommitteeAssignments(state, epoch)
if err != nil {
t.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion beacon-chain/core/state/benchmarks_test.go
Expand Up @@ -163,7 +163,7 @@ func BenchmarkHashTreeRootState_FullState(b *testing.B) {
func clonedStates(beaconState *beaconstate.BeaconState) []*beaconstate.BeaconState {
clonedStates := make([]*beaconstate.BeaconState, runAmount)
for i := 0; i < runAmount; i++ {
c, err := beaconstate.InitializeFromProto(beaconState.Clone())
c, err := beaconstate.InitializeFromProto(beaconState.CloneInnerState())
if err != nil {
panic(err)
}
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/core/state/skip_slot_cache_test.go
Expand Up @@ -14,7 +14,7 @@ import (

func TestSkipSlotCache_OK(t *testing.T) {
bState, privs := testutil.DeterministicGenesisState(t, params.MinimalSpecConfig().MinGenesisActiveValidatorCount)
originalState, _ := beaconstate.InitializeFromProto(bState.Clone())
originalState, _ := beaconstate.InitializeFromProto(bState.CloneInnerState())

blkCfg := testutil.DefaultBlockGenConfig()
blkCfg.NumAttestations = 1
Expand Down Expand Up @@ -43,7 +43,7 @@ func TestSkipSlotCache_OK(t *testing.T) {
t.Fatalf("Could not process state transition: %v", err)
}

if !ssz.DeepEqual(originalState.Clone(), bState.Clone()) {
if !ssz.DeepEqual(originalState.CloneInnerState(), bState.CloneInnerState()) {
t.Fatal("Skipped slots cache leads to different states")
}
}
2 changes: 1 addition & 1 deletion beacon-chain/core/state/spectest/slot_processing_test.go
Expand Up @@ -60,7 +60,7 @@ func runSlotProcessingTests(t *testing.T, config string) {
t.Fatal(err)
}

if !proto.Equal(postState.Clone(), postBeaconState) {
if !proto.Equal(postState.CloneInnerState(), postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState, postBeaconState)
t.Fatalf("Post state does not match expected. Diff between states %s", diff)
}
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/core/state/state_test.go
Expand Up @@ -140,8 +140,8 @@ func TestGenesisState_HashEquality(t *testing.T) {
t.Error(err)
}

root1, err1 := hashutil.HashProto(state1.Clone())
root2, err2 := hashutil.HashProto(state2.Clone())
root1, err1 := hashutil.HashProto(state1.CloneInnerState())
root2, err2 := hashutil.HashProto(state2.CloneInnerState())

if err1 != nil || err2 != nil {
t.Fatalf("Failed to marshal state to bytes: %v %v", err1, err2)
Expand Down
2 changes: 1 addition & 1 deletion beacon-chain/core/state/stateutils/validator_index_map.go
Expand Up @@ -8,7 +8,7 @@ import (
// ValidatorIndexMap builds a lookup map for quickly determining the index of
// a validator by their public key.
func ValidatorIndexMap(validators []*ethpb.Validator) map[[48]byte]uint64 {
m := make(map[[48]byte]uint64)
m := make(map[[48]byte]uint64, len(validators))
if validators == nil {
return m
}
Expand Down
7 changes: 2 additions & 5 deletions beacon-chain/core/state/transition.go
Expand Up @@ -158,15 +158,12 @@ func CalculateStateRoot(
}

// Copy state to avoid mutating the state reference.
state, err := stateTrie.InitializeFromProto(state.Clone())
if err != nil {
return [32]byte{}, err
}
state = state.Copy()

b.ClearEth1DataVoteCache()

// Execute per slots transition.
state, err = ProcessSlots(ctx, state, signed.Block.Slot)
state, err := ProcessSlots(ctx, state, signed.Block.Slot)
if err != nil {
return [32]byte{}, errors.Wrap(err, "could not process slot")
}
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/state/getters.go
Expand Up @@ -70,8 +70,8 @@ func (b *BeaconState) InnerStateUnsafe() *pbp2p.BeaconState {
return b.state
}

// Clone the beacon state into a protobuf for usage.
func (b *BeaconState) Clone() *pbp2p.BeaconState {
// CloneInnerState the beacon state into a protobuf for usage.
func (b *BeaconState) CloneInnerState() *pbp2p.BeaconState {
if b.state == nil {
return nil
}
Expand Down
57 changes: 44 additions & 13 deletions beacon-chain/state/types.go
Expand Up @@ -39,27 +39,58 @@ func InitializeFromProto(st *pbp2p.BeaconState) (*BeaconState, error) {
// InitializeFromProtoUnsafe directly uses the beacon state protobuf pointer
// and sets it as the inner state of the BeaconState type.
func InitializeFromProtoUnsafe(st *pbp2p.BeaconState) (*BeaconState, error) {
fieldRoots, err := stateutil.ComputeFieldRoots(st)
if err != nil {
return nil, err
b := &BeaconState{
state: st,
dirtyFields: make(map[fieldIndex]interface{}, 20),
valIdxMap: coreutils.ValidatorIndexMap(st.Validators),
}
layers := merkleize(fieldRoots)
valMap := coreutils.ValidatorIndexMap(st.Validators)
return &BeaconState{
state: st,
merkleLayers: layers,
dirtyFields: make(map[fieldIndex]interface{}),
valIdxMap: valMap,
}, nil
return b, nil
}

// Copy returns a deep copy of the beacon state.
func (b *BeaconState) Copy() *BeaconState {
b.lock.RLock()
defer b.lock.RUnlock()
dst := &BeaconState{
state: b.CloneInnerState(),
dirtyFields: make(map[fieldIndex]interface{}, 20),
valIdxMap: make(map[[48]byte]uint64, len(b.valIdxMap)),
}

for i := range b.dirtyFields {
dst.dirtyFields[i] = true
}

for i := range b.valIdxMap {
dst.valIdxMap[i] = b.valIdxMap[i]
}

dst.merkleLayers = make([][][]byte, len(b.merkleLayers))
for i, layer := range b.merkleLayers {
dst.merkleLayers[i] = make([][]byte, len(layer))
for j, content := range layer {
dst.merkleLayers[i][j] = make([]byte, len(content))
copy(dst.merkleLayers[i][j], content)
}
}

return dst
}

// HashTreeRoot of the beacon state retrieves the Merkle root of the trie
// representation of the beacon state based on the eth2 Simple Serialize specification.
func (b *BeaconState) HashTreeRoot() ([32]byte, error) {
b.lock.Lock()
defer b.lock.Unlock()
if len(b.merkleLayers) == 0 {
return [32]byte{}, errors.New("state merkle layers not initialized")

if b.merkleLayers == nil || len(b.merkleLayers) == 0 {
fieldRoots, err := stateutil.ComputeFieldRoots(b.state)
if err != nil {
return [32]byte{}, err
}
layers := merkleize(fieldRoots)
b.merkleLayers = layers
b.dirtyFields = make(map[fieldIndex]interface{})
}

for field := range b.dirtyFields {
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/state/types_test.go
Expand Up @@ -21,7 +21,7 @@ func TestBeaconState_ProtoBeaconStateCompatibility(t *testing.T) {
t.Fatal(err)
}
cloned := proto.Clone(genesis).(*pb.BeaconState)
custom := customState.Clone()
custom := customState.CloneInnerState()
if !proto.Equal(cloned, custom) {
t.Fatal("Cloned states did not match")
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func BenchmarkStateClone_Manual(b *testing.B) {
}
b.StartTimer()
for i := 0; i < b.N; i++ {
_ = st.Clone()
_ = st.CloneInnerState()
}
}

Expand Down
2 changes: 1 addition & 1 deletion shared/interop/generate_genesis_state.go
Expand Up @@ -56,7 +56,7 @@ func GenerateGenesisState(genesisTime, numValidators uint64) (*pb.BeaconState, [
if err != nil {
return nil, nil, errors.Wrap(err, "could not generate genesis state")
}
return beaconState.Clone(), deposits, nil
return beaconState.CloneInnerState(), deposits, nil
}

// GenerateDepositsFromData a list of deposit items by creating proofs for each of them from a sparse Merkle trie.
Expand Down
10 changes: 4 additions & 6 deletions shared/testutil/block.go
Expand Up @@ -53,15 +53,13 @@ func GenerateFullBlock(
if currentSlot > slot {
return nil, fmt.Errorf("current slot in state is larger than given slot. %d > %d", currentSlot, slot)
}
bState, err := stateTrie.InitializeFromProto(bState.Clone())
if err != nil {
return nil, err
}
bState = bState.Copy()

if conf == nil {
conf = &BlockGenConfig{}
}

var err error
pSlashings := []*ethpb.ProposerSlashing{}
numToGen := conf.NumProposerSlashings
if numToGen > 0 {
Expand Down Expand Up @@ -302,7 +300,7 @@ func GenerateAttestations(
currentEpoch := helpers.SlotToEpoch(slot)
attestations := []*ethpb.Attestation{}
generateHeadState := false
bState, err := stateTrie.InitializeFromProtoUnsafe(bState.Clone())
bState, err := stateTrie.InitializeFromProtoUnsafe(bState.CloneInnerState())
if err != nil {
return nil, err
}
Expand All @@ -316,7 +314,7 @@ func GenerateAttestations(
headRoot := make([]byte, 32)
// Only calculate head state if its an attestation for the current slot or future slot.
if generateHeadState || slot == bState.Slot() {
headState, err := stateTrie.InitializeFromProtoUnsafe(bState.Clone())
headState, err := stateTrie.InitializeFromProtoUnsafe(bState.CloneInnerState())
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions shared/testutil/spectest.go
Expand Up @@ -116,7 +116,7 @@ func RunBlockOperationTest(
t.Fatalf("Failed to unmarshal: %v", err)
}

if !proto.Equal(beaconState.Clone(), postBeaconState) {
if !proto.Equal(beaconState.CloneInnerState(), postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState, postBeaconState)
t.Log(diff)
t.Fatal("Post state does not match expected")
Expand Down Expand Up @@ -177,8 +177,8 @@ func RunEpochOperationTest(
t.Fatalf("Failed to unmarshal: %v", err)
}

if !proto.Equal(beaconState.Clone(), postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.Clone(), postBeaconState)
if !proto.Equal(beaconState.InnerStateUnsafe(), postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.InnerStateUnsafe(), postBeaconState)
t.Log(diff)
t.Fatal("Post state does not match expected")
}
Expand Down