Skip to content

Commit

Permalink
Code improvements
Browse files Browse the repository at this point in the history
- Added Authorization to HeaderBuilder.
- New constructor for Chain type.
- Changed GetReader/SetWriter methods to Read/Write (Header).
- Improved SessionStore testing to use SalterFast to more reliable timing.
- New testing for BasicAuthenticator and Chain types.
  • Loading branch information
skarllot committed Jan 26, 2016
1 parent 4342757 commit 803f941
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 16 deletions.
20 changes: 19 additions & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

package web

import "net/http"
import (
"encoding/base64"
"net/http"
)

// A Authenticable defines rules for a type that offers HTTP authentication.
type Authenticable interface {
Expand All @@ -35,3 +38,18 @@ func (HeaderBuilder) WwwAuthenticate() *Header {
"", // type and params
}
}

// Authorization creates a HTTP header to request authentication.
func (HeaderBuilder) Authorization(user, secret string) *Header {
var value string

if len(user)+len(secret) > 0 {
credentials := []byte(user + ":" + secret)
value = basicPrefix + base64.StdEncoding.EncodeToString(credentials)
}

return &Header{
authHeaderName,
value, // credentials
}
}
9 changes: 5 additions & 4 deletions authbasic.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import (
)

const (
basicPrefix = "Basic "
basicRealm = basicPrefix + "realm=\"Restricted\""
authHeaderName = "Authorization"
basicPrefix = "Basic "
basicRealm = basicPrefix + "realm=\"Restricted\""
)

// A BasicAuthenticator represents a handler for HTTP basic authentication.
Expand All @@ -40,7 +41,7 @@ func (auth BasicAuthenticator) AuthHandler(next http.Handler) http.Handler {
}

