diff --git a/overlord/auth/auth.go b/overlord/auth/auth.go index 7ac3a6725c12..31c522a4b8a8 100644 --- a/overlord/auth/auth.go +++ b/overlord/auth/auth.go @@ -28,6 +28,7 @@ import ( "os" "sort" "strconv" + "time" "golang.org/x/net/context" @@ -373,6 +374,11 @@ type DeviceAssertions interface { // Serial returns the device serial assertion. Serial() (*asserts.Serial, error) + // EnsureSerial does a best-effort of triggering and waiting + // up to timeout for registration to occur and returns the + // serial if now available, or ErrNoState otherwise. + EnsureSerial(context.Context, time.Duration) (*asserts.Serial, error) + // DeviceSessionRequestParams produces a device-session-request with the given nonce, together with other required parameters, the device serial and model assertions. DeviceSessionRequestParams(nonce string) (*DeviceSessionRequestParams, error) // ProxyStore returns the store assertion for the proxy store if one is set. @@ -404,6 +410,8 @@ type AuthContext interface { StoreID(fallback string) (string, error) + EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) + DeviceSessionRequestParams(nonce string) (*DeviceSessionRequestParams, error) ProxyStoreParams(defaultURL *url.URL) (proxyStoreID string, proxySroreURL *url.URL, err error) @@ -499,6 +507,20 @@ func (ac *authContext) StoreID(fallback string) (string, error) { return fallback, nil } +// EnsureSerial does a best-effort of triggering and waiting +// up to timeout for registration to occur and returns the +// serial if now available, or ErrNoSerial otherwise. +func (ac *authContext) EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) { + if ac.deviceAsserts == nil { + return nil, ErrNoSerial + } + serial, err := ac.deviceAsserts.EnsureSerial(ctx, timeout) + if err == state.ErrNoState { + return nil, ErrNoSerial + } + return serial, err +} + // DeviceSessionRequestParams produces a device-session-request with the given nonce, together with other required parameters, the device serial and model assertions. It returns ErrNoSerial if the device serial is not yet initialized. func (ac *authContext) DeviceSessionRequestParams(nonce string) (*DeviceSessionRequestParams, error) { if ac.deviceAsserts == nil { diff --git a/overlord/auth/auth_test.go b/overlord/auth/auth_test.go index c1226504bfbc..d88618670558 100644 --- a/overlord/auth/auth_test.go +++ b/overlord/auth/auth_test.go @@ -552,6 +552,13 @@ func (as *authSuite) TestAuthContextDeviceSessionRequestParamsNilDeviceAssertion c.Check(err, Equals, auth.ErrNoSerial) } +func (as *authSuite) TestAuthContextEnsureSerialNilDeviceAssertions(c *C) { + authContext := auth.NewAuthContext(as.state, nil) + + _, err := authContext.EnsureSerial(context.TODO(), 5*time.Second) + c.Check(err, Equals, auth.ErrNoSerial) +} + func (as *authSuite) TestAuthContextCloudInfo(c *C) { authContext := auth.NewAuthContext(as.state, nil) @@ -661,6 +668,13 @@ func (da *testDeviceAssertions) Serial() (*asserts.Serial, error) { return a.(*asserts.Serial), nil } +func (da *testDeviceAssertions) EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) { + if ctx == nil { + panic("context required") + } + return da.Serial() +} + func (da *testDeviceAssertions) DeviceSessionRequestParams(nonce string) (*auth.DeviceSessionRequestParams, error) { if da.nothing { return nil, state.ErrNoState @@ -704,7 +718,10 @@ func (as *authSuite) TestAuthContextMissingDeviceAssertions(c *C) { // no assertions in state authContext := auth.NewAuthContext(as.state, &testDeviceAssertions{nothing: true}) - _, err := authContext.DeviceSessionRequestParams("NONCE") + _, err := authContext.EnsureSerial(context.TODO(), 0) + c.Check(err, Equals, auth.ErrNoSerial) + + _, err = authContext.DeviceSessionRequestParams("NONCE") c.Check(err, Equals, auth.ErrNoSerial) storeID, err := authContext.StoreID("fallback") @@ -721,6 +738,10 @@ func (as *authSuite) TestAuthContextWithDeviceAssertions(c *C) { // having assertions in state authContext := auth.NewAuthContext(as.state, &testDeviceAssertions{}) + serialAssert, err := authContext.EnsureSerial(context.TODO(), 0) + c.Check(err, IsNil) + c.Check(serialAssert.Serial(), Equals, "9999") + params, err := authContext.DeviceSessionRequestParams("NONCE-1") c.Assert(err, IsNil) diff --git a/overlord/devicestate/devicemgr.go b/overlord/devicestate/devicemgr.go index b1d231e527c3..a767b5cbe73e 100644 --- a/overlord/devicestate/devicemgr.go +++ b/overlord/devicestate/devicemgr.go @@ -297,7 +297,9 @@ func (m *DeviceManager) ensureOperational() error { // on classic the first time around we just do // prepare-device + key gen and then pause the process // until the first store interaction - full = err != state.ErrNoState + if err == state.ErrNoState { + full = false + } } if full { @@ -357,13 +359,10 @@ func (m *DeviceManager) Registered() <-chan struct{} { return m.reg } -// waitForRegistrationTimeout is the timeout after which waitForRegistration will give up, about the same as network retries max timeout -var waitForRegistrationTimeout = 30 * time.Second - // ensureSerial does a best-effort of triggering and waiting for // registration to occur and returns the serial if now available, or // ErrNoState otherwise. Assumes state is locked. -func (m *DeviceManager) ensureSerial(ctx context.Context) (*asserts.Serial, error) { +func (m *DeviceManager) ensureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) { serial, err := Serial(m.state) if err != nil && err != state.ErrNoState { return nil, err @@ -388,7 +387,7 @@ func (m *DeviceManager) ensureSerial(ctx context.Context) (*asserts.Serial, erro // registered case <-m.regFirstAttempt: // first registration attempt finished (successfully or not) - case <-time.After(waitForRegistrationTimeout): + case <-time.After(timeout): // neither, timed out } m.state.Lock() @@ -605,12 +604,12 @@ func (m *DeviceManager) ProxyStore() (*asserts.Store, error) { return ProxyStore(m.state) } -// EnsureSerial does a best-effort of triggering and waiting for +// EnsureSerial does a best-effort of triggering and waiting up to timeout for // registration to occur and returns the serial if now available, or // ErrNoState otherwise. -func (m *DeviceManager) EnsureSerial(ctx context.Context) (*asserts.Serial, error) { +func (m *DeviceManager) EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) { m.state.Lock() defer m.state.Unlock() - return m.ensureSerial(ctx) + return m.ensureSerial(ctx, timeout) } diff --git a/overlord/devicestate/devicestate.go b/overlord/devicestate/devicestate.go index 25b73f7e1d01..0cc1921c3cc9 100644 --- a/overlord/devicestate/devicestate.go +++ b/overlord/devicestate/devicestate.go @@ -24,6 +24,7 @@ package devicestate import ( "fmt" "sync" + "time" "golang.org/x/net/context" @@ -280,11 +281,22 @@ func CanManageRefreshes(st *state.State) bool { return false } -// EnsureRegistration does a best-effort of triggering and waiting for registration to occur. -// XXX opts -func EnsureRegistration(ctx context.Context, st *state.State) (proceed bool, err error) { +// ensureRegistrationDefaultTimeout is the default timeout after which waitForRegistration will give up, about the same as network retries max timeout +var ensureRegistrationDefaultTimeout = 30 * time.Second + +// EnsureRegistration does a best-effort of triggering and waiting for registration to occur controlled by optional opts. +func EnsureRegistration(ctx context.Context, st *state.State, opts *snapstate.EnsureRegistrationOptions) (proceed bool, err error) { m := deviceMgr(st) - serial, err := m.ensureSerial(ctx) + + if opts == nil { + opts = &snapstate.EnsureRegistrationOptions{} + } + timeout := ensureRegistrationDefaultTimeout + if opts.Timeout != 0 { + timeout = opts.Timeout + } + + serial, err := m.ensureSerial(ctx, timeout) if err != nil && err != state.ErrNoState { return false, err } diff --git a/overlord/devicestate/devicestate_test.go b/overlord/devicestate/devicestate_test.go index 24641349011e..f1e71547c09d 100644 --- a/overlord/devicestate/devicestate_test.go +++ b/overlord/devicestate/devicestate_test.go @@ -502,7 +502,7 @@ func (s *deviceMgrSuite) TestFullDeviceRegistrationHappyClassic(c *C) { s.state.Unlock() go func() { // retriggers registration - serial, err := s.mgr.EnsureSerial(context.TODO()) + serial, err := s.mgr.EnsureSerial(context.TODO(), 5*time.Second) c.Check(err, IsNil) ch <- serial }() @@ -525,7 +525,7 @@ func (s *deviceMgrSuite) TestFullDeviceRegistrationHappyClassic(c *C) { select { case serial0 = <-ch: case <-time.After(5 * time.Second): - c.Fatal("waitForSerial should have returned by now") + c.Fatal("EnsureSerial should have returned by now") } s.state.Lock() @@ -564,12 +564,12 @@ func (s *deviceMgrSuite) TestFullDeviceRegistrationHappyClassic(c *C) { // now EnsureSerial just returns s.state.Unlock() - serial1, err := s.mgr.EnsureSerial(context.TODO()) + serial1, err := s.mgr.EnsureSerial(context.TODO(), 0) s.state.Lock() c.Assert(err, IsNil) c.Check(serial1.Serial(), Equals, "9999") // and EnsureRegistration returns proceed - p, err := devicestate.EnsureRegistration(context.TODO(), s.state) + p, err := devicestate.EnsureRegistration(context.TODO(), s.state, nil) c.Check(err, IsNil) c.Check(p, Equals, true) } @@ -582,7 +582,7 @@ func (s *deviceMgrSuite) TestEnsureRegistrationGiveUpNotFirstAttempts(c *C) { devicestate.IncEnsureOperationalAttempts(s.state) devicestate.IncEnsureOperationalAttempts(s.state) // number of attempts is at 2, give up - p, err := devicestate.EnsureRegistration(context.TODO(), s.state) + p, err := devicestate.EnsureRegistration(context.TODO(), s.state, nil) c.Check(err, IsNil) c.Check(p, Equals, false) c.Check(logbuf.String(), Equals, "") @@ -593,24 +593,22 @@ func (s *deviceMgrSuite) TestEnsureRegistrationGiveUpInsideEnsureLoop(c *C) { defer r() s.state.Lock() defer s.state.Unlock() - devicestate.IncEnsureOperationalAttempts(s.state) - devicestate.IncEnsureOperationalAttempts(s.state) // inside Ensure loop, give up - p, err := devicestate.EnsureRegistration(auth.EnsureContextTODO(), s.state) + p, err := devicestate.EnsureRegistration(auth.EnsureContextTODO(), s.state, nil) c.Check(err, IsNil) c.Check(p, Equals, false) c.Check(logbuf.String(), Equals, "") } -func (s *deviceMgrSuite) TestEnsureRegistrationGiveUpTimeout(c *C) { +func (s *deviceMgrSuite) TestEnsureRegistrationGiveUpDefaultTimeout(c *C) { logbuf, r1 := logger.MockLogger() defer r1() - r2 := devicestate.MockWaitForRegistraionTimeout(100 * time.Millisecond) + r2 := devicestate.MockEnsureRegistrationDefaultTimeout(100 * time.Millisecond) defer r2() s.state.Lock() defer s.state.Unlock() // will wait full timeout - p, err := devicestate.EnsureRegistration(context.TODO(), s.state) + p, err := devicestate.EnsureRegistration(context.TODO(), s.state, nil) c.Check(err, IsNil) c.Check(p, Equals, false) c.Check(logbuf.String(), Matches, "(?s).*first device registration attempt did not succeed, registration will be retried but some operations might fail until the device is registered\n") @@ -1273,7 +1271,7 @@ version: gadget // fails fast s.state.Unlock() - _, err = s.mgr.EnsureSerial(context.TODO()) + _, err = s.mgr.EnsureSerial(context.TODO(), 5*time.Second) s.state.Lock() c.Check(err, Equals, state.ErrNoState) diff --git a/overlord/devicestate/export_test.go b/overlord/devicestate/export_test.go index a6cce5701467..ca542f202754 100644 --- a/overlord/devicestate/export_test.go +++ b/overlord/devicestate/export_test.go @@ -86,11 +86,11 @@ func SetLastBecomeOperationalAttempt(m *DeviceManager, t time.Time) { m.lastBecomeOperationalAttempt = t } -func MockWaitForRegistraionTimeout(timeout time.Duration) (restore func()) { - old := waitForRegistrationTimeout - waitForRegistrationTimeout = timeout +func MockEnsureRegistrationDefaultTimeout(timeout time.Duration) (restore func()) { + old := ensureRegistrationDefaultTimeout + ensureRegistrationDefaultTimeout = timeout return func() { - waitForRegistrationTimeout = old + ensureRegistrationDefaultTimeout = old } } diff --git a/overlord/managers_test.go b/overlord/managers_test.go index dda907a3bac1..dd698d438f2c 100644 --- a/overlord/managers_test.go +++ b/overlord/managers_test.go @@ -1930,7 +1930,7 @@ func (s *authContextSetupSuite) TestDeviceSessionRequestParams(c *C) { defer st.Unlock() st.Unlock() - _, err := s.ac.DeviceSessionRequestParams("NONCE") + _, err = s.ac.DeviceSessionRequestParams("NONCE") st.Lock() c.Check(err, Equals, auth.ErrNoSerial) @@ -1958,6 +1958,11 @@ func (s *authContextSetupSuite) TestDeviceSessionRequestParams(c *C) { c.Check(params.EncodedSerial(), DeepEquals, string(asserts.Encode(s.serial))) c.Check(params.EncodedModel(), DeepEquals, string(asserts.Encode(s.model))) + st.Unlock() + serial, err := s.ac.EnsureSerial(context.TODO(), 10*time.Second) + st.Lock() + c.Check(err, IsNil) + c.Check(serial.Serial(), Equals, "") } func (s *authContextSetupSuite) TestProxyStoreParams(c *C) { diff --git a/overlord/snapstate/autorefresh.go b/overlord/snapstate/autorefresh.go index 3b7a31c31e5a..c41cf5ef19bc 100644 --- a/overlord/snapstate/autorefresh.go +++ b/overlord/snapstate/autorefresh.go @@ -24,6 +24,8 @@ import ( "os" "time" + "golang.org/x/net/context" + "github.com/snapcore/snapd/i18n" "github.com/snapcore/snapd/logger" "github.com/snapcore/snapd/overlord/auth" @@ -45,8 +47,20 @@ const maxPostponement = 60 * 24 * time.Hour var ( CanAutoRefresh func(st *state.State) (bool, error) CanManageRefreshes func(st *state.State) bool + EnsureRegistration func(ctx context.Context, st *state.State, opts EnsureRegistrationOptions) (proceed bool, err error) ) +type EnsureRegistrationOptions struct { + // AfterAttemptsProceedAnyway is the number (if != 0) of + // registration attempts after which is allowed to proceed anyway. + AfterAttemptsProceedAnyway int + // CustomStoreOnly means registration is required only with a + // custom store. + CustomStoreOnly bool + // Timeout if waiting for registration. + Timeout time.Duration +} + // refreshRetryDelay specified the minimum time to retry failed refreshes var refreshRetryDelay = 10 * time.Minute diff --git a/store/store.go b/store/store.go index c699690f5be9..64f4a81f4f60 100644 --- a/store/store.go +++ b/store/store.go @@ -85,6 +85,8 @@ var defaultRetryStrategy = retry.LimitCount(5, retry.LimitTime(38*time.Second, }, )) +var ensureSerialTimeout = 30 * time.Second + // Config represents the configuration to access the snap store type Config struct { // Store API base URLs. The assertions url is only separate because it can @@ -557,6 +559,14 @@ func (s *Store) refreshDeviceSession(ctx context.Context, device *auth.DeviceSta return fmt.Errorf("internal error: no authContext") } + if device.Serial == "" { + // best-effort to be registered and have a serial + _, err := s.authContext.EnsureSerial(ctx, ensureSerialTimeout) + if err != nil { + return err + } + } + nonce, err := requestStoreDeviceNonce(s.endpointURL(deviceNonceEndpPath, nil).String()) if err != nil { return err diff --git a/store/store_test.go b/store/store_test.go index ff4c54920776..b2a737533f1d 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -225,6 +225,8 @@ type testAuthContext struct { storeID string cloudInfo *auth.CloudInfo + + ensuredSerial bool } func (ac *testAuthContext) Device() (*auth.DeviceState, error) { @@ -257,6 +259,14 @@ func (ac *testAuthContext) StoreID(fallback string) (string, error) { return fallback, nil } +func (ac *testAuthContext) EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) { + if ctx == nil { + panic("context required") + } + ac.ensuredSerial = true + return nil, nil +} + func (ac *testAuthContext) DeviceSessionRequestParams(nonce string) (*auth.DeviceSessionRequestParams, error) { model, err := asserts.Decode([]byte(exModel)) if err != nil { @@ -1694,6 +1704,7 @@ func (s *storeTestSuite) TestDoRequestSetsAndRefreshesDeviceAuth(c *C) { c.Check(string(responseData), Equals, "response-data") c.Check(deviceSessionRequested, Equals, true) c.Check(refreshSessionRequested, Equals, true) + c.Check(authContext.ensuredSerial, Equals, true) } func (s *storeTestSuite) TestDoRequestSetsAndRefreshesBothAuths(c *C) {