Skip to content

Commit

Permalink
fix: recursive loop on network errors in password validator (#589)
Browse files Browse the repository at this point in the history
The old code no error when ignoreNetworkErrors was set to true, but did not set a hash result which caused an infinite loop.

Closes #316 

Co-authored-by: aeneasr <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
jakhog and aeneasr committed Jul 27, 2020
1 parent 5eb14ed commit b4d5a42
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 9 deletions.
17 changes: 8 additions & 9 deletions selfservice/strategy/password/validator.go
Expand Up @@ -34,6 +34,8 @@ type ValidationProvider interface {
}

var _ Validator = new(DefaultPasswordValidator)
var ErrNetworkFailure = errors.New("unable to check if password has been leaked because an unexpected network error occurred")
var ErrUnexpectedStatusCode = errors.New("unexpected status code")

// DefaultPasswordValidator implements Validator. It is based on best
// practices as defined in the following blog posts:
Expand Down Expand Up @@ -104,18 +106,12 @@ func (s *DefaultPasswordValidator) fetch(hpw []byte) error {
loc := fmt.Sprintf("https://api.pwnedpasswords.com/range/%s", prefix)
res, err := s.c.Get(loc)
if err != nil {
if s.ignoreNetworkErrors {
return nil
}
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to check if password has been breached before: %s", err))
return errors.Wrapf(ErrNetworkFailure, "%s", err)
}
defer res.Body.Close()

if res.StatusCode != http.StatusOK {
if s.ignoreNetworkErrors {
return nil
}
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to check if password has been breached before, expected status code 200 but got %d", res.StatusCode))
return errors.Wrapf(ErrUnexpectedStatusCode, "%d", res.StatusCode)
}

s.Lock()
Expand Down Expand Up @@ -172,7 +168,10 @@ func (s *DefaultPasswordValidator) Validate(identifier, password string) error {
s.RUnlock()

if !ok {
if err := s.fetch(hpw); err != nil {
err := s.fetch(hpw)
if (errors.Is(err, ErrNetworkFailure) || errors.Is(err, ErrUnexpectedStatusCode)) && s.ignoreNetworkErrors {
return nil
} else if err != nil {
return err
}

Expand Down
156 changes: 156 additions & 0 deletions selfservice/strategy/password/validator_test.go
@@ -1,7 +1,11 @@
package password

import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"net/http"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -75,4 +79,156 @@ func TestDefaultPasswordValidationStrategy(t *testing.T) {
}
})
}

fakeClient := NewFakeHTTPClient()
s.c = &fakeClient.Client

t.Run("case=should send request to pwnedpasswords.com", func(t *testing.T) {
s.ignoreNetworkErrors = false
s.Validate("mohutdesub", "damrumukuh")
require.Contains(t, fakeClient.RequestedURLs(), "https://api.pwnedpasswords.com/range/BCBA9")
})

t.Run("case=should fail if request fails and ignoreNetworkErrors is not set", func(t *testing.T) {
s.ignoreNetworkErrors = false
fakeClient.RespondWithError("Network request failed")
require.Error(t, s.Validate("", "sumdarmetp"))
})

t.Run("case=should not fail if request fails and ignoreNetworkErrors is set", func(t *testing.T) {
s.ignoreNetworkErrors = true
fakeClient.RespondWithError("Network request failed")
require.NoError(t, s.Validate("", "pepegtawni"))
})

t.Run("case=should fail if response has non 200 code and ignoreNetworkErrors is not set", func(t *testing.T) {
s.ignoreNetworkErrors = false
fakeClient.RespondWith(http.StatusForbidden, "")
require.Error(t, s.Validate("", "jolhakowef"))
})

t.Run("case=should not fail if response has non 200 code code and ignoreNetworkErrors is set", func(t *testing.T) {
s.ignoreNetworkErrors = true
fakeClient.RespondWith(http.StatusInternalServerError, "")
require.NoError(t, s.Validate("", "jenuzuhjoj"))
})

for _, tc := range []struct {
cs string
pw string
res string
pass bool
}{
{
cs: "contains invalid data",
pw: "lufsokpugo",
res: "0225BDB8F106B1B4A5DF4C31B80AC695874:2\ninvalid",
pass: false,
},
{
cs: "contains invalid hash count",
pw: "gimekvizec",
res: "0248B3D6077106761CC84F4B9CF680C6D84:text\n1A34C526A9D14832C6ACFEAE90261ED78F8:2",
pass: false,
},
{
cs: "is missing hash count",
pw: "bofulosasm",
res: "1D29CF237A57F6FEA8F29E8D907DCF1EBBA\n026364A8EE59DEDCF9E2DC80B9D7BAB7389:2",
pass: false,
},
{
cs: "response contains no matches",
pw: "lizrafakha",
res: "0D6CF6289C9CA71B47D2167EB7FE89690E7:57",
pass: true,
},
{
cs: "contains less than maxBreachesThreshold",
pw: "tafpabdopa",
res: fmt.Sprintf("280915F3B572F94217D86F1D63BED53F66A:%d\n0F76A7D21E7C3E653E98236897AD7888937:%d", s.maxBreachesThreshold, s.maxBreachesThreshold+1),
pass: true,
},
{
cs: "contains more than maxBreachesThreshold",
pw: "hicudsumla",
res: fmt.Sprintf("5656812AA72561AAA6663E486A46D5711BE:%d", s.maxBreachesThreshold+1),
pass: false,
},
} {
fakeClient.RespondWith(http.StatusOK, tc.res)
format := "case=shuold not fail if response %s"
if !tc.pass {
format = "case=shuold fail if response %s"
}
t.Run(fmt.Sprintf(format, tc.cs), func(t *testing.T) {
err := s.Validate("", tc.pw)
if tc.pass {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}

type fakeHttpClient struct {
http.Client

requestedURLs []string
responder func(*http.Request) (*http.Response, error)
}

func NewFakeHTTPClient() *fakeHttpClient {
client := fakeHttpClient{
responder: func(*http.Request) (*http.Response, error) {
return nil, errors.New("No responder defined in fake HTTP client")
},
}
client.Client = http.Client{
Transport: &fakeRoundTripper{&client},
}
return &client
}

func (c *fakeHttpClient) RespondWith(status int, body string) {
c.responder = func(request *http.Request) (*http.Response, error) {
buffer := bytes.NewBufferString(body)
return &http.Response{
StatusCode: status,
Body: ioutil.NopCloser(buffer),
ContentLength: int64(buffer.Len()),
Request: request,
}, nil
}
}

func (c *fakeHttpClient) RespondWithError(err string) {
c.responder = func(*http.Request) (*http.Response, error) {
return nil, errors.New(err)
}
}

func (c *fakeHttpClient) Reset() {
c.requestedURLs = nil
}

func (c *fakeHttpClient) RequestedURLs() []string {
return c.requestedURLs
}

func (c *fakeHttpClient) handle(request *http.Request) (*http.Response, error) {
c.requestedURLs = append(c.requestedURLs, request.URL.String())
if request.Body != nil {
request.Body.Close()
}
return c.responder(request)
}

type fakeRoundTripper struct {
client *fakeHttpClient
}

func (rt *fakeRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
return rt.client.handle(request)
}

0 comments on commit b4d5a42

Please sign in to comment.