Skip to content

Commit

Permalink
make follow redirects configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Mancke committed Jun 24, 2016
1 parent 1abcacc commit 79e9fda
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 27 deletions.
7 changes: 7 additions & 0 deletions composition/composition_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@ func (agg *CompositionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
for _, res := range results {
if res.Err == nil && res.Content != nil {

if res.Content.HttpStatusCode() == 301 || res.Content.HttpStatusCode() == 302 || res.Content.HttpStatusCode() == 303 {
copyHeaders(res.Content.HttpHeader(), w.Header(), ForwardResponseHeaders)
w.WriteHeader(res.Content.HttpStatusCode())
return
}

if res.Content.Reader() != nil {
copyHeaders(res.Content.HttpHeader(), w.Header(), ForwardResponseHeaders)
w.WriteHeader(res.Content.HttpStatusCode())
io.Copy(w, res.Content.Reader())
res.Content.Reader().Close()
return
Expand Down
10 changes: 5 additions & 5 deletions composition/composition_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ func Test_CompositionHandler_CorrectHeaderAndStatusCodeReturned(t *testing.T) {
ch.ServeHTTP(resp, r)

a.Equal(302, resp.Code)
a.Equal(4, len(resp.Header()))
a.Equal(2, len(resp.Header()))
a.Equal("/look/somewhere", resp.Header().Get("Location"))
a.Equal("", resp.Header().Get("Transfer-Encoding"))
a.NotEqual("", resp.Header().Get("Content-Length"))
a.Contains(resp.Header()["Set-Cookie"], "cookie-content 1")
a.Contains(resp.Header()["Set-Cookie"], "cookie-content 2")
}
Expand All @@ -108,8 +107,9 @@ func Test_CompositionHandler_ReturnStream(t *testing.T) {
}

contentWithReader := &MemoryContent{
reader: ioutil.NopCloser(strings.NewReader("bar")),
httpHeader: http.Header{"Content-Type": {"text/plain"}},
reader: ioutil.NopCloser(strings.NewReader("bar")),
httpHeader: http.Header{"Content-Type": {"text/plain"}},
httpStatusCode: 201,
}

contentFetcherFactory := func(r *http.Request) FetchResultSupplier {
Expand All @@ -132,7 +132,7 @@ func Test_CompositionHandler_ReturnStream(t *testing.T) {

a.Equal("bar", string(resp.Body.Bytes()))
a.Equal("text/plain", resp.Header().Get("Content-Type"))
a.Equal(200, resp.Code)
a.Equal(201, resp.Code)
}

func Test_CompositionHandler_ErrorInMerging(t *testing.T) {
Expand Down
28 changes: 15 additions & 13 deletions composition/fetch_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ const (

// FetchDefinition is a descriptor for fetching Content from an endpoint.
type FetchDefinition struct {
URL string
Timeout time.Duration
Required bool
Header http.Header
Method string
Body io.Reader
RespProc ResponseProcessor
ErrHandler ErrorHandler
URL string
Timeout time.Duration
Required bool
FollowRedirects bool
Header http.Header
Method string
Body io.Reader
RespProc ResponseProcessor
ErrHandler ErrorHandler
//ServeResponseHeaders bool
//IsPrimary bool
//FallbackURL string
Expand All @@ -77,11 +78,12 @@ func NewFetchDefinitionWithErrorHandler(url string, errHandler ErrorHandler) *Fe
errHandler = NewDefaultErrorHandler()
}
return &FetchDefinition{
URL: url,
Timeout: DefaultTimeout,
Required: true,
Method: "GET",
ErrHandler: errHandler,
URL: url,
Timeout: DefaultTimeout,
FollowRedirects: false,
Required: true,
Method: "GET",
ErrHandler: errHandler,
}
}

Expand Down
31 changes: 22 additions & 9 deletions composition/http_content_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@ import (
"strings"
"time"

"errors"
"github.com/tarent/lib-compose/logging"
"net/url"
)

var redirectAttemptedError = errors.New("do not follow redirects")
var noRedirectFunc = func(req *http.Request, via []*http.Request) error {
return redirectAttemptedError
}

type HttpContentLoader struct {
parser map[string]ContentParser
}
Expand All @@ -30,7 +37,11 @@ func (loader *HttpContentLoader) Load(fd *FetchDefinition) (Content, error) {
c.url = fd.URL
c.httpStatusCode = 502

var err error
// redirects can only be stopped by returning an error in the CheckRedirect function
if !fd.FollowRedirects {
client.CheckRedirect = noRedirectFunc
}

request, err := http.NewRequest(fd.Method, fd.URL, fd.Body)
if err != nil {
return c, err
Expand All @@ -44,29 +55,31 @@ func (loader *HttpContentLoader) Load(fd *FetchDefinition) (Content, error) {
start := time.Now()

resp, err := client.Do(request)

logging.Call(request, resp, start, err)
if resp != nil {
c.httpStatusCode = resp.StatusCode
c.httpHeader = resp.Header
}

// do not handle our own redirects returns as errors
if urlError, ok := err.(*url.Error); ok && urlError.Err == redirectAttemptedError {
return c, nil
}

logging.Call(request, resp, start, err)
if err != nil {
return c, err
}

if fd.RespProc != nil {
err = fd.RespProc.Process(resp, fd.URL)
}
if err != nil {
return c, err
if err := fd.RespProc.Process(resp, fd.URL); err != nil {
return c, err
}
}

if c.httpStatusCode < 200 || c.httpStatusCode > 399 {
return c, fmt.Errorf("(http %v) on loading url %q", c.httpStatusCode, fd.URL)
}

c.httpHeader = resp.Header

// take the first parser for the content type
// direct access to the map does not work, because the
// content type may have encoding information at the end
Expand Down
53 changes: 53 additions & 0 deletions composition/http_content_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,59 @@ func Test_HttpContentLoader_LoadErrorNetwork(t *testing.T) {
a.Contains(err.Error(), "unsupported protocol scheme")
}

func Test_HttpContentLoader_FollowRedirects(t *testing.T) {
a := assert.New(t)

for _, status := range []int{301, 302} {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/redirected" {
http.Redirect(w, r, "/redirected", status)
return
}
w.Write([]byte("ok"))
}))

loader := &HttpContentLoader{}
fd := NewFetchDefinition(server.URL)
fd.FollowRedirects = true
c, err := loader.Load(fd)
a.NoError(err)
a.Equal(200, c.HttpStatusCode())

a.NotNil(c.Reader())
body, err := ioutil.ReadAll(c.Reader())
a.NoError(err)
a.Equal("ok", string(body))

server.Close()
}
}

func Test_HttpContentLoader_DoNotFollowRedirects(t *testing.T) {
a := assert.New(t)

for _, status := range []int{301, 302} {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/redirected" {
http.Redirect(w, r, "/redirected", status)
return
}
w.Write([]byte("ok"))
}))

loader := &HttpContentLoader{}
fd := NewFetchDefinition(server.URL)
fd.FollowRedirects = false
c, err := loader.Load(fd)
a.NoError(err)

a.Equal(status, c.HttpStatusCode())
a.Equal("/redirected", c.HttpHeader().Get("Location"))

server.Close()
}
}

func testServer(content string, timeout time.Duration) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
Expand Down

0 comments on commit 79e9fda

Please sign in to comment.