Skip to content

Commit

Permalink
HostClient can't switch between protocols (#800)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikdubbelboer committed May 18, 2020
2 parents 5bd1b0c + dacd035 commit d22782d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 40 deletions.
35 changes: 8 additions & 27 deletions client.go
Expand Up @@ -881,6 +881,9 @@ var (
// ErrTooManyRedirects is returned by clients when the number of redirects followed
// exceed the max count.
ErrTooManyRedirects = errors.New("too many redirects detected when doing the request")

// HostClients are only able to follow redirects to the same protocol.
ErrHostClientRedirectToDifferentScheme = errors.New("HostClient can't follow redirects to a different protocol, please use Client instead")
)

const defaultMaxRedirectsCount = 16
Expand All @@ -903,27 +906,11 @@ func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clie
}

func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer) (statusCode int, body []byte, err error) {
scheme := req.uri.Scheme()
req.schemaUpdate = false
redirectsCount := 0

for {
// In case redirect to different scheme
if redirectsCount > 0 && !bytes.Equal(scheme, req.uri.Scheme()) {
if strings.HasPrefix(url, string(strHTTPS)) {
req.isTLS = true
req.uri.SetSchemeBytes(strHTTPS)
} else {
req.isTLS = false
req.uri.SetSchemeBytes(strHTTP)
}
scheme = req.uri.Scheme()
req.schemaUpdate = true
}

req.parsedURI = false
req.Header.host = req.Header.host[:0]
req.SetRequestURI(url)
req.parseURI()

if err = c.Do(req, resp); err != nil {
break
Expand Down Expand Up @@ -1271,6 +1258,10 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
panic("BUG: resp cannot be nil")
}

if c.IsTLS != bytes.Equal(req.uri.Scheme(), strHTTPS) {
return false, ErrHostClientRedirectToDifferentScheme
}

atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix))

// Free up resources occupied by response before sending the request,
Expand All @@ -1285,16 +1276,6 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
req.URI().DisablePathNormalizing = true
}

// If we detected a redirect to another schema
if req.schemaUpdate {
c.IsTLS = bytes.Equal(req.URI().Scheme(), strHTTPS)
c.Addr = addMissingPort(string(req.Host()), c.IsTLS)
c.addrIdx = 0
c.addrs = nil
req.schemaUpdate = false
req.SetConnectionClose()
}

cc, err := c.acquireConn(req.timeout)
if err != nil {
return false, err
Expand Down
49 changes: 39 additions & 10 deletions client_test.go
Expand Up @@ -245,7 +245,7 @@ func TestClientRedirectSameSchema(t *testing.T) {

urlParsed, err := url.Parse(destURL)
if err != nil {
fmt.Println(err)
t.Fatal(err)
return
}

Expand All @@ -270,7 +270,7 @@ func TestClientRedirectSameSchema(t *testing.T) {

}

func TestClientRedirectChangingSchemaHttp2Https(t *testing.T) {
func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) {
t.Parallel()

listenHTTPS := testClientRedirectListener(t, true)
Expand All @@ -287,14 +287,7 @@ func TestClientRedirectChangingSchemaHttp2Https(t *testing.T) {

destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())

urlParsed, err := url.Parse(destURL)
if err != nil {
fmt.Println(err)
return
}

reqClient := &HostClient{
Addr: urlParsed.Host,
reqClient := &Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Expand All @@ -312,6 +305,42 @@ func TestClientRedirectChangingSchemaHttp2Https(t *testing.T) {
}
}

func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) {
t.Parallel()

listenHTTPS := testClientRedirectListener(t, true)
defer listenHTTPS.Close()

listenHTTP := testClientRedirectListener(t, false)
defer listenHTTP.Close()

sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true)
defer sHTTPS.Stop()

sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false)
defer sHTTP.Stop()

destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String())

urlParsed, err := url.Parse(destURL)
if err != nil {
t.Fatal(err)
return
}

reqClient := &HostClient{
Addr: urlParsed.Host,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}

_, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != ErrHostClientRedirectToDifferentScheme {
t.Fatal("expected HostClient error")
}
}

func testClientRedirectListener(t *testing.T, isTLS bool) net.Listener {
var ln net.Listener
var err error
Expand Down
5 changes: 2 additions & 3 deletions http.go
Expand Up @@ -46,11 +46,10 @@ type Request struct {

keepBodyBuffer bool

// Used by Server to indicate the request was received on a HTTPS endpoint.
// Client/HostClient shouldn't use this field but should depend on the uri.scheme instead.
isTLS bool

// To detect scheme changes in redirects
schemaUpdate bool

// Request timeout. Usually set by DoDealine or DoTimeout
// if <= 0, means not set
timeout time.Duration
Expand Down

0 comments on commit d22782d

Please sign in to comment.