Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
pedronis committed Apr 6, 2018
1 parent ac824fb commit ee1b28e
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 31 deletions.
22 changes: 22 additions & 0 deletions overlord/auth/auth.go
Expand Up @@ -28,6 +28,7 @@ import (
"os"
"sort"
"strconv"
"time"

"golang.org/x/net/context"

Expand Down Expand Up @@ -373,6 +374,11 @@ type DeviceAssertions interface {
// Serial returns the device serial assertion.
Serial() (*asserts.Serial, error)

// EnsureSerial does a best-effort of triggering and waiting
// up to timeout for registration to occur and returns the
// serial if now available, or ErrNoState otherwise.
EnsureSerial(context.Context, time.Duration) (*asserts.Serial, error)

// DeviceSessionRequestParams produces a device-session-request with the given nonce, together with other required parameters, the device serial and model assertions.
DeviceSessionRequestParams(nonce string) (*DeviceSessionRequestParams, error)
// ProxyStore returns the store assertion for the proxy store if one is set.
Expand Down Expand Up @@ -404,6 +410,8 @@ type AuthContext interface {

StoreID(fallback string) (string, error)

EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error)

DeviceSessionRequestParams(nonce string) (*DeviceSessionRequestParams, error)
ProxyStoreParams(defaultURL *url.URL) (proxyStoreID string, proxySroreURL *url.URL, err error)

Expand Down Expand Up @@ -499,6 +507,20 @@ func (ac *authContext) StoreID(fallback string) (string, error) {
return fallback, nil
}

// EnsureSerial does a best-effort of triggering and waiting
// up to timeout for registration to occur and returns the
// serial if now available, or ErrNoSerial otherwise.
func (ac *authContext) EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) {
if ac.deviceAsserts == nil {
return nil, ErrNoSerial
}
serial, err := ac.deviceAsserts.EnsureSerial(ctx, timeout)
if err == state.ErrNoState {
return nil, ErrNoSerial
}
return serial, err
}

// DeviceSessionRequestParams produces a device-session-request with the given nonce, together with other required parameters, the device serial and model assertions. It returns ErrNoSerial if the device serial is not yet initialized.
func (ac *authContext) DeviceSessionRequestParams(nonce string) (*DeviceSessionRequestParams, error) {
if ac.deviceAsserts == nil {
Expand Down
23 changes: 22 additions & 1 deletion overlord/auth/auth_test.go
Expand Up @@ -552,6 +552,13 @@ func (as *authSuite) TestAuthContextDeviceSessionRequestParamsNilDeviceAssertion
c.Check(err, Equals, auth.ErrNoSerial)
}

func (as *authSuite) TestAuthContextEnsureSerialNilDeviceAssertions(c *C) {
authContext := auth.NewAuthContext(as.state, nil)

_, err := authContext.EnsureSerial(context.TODO(), 5*time.Second)
c.Check(err, Equals, auth.ErrNoSerial)
}

func (as *authSuite) TestAuthContextCloudInfo(c *C) {
authContext := auth.NewAuthContext(as.state, nil)

Expand Down Expand Up @@ -661,6 +668,13 @@ func (da *testDeviceAssertions) Serial() (*asserts.Serial, error) {
return a.(*asserts.Serial), nil
}

func (da *testDeviceAssertions) EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) {
if ctx == nil {
panic("context required")
}
return da.Serial()
}

func (da *testDeviceAssertions) DeviceSessionRequestParams(nonce string) (*auth.DeviceSessionRequestParams, error) {
if da.nothing {
return nil, state.ErrNoState
Expand Down Expand Up @@ -704,7 +718,10 @@ func (as *authSuite) TestAuthContextMissingDeviceAssertions(c *C) {
// no assertions in state
authContext := auth.NewAuthContext(as.state, &testDeviceAssertions{nothing: true})

_, err := authContext.DeviceSessionRequestParams("NONCE")
_, err := authContext.EnsureSerial(context.TODO(), 0)
c.Check(err, Equals, auth.ErrNoSerial)

_, err = authContext.DeviceSessionRequestParams("NONCE")
c.Check(err, Equals, auth.ErrNoSerial)

storeID, err := authContext.StoreID("fallback")
Expand All @@ -721,6 +738,10 @@ func (as *authSuite) TestAuthContextWithDeviceAssertions(c *C) {
// having assertions in state
authContext := auth.NewAuthContext(as.state, &testDeviceAssertions{})

serialAssert, err := authContext.EnsureSerial(context.TODO(), 0)
c.Check(err, IsNil)
c.Check(serialAssert.Serial(), Equals, "9999")

params, err := authContext.DeviceSessionRequestParams("NONCE-1")
c.Assert(err, IsNil)

Expand Down
17 changes: 8 additions & 9 deletions overlord/devicestate/devicemgr.go
Expand Up @@ -297,7 +297,9 @@ func (m *DeviceManager) ensureOperational() error {
// on classic the first time around we just do
// prepare-device + key gen and then pause the process
// until the first store interaction
full = err != state.ErrNoState
if err == state.ErrNoState {
full = false
}
}

if full {
Expand Down Expand Up @@ -357,13 +359,10 @@ func (m *DeviceManager) Registered() <-chan struct{} {
return m.reg
}

// waitForRegistrationTimeout is the timeout after which waitForRegistration will give up, about the same as network retries max timeout
var waitForRegistrationTimeout = 30 * time.Second

// ensureSerial does a best-effort of triggering and waiting for
// registration to occur and returns the serial if now available, or
// ErrNoState otherwise. Assumes state is locked.
func (m *DeviceManager) ensureSerial(ctx context.Context) (*asserts.Serial, error) {
func (m *DeviceManager) ensureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) {
serial, err := Serial(m.state)
if err != nil && err != state.ErrNoState {
return nil, err
Expand All @@ -388,7 +387,7 @@ func (m *DeviceManager) ensureSerial(ctx context.Context) (*asserts.Serial, erro
// registered
case <-m.regFirstAttempt:
// first registration attempt finished (successfully or not)
case <-time.After(waitForRegistrationTimeout):
case <-time.After(timeout):
// neither, timed out
}
m.state.Lock()
Expand Down Expand Up @@ -605,12 +604,12 @@ func (m *DeviceManager) ProxyStore() (*asserts.Store, error) {
return ProxyStore(m.state)
}

// EnsureSerial does a best-effort of triggering and waiting for
// EnsureSerial does a best-effort of triggering and waiting up to timeout for
// registration to occur and returns the serial if now available, or
// ErrNoState otherwise.
func (m *DeviceManager) EnsureSerial(ctx context.Context) (*asserts.Serial, error) {
func (m *DeviceManager) EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) {
m.state.Lock()
defer m.state.Unlock()

return m.ensureSerial(ctx)
return m.ensureSerial(ctx, timeout)
}
20 changes: 16 additions & 4 deletions overlord/devicestate/devicestate.go
Expand Up @@ -24,6 +24,7 @@ package devicestate
import (
"fmt"
"sync"
"time"

"golang.org/x/net/context"

Expand Down Expand Up @@ -280,11 +281,22 @@ func CanManageRefreshes(st *state.State) bool {
return false
}

// EnsureRegistration does a best-effort of triggering and waiting for registration to occur.
// XXX opts
func EnsureRegistration(ctx context.Context, st *state.State) (proceed bool, err error) {
// ensureRegistrationDefaultTimeout is the default timeout after which waitForRegistration will give up, about the same as network retries max timeout
var ensureRegistrationDefaultTimeout = 30 * time.Second

// EnsureRegistration does a best-effort of triggering and waiting for registration to occur controlled by optional opts.
func EnsureRegistration(ctx context.Context, st *state.State, opts *snapstate.EnsureRegistrationOptions) (proceed bool, err error) {
m := deviceMgr(st)
serial, err := m.ensureSerial(ctx)

if opts == nil {
opts = &snapstate.EnsureRegistrationOptions{}
}
timeout := ensureRegistrationDefaultTimeout
if opts.Timeout != 0 {
timeout = opts.Timeout
}

serial, err := m.ensureSerial(ctx, timeout)
if err != nil && err != state.ErrNoState {
return false, err
}
Expand Down
22 changes: 10 additions & 12 deletions overlord/devicestate/devicestate_test.go
Expand Up @@ -502,7 +502,7 @@ func (s *deviceMgrSuite) TestFullDeviceRegistrationHappyClassic(c *C) {
s.state.Unlock()
go func() {
// retriggers registration
serial, err := s.mgr.EnsureSerial(context.TODO())
serial, err := s.mgr.EnsureSerial(context.TODO(), 5*time.Second)
c.Check(err, IsNil)
ch <- serial
}()
Expand All @@ -525,7 +525,7 @@ func (s *deviceMgrSuite) TestFullDeviceRegistrationHappyClassic(c *C) {
select {
case serial0 = <-ch:
case <-time.After(5 * time.Second):
c.Fatal("waitForSerial should have returned by now")
c.Fatal("EnsureSerial should have returned by now")
}

s.state.Lock()
Expand Down Expand Up @@ -564,12 +564,12 @@ func (s *deviceMgrSuite) TestFullDeviceRegistrationHappyClassic(c *C) {

// now EnsureSerial just returns
s.state.Unlock()
serial1, err := s.mgr.EnsureSerial(context.TODO())
serial1, err := s.mgr.EnsureSerial(context.TODO(), 0)
s.state.Lock()
c.Assert(err, IsNil)
c.Check(serial1.Serial(), Equals, "9999")
// and EnsureRegistration returns proceed
p, err := devicestate.EnsureRegistration(context.TODO(), s.state)
p, err := devicestate.EnsureRegistration(context.TODO(), s.state, nil)
c.Check(err, IsNil)
c.Check(p, Equals, true)
}
Expand All @@ -582,7 +582,7 @@ func (s *deviceMgrSuite) TestEnsureRegistrationGiveUpNotFirstAttempts(c *C) {
devicestate.IncEnsureOperationalAttempts(s.state)
devicestate.IncEnsureOperationalAttempts(s.state)
// number of attempts is at 2, give up
p, err := devicestate.EnsureRegistration(context.TODO(), s.state)
p, err := devicestate.EnsureRegistration(context.TODO(), s.state, nil)
c.Check(err, IsNil)
c.Check(p, Equals, false)
c.Check(logbuf.String(), Equals, "")
Expand All @@ -593,24 +593,22 @@ func (s *deviceMgrSuite) TestEnsureRegistrationGiveUpInsideEnsureLoop(c *C) {
defer r()
s.state.Lock()
defer s.state.Unlock()
devicestate.IncEnsureOperationalAttempts(s.state)
devicestate.IncEnsureOperationalAttempts(s.state)
// inside Ensure loop, give up
p, err := devicestate.EnsureRegistration(auth.EnsureContextTODO(), s.state)
p, err := devicestate.EnsureRegistration(auth.EnsureContextTODO(), s.state, nil)
c.Check(err, IsNil)
c.Check(p, Equals, false)
c.Check(logbuf.String(), Equals, "")
}

func (s *deviceMgrSuite) TestEnsureRegistrationGiveUpTimeout(c *C) {
func (s *deviceMgrSuite) TestEnsureRegistrationGiveUpDefaultTimeout(c *C) {
logbuf, r1 := logger.MockLogger()
defer r1()
r2 := devicestate.MockWaitForRegistraionTimeout(100 * time.Millisecond)
r2 := devicestate.MockEnsureRegistrationDefaultTimeout(100 * time.Millisecond)
defer r2()
s.state.Lock()
defer s.state.Unlock()
// will wait full timeout
p, err := devicestate.EnsureRegistration(context.TODO(), s.state)
p, err := devicestate.EnsureRegistration(context.TODO(), s.state, nil)
c.Check(err, IsNil)
c.Check(p, Equals, false)
c.Check(logbuf.String(), Matches, "(?s).*first device registration attempt did not succeed, registration will be retried but some operations might fail until the device is registered\n")
Expand Down Expand Up @@ -1273,7 +1271,7 @@ version: gadget

// fails fast
s.state.Unlock()
_, err = s.mgr.EnsureSerial(context.TODO())
_, err = s.mgr.EnsureSerial(context.TODO(), 5*time.Second)
s.state.Lock()
c.Check(err, Equals, state.ErrNoState)

Expand Down
8 changes: 4 additions & 4 deletions overlord/devicestate/export_test.go
Expand Up @@ -86,11 +86,11 @@ func SetLastBecomeOperationalAttempt(m *DeviceManager, t time.Time) {
m.lastBecomeOperationalAttempt = t
}

func MockWaitForRegistraionTimeout(timeout time.Duration) (restore func()) {
old := waitForRegistrationTimeout
waitForRegistrationTimeout = timeout
func MockEnsureRegistrationDefaultTimeout(timeout time.Duration) (restore func()) {
old := ensureRegistrationDefaultTimeout
ensureRegistrationDefaultTimeout = timeout
return func() {
waitForRegistrationTimeout = old
ensureRegistrationDefaultTimeout = old
}
}

Expand Down
7 changes: 6 additions & 1 deletion overlord/managers_test.go
Expand Up @@ -1930,7 +1930,7 @@ func (s *authContextSetupSuite) TestDeviceSessionRequestParams(c *C) {
defer st.Unlock()

st.Unlock()
_, err := s.ac.DeviceSessionRequestParams("NONCE")
_, err = s.ac.DeviceSessionRequestParams("NONCE")
st.Lock()
c.Check(err, Equals, auth.ErrNoSerial)

Expand Down Expand Up @@ -1958,6 +1958,11 @@ func (s *authContextSetupSuite) TestDeviceSessionRequestParams(c *C) {
c.Check(params.EncodedSerial(), DeepEquals, string(asserts.Encode(s.serial)))
c.Check(params.EncodedModel(), DeepEquals, string(asserts.Encode(s.model)))

st.Unlock()
serial, err := s.ac.EnsureSerial(context.TODO(), 10*time.Second)
st.Lock()
c.Check(err, IsNil)
c.Check(serial.Serial(), Equals, "")
}

func (s *authContextSetupSuite) TestProxyStoreParams(c *C) {
Expand Down
14 changes: 14 additions & 0 deletions overlord/snapstate/autorefresh.go
Expand Up @@ -24,6 +24,8 @@ import (
"os"
"time"

"golang.org/x/net/context"

"github.com/snapcore/snapd/i18n"
"github.com/snapcore/snapd/logger"
"github.com/snapcore/snapd/overlord/auth"
Expand All @@ -45,8 +47,20 @@ const maxPostponement = 60 * 24 * time.Hour
var (
CanAutoRefresh func(st *state.State) (bool, error)
CanManageRefreshes func(st *state.State) bool
EnsureRegistration func(ctx context.Context, st *state.State, opts EnsureRegistrationOptions) (proceed bool, err error)
)

type EnsureRegistrationOptions struct {
// AfterAttemptsProceedAnyway is the number (if != 0) of
// registration attempts after which is allowed to proceed anyway.
AfterAttemptsProceedAnyway int
// CustomStoreOnly means registration is required only with a
// custom store.
CustomStoreOnly bool
// Timeout if waiting for registration.
Timeout time.Duration
}

// refreshRetryDelay specified the minimum time to retry failed refreshes
var refreshRetryDelay = 10 * time.Minute

Expand Down
10 changes: 10 additions & 0 deletions store/store.go
Expand Up @@ -85,6 +85,8 @@ var defaultRetryStrategy = retry.LimitCount(5, retry.LimitTime(38*time.Second,
},
))

var ensureSerialTimeout = 30 * time.Second

// Config represents the configuration to access the snap store
type Config struct {
// Store API base URLs. The assertions url is only separate because it can
Expand Down Expand Up @@ -557,6 +559,14 @@ func (s *Store) refreshDeviceSession(ctx context.Context, device *auth.DeviceSta
return fmt.Errorf("internal error: no authContext")
}

if device.Serial == "" {
// best-effort to be registered and have a serial
_, err := s.authContext.EnsureSerial(ctx, ensureSerialTimeout)
if err != nil {
return err
}
}

nonce, err := requestStoreDeviceNonce(s.endpointURL(deviceNonceEndpPath, nil).String())
if err != nil {
return err
Expand Down
11 changes: 11 additions & 0 deletions store/store_test.go
Expand Up @@ -225,6 +225,8 @@ type testAuthContext struct {
storeID string

cloudInfo *auth.CloudInfo

ensuredSerial bool
}

func (ac *testAuthContext) Device() (*auth.DeviceState, error) {
Expand Down Expand Up @@ -257,6 +259,14 @@ func (ac *testAuthContext) StoreID(fallback string) (string, error) {
return fallback, nil
}

func (ac *testAuthContext) EnsureSerial(ctx context.Context, timeout time.Duration) (*asserts.Serial, error) {
if ctx == nil {
panic("context required")
}
ac.ensuredSerial = true
return nil, nil
}

func (ac *testAuthContext) DeviceSessionRequestParams(nonce string) (*auth.DeviceSessionRequestParams, error) {
model, err := asserts.Decode([]byte(exModel))
if err != nil {
Expand Down Expand Up @@ -1694,6 +1704,7 @@ func (s *storeTestSuite) TestDoRequestSetsAndRefreshesDeviceAuth(c *C) {
c.Check(string(responseData), Equals, "response-data")
c.Check(deviceSessionRequested, Equals, true)
c.Check(refreshSessionRequested, Equals, true)
c.Check(authContext.ensuredSerial, Equals, true)
}

func (s *storeTestSuite) TestDoRequestSetsAndRefreshesBothAuths(c *C) {
Expand Down

0 comments on commit ee1b28e

Please sign in to comment.