Skip to content

Commit

Permalink
store: reorg auth refresh
Browse files Browse the repository at this point in the history
This factors out auth refresh to a helper and addresses long-standing todo to limit tries there (now without recursion involved).
  • Loading branch information
pedronis committed Feb 23, 2018
2 parents 17ab254 + aa27c91 commit 712f7ef
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 42 deletions.
1 change: 1 addition & 0 deletions store/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ func requestDeviceSession(deviceSessionEndpoint string, paramsEncoder deviceSess
if err != nil {
return "", fmt.Errorf(errorPrefix+"%v", err)
}
// TODO: retry at least once on 400

if responseData.Macaroon == "" {
return "", fmt.Errorf(errorPrefix + "empty session returned")
Expand Down
111 changes: 69 additions & 42 deletions store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,57 +828,84 @@ func (s *Store) retryRequestDecodeJSON(ctx context.Context, reqOptions *requestO

// doRequest does an authenticated request to the store handling a potential macaroon refresh required if needed
func (s *Store) doRequest(ctx context.Context, client *http.Client, reqOptions *requestOptions, user *auth.UserState) (*http.Response, error) {
req, err := s.newRequest(reqOptions, user)
if err != nil {
return nil, err
}

var resp *http.Response
if ctx != nil {
resp, err = ctxhttp.Do(ctx, client, req)
} else {
resp, err = client.Do(req)
}
if err != nil {
return nil, err
}
authRefreshes := 0
for {
req, err := s.newRequest(reqOptions, user)
if err != nil {
return nil, err
}

wwwAuth := resp.Header.Get("WWW-Authenticate")
if resp.StatusCode == 401 {
refreshed := false
if user != nil && strings.Contains(wwwAuth, "needs_refresh=1") {
// refresh user
err = s.refreshUser(user)
if err != nil {
return nil, err
}
refreshed = true
var resp *http.Response
if ctx != nil {
resp, err = ctxhttp.Do(ctx, client, req)
} else {
resp, err = client.Do(req)
}
if err != nil {
return nil, err
}
if strings.Contains(wwwAuth, "refresh_device_session=1") {
// refresh device session
if s.authContext == nil {
return nil, fmt.Errorf("internal error: no authContext")

wwwAuth := resp.Header.Get("WWW-Authenticate")
if resp.StatusCode == 401 && authRefreshes < 4 {
// 4 tries: 2 tries for each in case both user
// and device need refreshing
var refreshNeed authRefreshNeed
refresh := false
if user != nil && strings.Contains(wwwAuth, "needs_refresh=1") {
// refresh user
refreshNeed.user = true
refresh = true
}
device, err := s.authContext.Device()
if err != nil {
return nil, err
if strings.Contains(wwwAuth, "refresh_device_session=1") {
// refresh device session
refreshNeed.device = true
refresh = true
}

err = s.refreshDeviceSession(device)
if err != nil {
return nil, err
if refresh {
err := s.refreshAuth(user, refreshNeed)
if err != nil {
return nil, err
}
// close previous response and retry
resp.Body.Close()
authRefreshes++
continue
}
refreshed = true
}
if refreshed {
// close previous response and retry
// TODO: make this non-recursive or add a recursion limit
resp.Body.Close()
return s.doRequest(ctx, client, reqOptions, user)

return resp, err
}
}

type authRefreshNeed struct {
device bool
user bool
}

func (s *Store) refreshAuth(user *auth.UserState, need authRefreshNeed) error {
if need.user {
// refresh user
err := s.refreshUser(user)
if err != nil {
return err
}
}
if need.device {
// refresh device session
if s.authContext == nil {
return fmt.Errorf("internal error: no authContext")
}
device, err := s.authContext.Device()
if err != nil {
return err
}

return resp, err
err = s.refreshDeviceSession(device)
if err != nil {
return err
}
}
return nil
}

// build a new http.Request with headers for the store
Expand Down
93 changes: 93 additions & 0 deletions store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,99 @@ func (t *remoteRepoTestSuite) TestDoRequestSetsAndRefreshesDeviceAuth(c *C) {
c.Check(refreshSessionRequested, Equals, true)
}

func (t *remoteRepoTestSuite) TestDoRequestSetsAndRefreshesBothAuths(c *C) {
refresh, err := makeTestRefreshDischargeResponse()
c.Assert(err, IsNil)
c.Check(t.user.StoreDischarges[0], Not(Equals), refresh)

// mock refresh response
refreshDischargeEndpointHit := false
mockSSOServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, fmt.Sprintf(`{"discharge_macaroon": "%s"}`, refresh))
refreshDischargeEndpointHit = true
}))
defer mockSSOServer.Close()
UbuntuoneRefreshDischargeAPI = mockSSOServer.URL + "/tokens/refresh"

refreshSessionRequested := false
expiredAuth := `Macaroon root="expired-session-macaroon"`
// mock store response
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.Check(r.UserAgent(), Equals, userAgent)

switch r.URL.Path {
case "/":
authorization := r.Header.Get("Authorization")
c.Check(authorization, Equals, t.expectedAuthorization(c, t.user))
if t.user.StoreDischarges[0] != refresh {
w.Header().Set("WWW-Authenticate", "Macaroon needs_refresh=1")
w.WriteHeader(401)
return
}

devAuthorization := r.Header.Get("X-Device-Authorization")
if devAuthorization == "" {
c.Fatalf("device authentication missing")
} else if devAuthorization == expiredAuth {
w.Header().Set("WWW-Authenticate", "Macaroon refresh_device_session=1")
w.WriteHeader(401)
} else {
c.Check(devAuthorization, Equals, `Macaroon root="refreshed-session-macaroon"`)
io.WriteString(w, "response-data")
}
case authNoncesPath:
io.WriteString(w, `{"nonce": "1234567890:9876543210"}`)
case authSessionPath:
// sanity of request
jsonReq, err := ioutil.ReadAll(r.Body)
c.Assert(err, IsNil)
var req map[string]string
err = json.Unmarshal(jsonReq, &req)
c.Assert(err, IsNil)
c.Check(strings.HasPrefix(req["device-session-request"], "type: device-session-request\n"), Equals, true)
c.Check(strings.HasPrefix(req["serial-assertion"], "type: serial\n"), Equals, true)
c.Check(strings.HasPrefix(req["model-assertion"], "type: model\n"), Equals, true)

authorization := r.Header.Get("X-Device-Authorization")
if authorization == "" {
c.Fatalf("expecting only refresh")
} else {
c.Check(authorization, Equals, expiredAuth)
io.WriteString(w, `{"macaroon": "refreshed-session-macaroon"}`)
refreshSessionRequested = true
}
default:
c.Fatalf("unexpected path %q", r.URL.Path)
}
}))
c.Assert(mockServer, NotNil)
defer mockServer.Close()

mockServerURL, _ := url.Parse(mockServer.URL)

// make sure device session is expired
t.device.SessionMacaroon = "expired-session-macaroon"
authContext := &testAuthContext{c: c, device: t.device, user: t.user}
repo := New(&Config{
StoreBaseURL: mockServerURL,
}, authContext)
c.Assert(repo, NotNil)

reqOptions := &requestOptions{Method: "GET", URL: mockServerURL}

resp, err := repo.doRequest(context.TODO(), repo.client, reqOptions, t.user)
c.Assert(err, IsNil)
defer resp.Body.Close()

c.Check(resp.StatusCode, Equals, 200)

responseData, err := ioutil.ReadAll(resp.Body)
c.Assert(err, IsNil)
c.Check(string(responseData), Equals, "response-data")
c.Check(refreshDischargeEndpointHit, Equals, true)
c.Check(refreshSessionRequested, Equals, true)
}

func (t *remoteRepoTestSuite) TestDoRequestSetsExtraHeaders(c *C) {
// Custom headers are applied last.
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 712f7ef

Please sign in to comment.