Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

overlord/devicestate: best effort to go to early full retries for registration on the like of DNS no host #4263

Merged
merged 3 commits into from Nov 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
114 changes: 114 additions & 0 deletions overlord/devicestate/devicestate_test.go
Expand Up @@ -193,6 +193,10 @@ func (s *deviceMgrSuite) mockServer(c *C) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case requestIDURLPath, "/svc/request-id":
if s.reqID == "REQID-501" {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much nicer than foo that we use randomly in other places. Thank you!

w.WriteHeader(501)
return
}
w.WriteHeader(200)
c.Check(r.Header.Get("User-Agent"), Equals, expectedUserAgent)
io.WriteString(w, fmt.Sprintf(`{"request-id": "%s"}`, s.reqID))
Expand Down Expand Up @@ -756,6 +760,116 @@ version: gadget
c.Check(device.Serial, Equals, "9999")
}

func (s *deviceMgrSuite) TestDoRequestSerialErrorsOnNoHost(c *C) {
privKey, _ := assertstest.GenerateKey(testKeyLength)

nowhere := "http://nowhere.invalid"

mockRequestIDURL := nowhere + requestIDURLPath
restore := devicestate.MockRequestIDURL(mockRequestIDURL)
defer restore()

mockSerialRequestURL := nowhere + serialURLPath
restore = devicestate.MockSerialRequestURL(mockSerialRequestURL)
defer restore()

// setup state as done by first-boot/Ensure/doGenerateDeviceKey
s.state.Lock()
defer s.state.Unlock()

s.setupGadget(c, `
name: gadget
type: gadget
version: gadget
`, "")

auth.SetDevice(s.state, &auth.DeviceState{
Brand: "canonical",
Model: "pc",
KeyID: privKey.PublicKey().ID(),
})
s.mgr.KeypairManager().Put(privKey)

t := s.state.NewTask("request-serial", "test")
chg := s.state.NewChange("become-operational", "...")
chg.AddTask(t)

// avoid full seeding
s.seeding()

s.state.Unlock()
s.mgr.Ensure()
s.mgr.Wait()
s.state.Lock()

c.Check(chg.Status(), Equals, state.ErrorStatus)
}

func (s *deviceMgrSuite) TestDoRequestSerialMaxTentatives(c *C) {
privKey, _ := assertstest.GenerateKey(testKeyLength)

// immediate
r := devicestate.MockRetryInterval(0)
defer r()

r = devicestate.MockMaxTentatives(2)
defer r()

s.reqID = "REQID-501"
mockServer := s.mockServer(c)
defer mockServer.Close()

mockRequestIDURL := mockServer.URL + requestIDURLPath
restore := devicestate.MockRequestIDURL(mockRequestIDURL)
defer restore()

mockSerialRequestURL := mockServer.URL + serialURLPath
restore = devicestate.MockSerialRequestURL(mockSerialRequestURL)
defer restore()

restore = devicestate.MockRepeatRequestSerial("after-add-serial")
defer restore()

// setup state as done by first-boot/Ensure/doGenerateDeviceKey
s.state.Lock()
defer s.state.Unlock()

s.setupGadget(c, `
name: gadget
type: gadget
version: gadget
`, "")

auth.SetDevice(s.state, &auth.DeviceState{
Brand: "canonical",
Model: "pc",
KeyID: privKey.PublicKey().ID(),
})
s.mgr.KeypairManager().Put(privKey)

t := s.state.NewTask("request-serial", "test")
chg := s.state.NewChange("become-operational", "...")
chg.AddTask(t)

// avoid full seeding
s.seeding()

s.state.Unlock()
s.mgr.Ensure()
s.mgr.Wait()
s.state.Lock()

c.Check(chg.Status(), Equals, state.DoingStatus)

s.state.Unlock()
s.mgr.Ensure()
s.mgr.Wait()
s.state.Lock()

c.Check(chg.Status(), Equals, state.ErrorStatus)
c.Check(chg.Err(), ErrorMatches, `(?s).*cannot retrieve request-id for making a request for a serial: unexpected status 501.*`)
}

func (s *deviceMgrSuite) TestFullDeviceRegistrationPollHappy(c *C) {
r1 := devicestate.MockKeyLength(testKeyLength)
defer r1()
Expand Down
8 changes: 8 additions & 0 deletions overlord/devicestate/export_test.go
Expand Up @@ -62,6 +62,14 @@ func MockRetryInterval(interval time.Duration) (restore func()) {
}
}

func MockMaxTentatives(max int) (restore func()) {
old := maxTentatives
maxTentatives = max
return func() {
maxTentatives = old
}
}