f := func(w http.ResponseWriter, r *http.Request) {
user, secret := parseAuthHeader(r.Header.Get("Authorization"))
user, secret := parseAuthHeader(r.Header.Get(authHeaderName))
if len(user) > 0 &&
len(secret) > 0 &&
auth.TryAuthentication(r, user, secret) {
Expand All @@ -51,7 +52,7 @@ func (auth BasicAuthenticator) AuthHandler(next http.Handler) http.Handler {
NewHeader().
WwwAuthenticate().
SetValue(basicRealm).
SetWriter(w.Header())
Write(w.Header())
http.Error(w, http.StatusText(http.StatusUnauthorized),
http.StatusUnauthorized)
}
Expand Down
98 changes: 98 additions & 0 deletions authbasic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2016 Fabrício Godoy
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package web

import (
"net/http"
"net/http/httptest"
"testing"
)

type FooAuthenticator int

func (a *FooAuthenticator) TryAuthentication(
r *http.Request,
user, secret string,
) bool {
if user == "user" && secret == "secret" {
*a = FooAuthenticator(int(*a) + 1)
return true
}
return false
}

func (a *FooAuthenticator) EndPoint(w http.ResponseWriter, r *http.Request) {
*a = FooAuthenticator(int(*a) * 10)
}

func TestBasicAuthenticator(t *testing.T) {
testValues := []struct {
user string
secret string
}{
{"user", "user"},
{"user", "secret"},
{"user", "123"},
{"user", ""},
{"secret", "user"},
{"secret", "secret"},
{"secret", "123"},
{"secret", ""},
{"123", "user"},
{"123", "secret"},
{"123", "123"},
{"123", ""},
{"", "user"},
{"", "secret"},
{"", "123"},
{"", ""},
}

for idx, testVal := range testValues {
foo := FooAuthenticator(1)
basicauth := BasicAuthenticator{&foo}

chain := NewChain()
chain = append(chain, basicauth.AuthHandler)
server := chain.Get(http.HandlerFunc(foo.EndPoint))

w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://localhost", nil)
NewHeader().
Authorization(testVal.user, testVal.secret).
Write(req.Header)

server.ServeHTTP(w, req)

switch int(foo) {
case 1:
if idx == 1 {
t.Error("Failed authentication: neither the middleware and endpoint was called")
}
case 2:
t.Error("Failed authentication: endpoint was not called")
case 10:
t.Error("Failed authentication: enpoint was called without authentication")
case 20:
if idx != 1 {
t.Errorf("Should not authenticate")
}
default:
t.Errorf("Failed authentication: unexpected value %d", int(foo))
}
}
}
4 changes: 2 additions & 2 deletions bench_result.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
BenchmarkSessionCreation-4 2000 781448 ns/op 12961 B/op 201 allocs/op
BenchmarkSessionCreationFast-4 100000 14922 ns/op 1077 B/op 15 allocs/op
BenchmarkSessionCreation-4 2000 774228 ns/op 10982 B/op 170 allocs/op
BenchmarkSessionCreationFast-4 100000 14867 ns/op 1009 B/op 14 allocs/op
5 changes: 5 additions & 0 deletions chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ type MiddlewareFunc func(http.Handler) http.Handler
// HTTP handler.
type Chain []MiddlewareFunc

// NewChain creates a new empty slice of MiddlewareFunc.
func NewChain() Chain {
return make(Chain, 0)
}

// Get returns a HTTP handler which is a chain of middlewares and then the
// specified handler.
func (s Chain) Get(handler http.Handler) http.Handler {
Expand Down
67 changes: 67 additions & 0 deletions chain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2016 Fabrício Godoy
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package web

import (
"net/http"
"testing"
)

var stacker []int

func TestChainOrder(t *testing.T) {
stacker = make([]int, 0)

chain := NewChain()
chain = append(chain, FooHandler(1).Middleware)
chain = append(chain, FooHandler(2).Middleware)
chain = append(chain, FooHandler(3).Middleware)
chain = append(chain, FooHandler(4).Middleware)
chain = append(chain, FooHandler(5).Middleware)
handler := chain.Get(http.HandlerFunc(FooHandler(6).EndPoint))
handler.ServeHTTP(nil, nil)

if len(stacker) != 6 {
t.Errorf("Not all handlers was called: %d instead of %d",
len(stacker), 6)
}

counter := 1
for _, v := range stacker {
if v != counter {
t.Errorf("Chain not called in order: got %d instead of %d",
v, counter)
}

counter++
}
}

type FooHandler int

func (h FooHandler) Middleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
stacker = append(stacker, int(h))
next.ServeHTTP(w, r)
}

return http.HandlerFunc(fn)
}

func (h FooHandler) EndPoint(w http.ResponseWriter, r *http.Request) {
stacker = append(stacker, int(h))
}
10 changes: 5 additions & 5 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ func (s Header) Clone() *Header {
return &s
}

// GetReader gets HTTP header value, as defined by current instance, from
// Request Header and sets to current instance.
func (s *Header) GetReader(h http.Header) *Header {
// Read gets HTTP header value, as defined by current instance, from Request
// Header and sets to current instance.
func (s *Header) Read(h http.Header) *Header {
s.Value = h.Get(s.Name)
return s
}
Expand All @@ -52,9 +52,9 @@ func (s *Header) SetValue(value string) *Header {
return s
}

// SetWriter sets HTTP header, as defined by current instance, to ResponseWriter
// Write sets HTTP header, as defined by current instance, to ResponseWriter
// Header.
func (s *Header) SetWriter(h http.Header) *Header {
func (s *Header) Write(h http.Header) *Header {
h.Set(s.Name, s.Value)
return s
}
2 changes: 1 addition & 1 deletion json.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const (
// JSONWrite sets response content type to JSON, sets HTTP status and serializes
// defined content to JSON format.
func JSONWrite(w http.ResponseWriter, status int, content interface{}) {
NewHeader().ContentType().JSON().SetWriter(w.Header())
NewHeader().ContentType().JSON().Write(w.Header())
w.WriteHeader(status)
if content != nil {
json.NewEncoder(w).Encode(content)
Expand Down
4 changes: 2 additions & 2 deletions sessionstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const TokenSalt = "CvoTVwDw685Ve0qjGn//zmHGKvoCcslYNQT4AQ9FygSk9t6NuzBHuohyO" +
func TestSessionLifetime(t *testing.T) {
store := memstore.New(time.Millisecond*10, false)
ts := NewSessionStore().
SalterSecure([]byte(TokenSalt)).
SalterFast([]byte(TokenSalt)).
Store(store).
Build()

Expand Down Expand Up @@ -87,7 +87,7 @@ func TestSessionHandling(t *testing.T) {

store := memstore.New(time.Millisecond*100, false)
ts := NewSessionStore().
SalterSecure([]byte(TokenSalt)).
SalterFast([]byte(TokenSalt)).
Store(store).
Build()
if _, err := ts.Count(); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion sessionstorebuilder.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (C) 2015 Fabrício Godoy <skarllot@gmail.com>
* Copyright (C) 2016 Fabrício Godoy <skarllot@gmail.com>
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
Expand Down

0 comments on commit 803f941

Please sign in to comment.