Skip to content

Commit

Permalink
Merge pull request #2 from pusher/context-for-req-cancellation
Browse files Browse the repository at this point in the history
Add context support for request cancellation
  • Loading branch information
Luís Fonseca committed Feb 18, 2019
2 parents bf541fe + 2e5bd9a commit 25559b3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
15 changes: 8 additions & 7 deletions gcm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package googlemessaging

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -125,7 +126,7 @@ type Notification struct {

// httpClient is an interface to stub the http client in tests.
type httpClient interface {
send(apiKey string, m HttpMessage) (*HttpResponse, error)
send(ctx context.Context, apiKey string, m HttpMessage) (*HttpResponse, error)
getRetryAfter() string
}

Expand All @@ -137,7 +138,7 @@ type httpGcmClient struct {
}

// httpGcmClient implementation to send a message through GCM Http server.
func (c *httpGcmClient) send(apiKey string, m HttpMessage) (*HttpResponse, error) {
func (c *httpGcmClient) send(ctx context.Context, apiKey string, m HttpMessage) (*HttpResponse, error) {
bs, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("error marshalling message>%v", err)
Expand All @@ -149,7 +150,7 @@ func (c *httpGcmClient) send(apiKey string, m HttpMessage) (*HttpResponse, error
}
req.Header.Add(http.CanonicalHeaderKey("Content-Type"), "application/json")
req.Header.Add(http.CanonicalHeaderKey("Authorization"), authHeader(apiKey))
httpResp, err := c.HttpClient.Do(req)
httpResp, err := c.HttpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, fmt.Errorf("error sending request to HTTP connection server>%v", err)
}
Expand Down Expand Up @@ -228,14 +229,14 @@ func (eb exponentialBackoff) wait() {
}

// Send a message using the HTTP GCM connection server.
func SendHttp(platform Platform, apiKey string, m HttpMessage) (*HttpResponse, error) {
func SendHttp(ctx context.Context, platform Platform, apiKey string, m HttpMessage) (*HttpResponse, error) {
c := &httpGcmClient{string(platform), &http.Client{}, "0"}
b := newExponentialBackoff()
return sendHttp(apiKey, m, c, b)
return sendHttp(ctx, apiKey, m, c, b)
}

// sendHttp sends an http message using exponential backoff, handling multicast replies.
func sendHttp(apiKey string, m HttpMessage, c httpClient, b backoffProvider) (*HttpResponse, error) {
func sendHttp(ctx context.Context, apiKey string, m HttpMessage, c httpClient, b backoffProvider) (*HttpResponse, error) {
// TODO(silvano): check this with responses for topic/notification group
gcmResp := &HttpResponse{}
var multicastId int
Expand All @@ -248,7 +249,7 @@ func sendHttp(apiKey string, m HttpMessage, c httpClient, b backoffProvider) (*H
copy(localTo, targets)
resultsState := &multicastResultsState{}
for b.sendAnother() {
gcmResp, err = c.send(apiKey, m)
gcmResp, err = c.send(ctx, apiKey, m)
if err != nil {
return gcmResp, fmt.Errorf("error sending request to GCM HTTP server: %v", err)
}
Expand Down
5 changes: 3 additions & 2 deletions gcm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package googlemessaging

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -125,7 +126,7 @@ func TestHttpClientSend(t *testing.T) {
}
httpClient := &http.Client{Transport: transport}
c := &httpGcmClient{server.URL, httpClient, "0"}
response, error := c.send("apiKey", *singleTargetMessage)
response, error := c.send(context.Background(), "apiKey", *singleTargetMessage)
expectedAuthHeader := "key=apiKey"
expResp := &HttpResponse{}
err := json.Unmarshal([]byte(expectedResp), &expResp)
Expand All @@ -149,7 +150,7 @@ func TestSendHttp(t *testing.T) {
if err != nil {
t.Fatalf("error: %v", err)
}
response, err := sendHttp("apiKey", *multipleTargetMessage, c, b)
response, err := sendHttp(context.Background(), "apiKey", *multipleTargetMessage, c, b)
assertDeepEqual(t, err, nil)
assertDeepEqual(t, response, expResp)
}
Expand Down

0 comments on commit 25559b3

Please sign in to comment.