diff --git a/beacon-chain/blockchain/service.go b/beacon-chain/blockchain/service.go index 0a75ae84ed7e..8fda38225d9b 100644 --- a/beacon-chain/blockchain/service.go +++ b/beacon-chain/blockchain/service.go @@ -126,8 +126,7 @@ func NewService(ctx context.Context, cfg *Config) (*Service, error) { // Start a blockchain service's main event loop. func (s *Service) Start() { - ctx := context.TODO() - beaconState, err := s.beaconDB.HeadState(ctx) + beaconState, err := s.beaconDB.HeadState(s.ctx) if err != nil { log.Fatalf("Could not fetch beacon state: %v", err) } @@ -135,13 +134,13 @@ func (s *Service) Start() { // For running initial sync with state cache, in an event of restart, we use // last finalized check point as start point to sync instead of head // state. This is because we no longer save state every slot during sync. - cp, err := s.beaconDB.FinalizedCheckpoint(ctx) + cp, err := s.beaconDB.FinalizedCheckpoint(s.ctx) if err != nil { log.Fatalf("Could not fetch finalized cp: %v", err) } if beaconState == nil { - beaconState, err = s.stateGen.StateByRoot(ctx, bytesutil.ToBytes32(cp.Root)) + beaconState, err = s.stateGen.StateByRoot(s.ctx, bytesutil.ToBytes32(cp.Root)) if err != nil { log.Fatalf("Could not fetch beacon state by root: %v", err) } @@ -155,29 +154,29 @@ func (s *Service) Start() { log.Info("Blockchain data already exists in DB, initializing...") s.genesisTime = time.Unix(int64(beaconState.GenesisTime()), 0) s.opsService.SetGenesisTime(beaconState.GenesisTime()) - if err := s.initializeChainInfo(ctx); err != nil { + if err := s.initializeChainInfo(s.ctx); err != nil { log.Fatalf("Could not set up chain info: %v", err) } // We start a counter to genesis, if needed. - gState, err := s.beaconDB.GenesisState(ctx) + gState, err := s.beaconDB.GenesisState(s.ctx) if err != nil { log.Fatalf("Could not retrieve genesis state: %v", err) } - go slotutil.CountdownToGenesis(ctx, s.genesisTime, uint64(gState.NumValidators())) + go slotutil.CountdownToGenesis(s.ctx, s.genesisTime, uint64(gState.NumValidators())) - justifiedCheckpoint, err := s.beaconDB.JustifiedCheckpoint(ctx) + justifiedCheckpoint, err := s.beaconDB.JustifiedCheckpoint(s.ctx) if err != nil { log.Fatalf("Could not get justified checkpoint: %v", err) } - finalizedCheckpoint, err := s.beaconDB.FinalizedCheckpoint(ctx) + finalizedCheckpoint, err := s.beaconDB.FinalizedCheckpoint(s.ctx) if err != nil { log.Fatalf("Could not get finalized checkpoint: %v", err) } // Resume fork choice. s.justifiedCheckpt = stateTrie.CopyCheckpoint(justifiedCheckpoint) - if err := s.cacheJustifiedStateBalances(ctx, bytesutil.ToBytes32(s.justifiedCheckpt.Root)); err != nil { + if err := s.cacheJustifiedStateBalances(s.ctx, bytesutil.ToBytes32(s.justifiedCheckpt.Root)); err != nil { log.Fatalf("Could not cache justified state balances: %v", err) } s.prevJustifiedCheckpt = stateTrie.CopyCheckpoint(justifiedCheckpoint) @@ -214,7 +213,7 @@ func (s *Service) Start() { return } log.WithField("starttime", data.StartTime).Debug("Received chain start event") - s.processChainStartTime(ctx, data.StartTime) + s.processChainStartTime(s.ctx, data.StartTime) return } case <-s.ctx.Done(): @@ -296,6 +295,10 @@ func (s *Service) initializeBeaconChain( // Stop the blockchain service's main event loop and associated goroutines. func (s *Service) Stop() error { defer s.cancel() + + if s.stateGen != nil && s.head != nil && s.head.state != nil { + return s.stateGen.ForceCheckpoint(s.ctx, s.head.state.FinalizedCheckpoint().Root) + } return nil } diff --git a/beacon-chain/blockchain/service_test.go b/beacon-chain/blockchain/service_test.go index ae56d448b9ff..ca3e8e006b6b 100644 --- a/beacon-chain/blockchain/service_test.go +++ b/beacon-chain/blockchain/service_test.go @@ -15,8 +15,6 @@ import ( "github.com/prysmaticlabs/prysm/beacon-chain/cache" "github.com/prysmaticlabs/prysm/beacon-chain/cache/depositcache" b "github.com/prysmaticlabs/prysm/beacon-chain/core/blocks" - "github.com/prysmaticlabs/prysm/beacon-chain/core/feed" - statefeed "github.com/prysmaticlabs/prysm/beacon-chain/core/feed/state" "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" "github.com/prysmaticlabs/prysm/beacon-chain/core/state" "github.com/prysmaticlabs/prysm/beacon-chain/db" @@ -121,6 +119,11 @@ func setupBeaconChain(t *testing.T, beaconDB db.Database, sc *cache.StateSummary OpsService: opsService, } + // Safe a state in stategen to purposes of testing a service stop / shutdown. + if err := cfg.StateGen.SaveState(ctx, bytesutil.ToBytes32(bState.FinalizedCheckpoint().Root), bState); err != nil { + t.Fatal(err) + } + chainService, err := NewService(ctx, cfg) require.NoError(t, err, "Unable to setup chain service") chainService.genesisTime = time.Unix(1, 0) // non-zero time @@ -128,54 +131,6 @@ func setupBeaconChain(t *testing.T, beaconDB db.Database, sc *cache.StateSummary return chainService } -func TestChainStartStop_Uninitialized(t *testing.T) { - hook := logTest.NewGlobal() - db, sc := testDB.SetupDB(t) - chainService := setupBeaconChain(t, db, sc) - - // Listen for state events. - stateSubChannel := make(chan *feed.Event, 1) - stateSub := chainService.stateNotifier.StateFeed().Subscribe(stateSubChannel) - - // Test the chain start state notifier. - genesisTime := time.Unix(1, 0) - chainService.Start() - event := &feed.Event{ - Type: statefeed.ChainStarted, - Data: &statefeed.ChainStartedData{ - StartTime: genesisTime, - }, - } - // Send in a loop to ensure it is delivered (busy wait for the service to subscribe to the state feed). - for sent := 1; sent == 1; { - sent = chainService.stateNotifier.StateFeed().Send(event) - if sent == 1 { - // Flush our local subscriber. - <-stateSubChannel - } - } - - // Now wait for notification the state is ready. - for stateInitialized := false; stateInitialized == false; { - recv := <-stateSubChannel - if recv.Type == statefeed.Initialized { - stateInitialized = true - } - } - stateSub.Unsubscribe() - - beaconState, err := db.HeadState(context.Background()) - require.NoError(t, err) - if beaconState == nil || beaconState.Slot() != 0 { - t.Error("Expected canonical state feed to send a state with genesis block") - } - require.NoError(t, chainService.Stop(), "Unable to stop chain service") - // The context should have been canceled. - assert.Equal(t, context.Canceled, chainService.ctx.Err(), "Context was not canceled") - testutil.AssertLogsContain(t, hook, "Waiting") - testutil.AssertLogsContain(t, hook, "Initialized beacon chain genesis state") -} - func TestChainStartStop_Initialized(t *testing.T) { hook := logTest.NewGlobal() ctx := context.Background() diff --git a/beacon-chain/state/stategen/getter.go b/beacon-chain/state/stategen/getter.go index 9da3ce41e531..4329b3ca68c0 100644 --- a/beacon-chain/state/stategen/getter.go +++ b/beacon-chain/state/stategen/getter.go @@ -112,7 +112,7 @@ func (s *State) StateBySlot(ctx context.Context, slot uint64) (*state.BeaconStat // StateSummaryExists returns true if the corresponding state summary of the input block root either // exists in the DB or in the cache. func (s *State) StateSummaryExists(ctx context.Context, blockRoot [32]byte) bool { - return s.beaconDB.HasStateSummary(ctx, blockRoot) || s.stateSummaryCache.Has(blockRoot) + return s.stateSummaryCache.Has(blockRoot) || s.beaconDB.HasStateSummary(ctx, blockRoot) } // This returns the state summary object of a given block root, it first checks the cache @@ -120,6 +120,9 @@ func (s *State) StateSummaryExists(ctx context.Context, blockRoot [32]byte) bool func (s *State) stateSummary(ctx context.Context, blockRoot [32]byte) (*pb.StateSummary, error) { var summary *pb.StateSummary var err error + if s.stateSummaryCache == nil { + return nil, errors.New("nil stateSummaryCache") + } if s.stateSummaryCache.Has(blockRoot) { summary = s.stateSummaryCache.Get(blockRoot) } else { diff --git a/beacon-chain/state/stategen/setter.go b/beacon-chain/state/stategen/setter.go index 2cc5af952520..ad3c64b205bc 100644 --- a/beacon-chain/state/stategen/setter.go +++ b/beacon-chain/state/stategen/setter.go @@ -4,6 +4,7 @@ import ( "context" "github.com/prysmaticlabs/prysm/beacon-chain/state" + "github.com/prysmaticlabs/prysm/shared/bytesutil" "go.opencensus.io/trace" ) @@ -20,3 +21,21 @@ func (s *State) SaveState(ctx context.Context, root [32]byte, state *state.Beaco return s.saveHotState(ctx, root, state) } + +// ForceCheckpoint initiates a cold state save of the given state. This method does not update the +// "last archived state" but simply saves the specified state from the root argument into the DB. +func (s *State) ForceCheckpoint(ctx context.Context, root []byte) error { + ctx, span := trace.StartSpan(ctx, "stateGen.ForceCheckpoint") + defer span.End() + + root32 := bytesutil.ToBytes32(root) + fs, err := s.loadHotStateByRoot(ctx, root32) + if err != nil { + return err + } + if err := s.beaconDB.SaveState(ctx, fs, root32); err != nil { + return err + } + + return nil +} diff --git a/beacon-chain/state/stategen/setter_test.go b/beacon-chain/state/stategen/setter_test.go index 1e80f71543de..e1a1b967a5b5 100644 --- a/beacon-chain/state/stategen/setter_test.go +++ b/beacon-chain/state/stategen/setter_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - //"github.com/gogo/protobuf/proto" "github.com/prysmaticlabs/prysm/beacon-chain/cache" testDB "github.com/prysmaticlabs/prysm/beacon-chain/db/testing" "github.com/prysmaticlabs/prysm/shared/params" @@ -77,3 +76,27 @@ func TestSaveState_HotStateCached(t *testing.T) { assert.Equal(t, false, service.beaconDB.HasStateSummary(ctx, r), "Should have saved the state summary") testutil.AssertLogsDoNotContain(t, hook, "Saved full state on epoch boundary") } + +func TestState_ForceCheckpoint_SavesStateToDatabase(t *testing.T) { + ctx := context.Background() + db, ssc := testDB.SetupDB(t) + + svc := New(db, ssc) + beaconState, _ := testutil.DeterministicGenesisState(t, 32) + if err := beaconState.SetSlot(params.BeaconConfig().SlotsPerEpoch); err != nil { + t.Fatal(err) + } + + r := [32]byte{'a'} + svc.hotStateCache.Put(r, beaconState) + + if db.HasState(ctx, r) { + t.Fatal("Database has state stored already") + } + if err := svc.ForceCheckpoint(ctx, r[:]); err != nil { + t.Error(err) + } + if !db.HasState(ctx, r) { + t.Error("Did not save checkpoint to database") + } +}