diff --git a/htpasswd.go b/htpasswd.go index 1e26a04..4ff514f 100644 --- a/htpasswd.go +++ b/htpasswd.go @@ -14,6 +14,7 @@ package htpasswd import ( "bufio" "fmt" + "io" "os" "strings" "sync" @@ -83,6 +84,21 @@ func New(filename string, parsers []PasswdParser, bad BadLineHandler) (*Htpasswd return &bf, nil } +// NewFromReader is like new but reads from r instead of a named file. Calling +// Reload on the returned HtpasswdFile will result in an error; use +// ReloadFromReader instead. +func NewFromReader(r io.Reader, parsers []PasswdParser, bad BadLineHandler) (*HtpasswdFile, error) { + bf := HtpasswdFile{ + parsers: parsers, + } + + if err := bf.ReloadFromReader(r, bad); err != nil { + return nil, err + } + + return &bf, nil +} + // Match checks the username and password combination to see if it represents // a valid account from the htpassword file. func (bf *HtpasswdFile) Match(username, password string) bool { @@ -111,11 +127,18 @@ func (bf *HtpasswdFile) Reload(bad BadLineHandler) error { } defer f.Close() + return bf.ReloadFromReader(f, bad) +} + +// ReloadFromReader is like Reload but reads credentials from r instead of a named +// file. If HtpasswdFile was created by New, it is okay to call Reload and +// ReloadFromReader as desired. +func (bf *HtpasswdFile) ReloadFromReader(r io.Reader, bad BadLineHandler) error { // ... and a new map ... newPasswdMap := passwdTable{} // ... for each line ... - scanner := bufio.NewScanner(f) + scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() diff --git a/htpasswd_test.go b/htpasswd_test.go index 9564206..9a9ef0e 100644 --- a/htpasswd_test.go +++ b/htpasswd_test.go @@ -3,6 +3,7 @@ package htpasswd import ( "io/ioutil" "os" + "strings" "testing" ) @@ -1420,6 +1421,26 @@ user198:$1$D89ubl/e$SdHoMvPduS1kS3KVqEw9W. user199:$1$D89ubl/e$1FQtoOElFQQCBL53IT2LL0 user200:$1$D89ubl/e$xO3.z/20nsNEXnaWJdsfB/` +func testSystemReader(t *testing.T, name string, contents string) { + r := strings.NewReader(contents) + + htp, err := NewFromReader(r, DefaultSystems, nil) + if err != nil { + t.Fatalf("Failed to read htpasswd reader") + } + + for _, u := range testUsers { + if good := htp.Match(u.username, u.password); !good { + t.Errorf("%s user %s, password %s failed to authenticate: %t", name, u.username, u.password, good) + } + + notPass := u.password + "not" + if bad := htp.Match(u.username, notPass); bad { + t.Errorf("%s user %s, password %s erroneously authenticated: %t", name, u.username, notPass, bad) + } + } +} + func testSystem(t *testing.T, name string, contents string) { f, err := ioutil.TempFile("", "??") if err != nil { @@ -1450,27 +1471,22 @@ func testSystem(t *testing.T, name string, contents string) { t.Errorf("%s user %s, password %s erroneously authenticated: %t", name, u.username, notPass, bad) } } - } -func Test_PlainFile(t *testing.T) { - testSystem(t, "plain", textPlain) -} -func Test_ShaFile(t *testing.T) { - testSystem(t, "sha", textSha) -} -func Test_Apr1File(t *testing.T) { - testSystem(t, "md5", textApr1) -} +func Test_PlainReader(t *testing.T) { testSystemReader(t, "plain", textPlain) } +func Test_PlainFile(t *testing.T) { testSystem(t, "plain", textPlain) } -func Test_Md5File(t *testing.T) { - testSystem(t, "md5", textMd5Crypt) -} +func Test_ShaReader(t *testing.T) { testSystemReader(t, "sha", textSha) } +func Test_ShaFile(t *testing.T) { testSystem(t, "sha", textSha) } -func Test_BcryptFile(t *testing.T) { - testSystem(t, "bcrypt", textBcrypt) -} +func Test_Apr1Reader(t *testing.T) { testSystemReader(t, "md5", textApr1) } +func Test_Apr1File(t *testing.T) { testSystem(t, "md5", textApr1) } -func Test_SshaFile(t *testing.T) { - testSystem(t, "ssha", textSsha) -} +func Test_Md5Reader(t *testing.T) { testSystemReader(t, "md5", textMd5Crypt) } +func Test_Md5File(t *testing.T) { testSystem(t, "md5", textMd5Crypt) } + +func Test_BcryptReader(t *testing.T) { testSystemReader(t, "bcrypt", textBcrypt) } +func Test_BcryptFile(t *testing.T) { testSystem(t, "bcrypt", textBcrypt) } + +func Test_SshaReader(t *testing.T) { testSystemReader(t, "ssha", textSsha) } +func Test_SshaFile(t *testing.T) { testSystem(t, "ssha", textSsha) }