func (m *DeviceManager) KeypairManager() asserts.KeypairManager {
return m.keypairMgr
}
Expand Down
38 changes: 29 additions & 9 deletions overlord/devicestate/handlers.go
Expand Up @@ -23,6 +23,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -62,6 +63,7 @@ func deviceAPIBaseURL() string {
var (
keyLength = 4096
retryInterval = 60 * time.Second
maxTentatives = 15
deviceAPIBase = deviceAPIBaseURL()
requestIDURL = deviceAPIBase + "request-id"
serialRequestURL = deviceAPIBase + "devices"
Expand Down Expand Up @@ -113,9 +115,12 @@ type requestIDResp struct {
RequestID string `json:"request-id"`
}

func retryErr(t *state.Task, reason string, a ...interface{}) error {
func retryErr(t *state.Task, nTentatives int, reason string, a ...interface{}) error {
t.State().Lock()
defer t.State().Unlock()
if nTentatives >= maxTentatives {
return fmt.Errorf(reason, a...)
}
t.Errorf(reason, a...)
return &state.Retry{After: retryInterval}
}
Expand All @@ -125,10 +130,10 @@ type serverError struct {
Errors []*serverError `json:"error_list"`
}

func retryBadStatus(t *state.Task, reason string, resp *http.Response) error {
func retryBadStatus(t *state.Task, nTentatives int, reason string, resp *http.Response) error {
if resp.StatusCode > 500 {
// likely temporary
return retryErr(t, "%s: unexpected status %d", reason, resp.StatusCode)
return retryErr(t, nTentatives, "%s: unexpected status %d", reason, resp.StatusCode)
}
if resp.Header.Get("Content-Type") == "application/json" {
var srvErr serverError
Expand All @@ -148,6 +153,16 @@ func retryBadStatus(t *state.Task, reason string, resp *http.Response) error {
}

func prepareSerialRequest(t *state.Task, privKey asserts.PrivateKey, device *auth.DeviceState, client *http.Client, cfg *serialRequestConfig) (string, error) {
// limit tentatives starting from scratch before going to
// slower full retries
var nTentatives int
err := t.Get("pre-poll-tentatives", &nTentatives)
if err != nil && err != state.ErrNoState {
return "", err
}
nTentatives++
t.Set("pre-poll-tentatives", nTentatives)

st := t.State()
st.Unlock()
defer st.Lock()
Expand All @@ -161,18 +176,23 @@ func prepareSerialRequest(t *state.Task, privKey asserts.PrivateKey, device *aut

resp, err := client.Do(req)
if err != nil {
return "", retryErr(t, "cannot retrieve request-id for making a request for a serial: %v", err)
if netErr, ok := err.(net.Error); ok && !netErr.Temporary() {
// a non temporary net error, like a DNS no
// host, error out and do full retries
return "", fmt.Errorf("cannot retrieve request-id for making a request for a serial: %v", err)
}
return "", retryErr(t, nTentatives, "cannot retrieve request-id for making a request for a serial: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return "", retryBadStatus(t, "cannot retrieve request-id for making a request for a serial", resp)
return "", retryBadStatus(t, nTentatives, "cannot retrieve request-id for making a request for a serial", resp)
}

dec := json.NewDecoder(resp.Body)
var requestID requestIDResp
err = dec.Decode(&requestID)
if err != nil { // assume broken i/o
return "", retryErr(t, "cannot read response with request-id for making a request for a serial: %v", err)
return "", retryErr(t, nTentatives, "cannot read response with request-id for making a request for a serial: %v", err)
}

encodedPubKey, err := asserts.EncodePublicKey(privKey.PublicKey())
Expand Down Expand Up @@ -216,7 +236,7 @@ func submitSerialRequest(t *state.Task, serialRequest string, client *http.Clien

resp, err := client.Do(req)
if err != nil {
return nil, retryErr(t, "cannot deliver device serial request: %v", err)
return nil, retryErr(t, 0, "cannot deliver device serial request: %v", err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding this little helper?

func retryFinalErr(t *state.Task, reason string, a ...interface{}) error {
    return retryErr(t, 0, reason, a...)
}

}
defer resp.Body.Close()

Expand All @@ -225,7 +245,7 @@ func submitSerialRequest(t *state.Task, serialRequest string, client *http.Clien
case 202:
return nil, errPoll
default:
return nil, retryBadStatus(t, "cannot deliver device serial request", resp)
return nil, retryBadStatus(t, 0, "cannot deliver device serial request", resp)
}

// TODO: support a stream of assertions instead of just the serial
Expand All @@ -234,7 +254,7 @@ func submitSerialRequest(t *state.Task, serialRequest string, client *http.Clien
dec := asserts.NewDecoder(resp.Body)
got, err := dec.Decode()
if err != nil { // assume broken i/o
return nil, retryErr(t, "cannot read response to request for a serial: %v", err)
return nil, retryErr(t, 0, "cannot read response to request for a serial: %v", err)
}

serial, ok := got.(*asserts.Serial)
Expand Down
4 changes: 2 additions & 2 deletions store/store_test.go
Expand Up @@ -2308,7 +2308,7 @@ func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryProxyStoreFromAuthContext
defer mockServer.Close()

mockServerURL, _ := url.Parse(mockServer.URL)
nowhereURL, err := url.Parse("http://nowhere.nowhere")
nowhereURL, err := url.Parse("http://nowhere.invalid")
c.Assert(err, IsNil)
cfg := DefaultConfig()
cfg.StoreBaseURL = nowhereURL
Expand Down Expand Up @@ -4268,7 +4268,7 @@ func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryAssertionProxyStoreFromAu
defer mockServer.Close()

mockServerURL, _ := url.Parse(mockServer.URL)
nowhereURL, err := url.Parse("http://nowhere.nowhere")
nowhereURL, err := url.Parse("http://nowhere.invalid")
c.Assert(err, IsNil)
cfg := Config{
AssertionsBaseURL: nowhereURL,
Expand Down