Skip to content

Commit

Permalink
Merge branch 'master' into patch-7
Browse files Browse the repository at this point in the history
  • Loading branch information
moredure committed Apr 4, 2022
2 parents cc3044c + 45981da commit 46601f8
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 54 deletions.
2 changes: 1 addition & 1 deletion certificate/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func FromPemBytes(bytes []byte, password string) (tls.Certificate, error) {
if block.Type == "CERTIFICATE" {
cert.Certificate = append(cert.Certificate, block.Bytes)
}
if block.Type == "PRIVATE KEY" || strings.HasSuffix(block.Type, "PRIVATE KEY") {
if strings.HasSuffix(block.Type, "PRIVATE KEY") {
key, err := unencryptPrivateKey(block, password)
if err != nil {
return tls.Certificate{}, err
Expand Down
40 changes: 11 additions & 29 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"sync"
"strconv"
"time"

"github.com/sideshow/apns2/token"
Expand Down Expand Up @@ -144,14 +143,7 @@ func (c *Client) Production() *Client {
//
// Use PushWithContext if you need better cancellation and timeout control.
func (c *Client) Push(n *Notification) (*Response, error) {
return c.PushWithContext(nil, n)
}

// payloads pool of bytes.Buffer holding notifications
var payloads = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
return c.PushWithContext(context.Background(), n)
}

// PushWithContext sends a Notification to the APNs gateway. Context carries a
Expand All @@ -164,16 +156,13 @@ var payloads = sync.Pool{
// return a Response indicating whether the notification was accepted or
// rejected by the APNs gateway, or an error if something goes wrong.
func (c *Client) PushWithContext(ctx Context, n *Notification) (*Response, error) {
payload := payloads.Get().(*bytes.Buffer)
payload.Reset()
defer payloads.Put(payload)
if err := json.NewEncoder(payload).Encode(n); err != nil {
payload, err := n.MarshalJSON()
if err != nil {
return nil, err
}
payload.Truncate(payload.Len() - len("\n"))

url := fmt.Sprintf("%v/3/device/%v", c.Host, n.DeviceToken)
req, err := http.NewRequest(http.MethodPost, url, payload)
url := c.Host + "/3/device/" + n.DeviceToken
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if err != nil {
return nil, err
}
Expand All @@ -184,7 +173,7 @@ func (c *Client) PushWithContext(ctx Context, n *Notification) (*Response, error

setHeaders(req, n)

httpRes, err := c.requestWithContext(ctx, req)
httpRes, err := c.HTTPClient.Do(req)
if err != nil {
return nil, err
}
Expand All @@ -195,7 +184,7 @@ func (c *Client) PushWithContext(ctx Context, n *Notification) (*Response, error
response.ApnsID = httpRes.Header.Get("apns-id")

decoder := json.NewDecoder(httpRes.Body)
if err := decoder.Decode(&response); err != nil && err != io.EOF {
if err := decoder.Decode(response); err != nil && err != io.EOF {
return &Response{}, err
}
return response, nil
Expand All @@ -210,7 +199,7 @@ func (c *Client) CloseIdleConnections() {

func (c *Client) setTokenHeader(r *http.Request) {
bearer := c.Token.GenerateIfExpired()
r.Header.Set("authorization", fmt.Sprintf("bearer %v", bearer))
r.Header.Set("authorization", "bearer "+bearer)
}

func setHeaders(r *http.Request, n *Notification) {
Expand All @@ -225,10 +214,10 @@ func setHeaders(r *http.Request, n *Notification) {
r.Header.Set("apns-collapse-id", n.CollapseID)
}
if n.Priority > 0 {
r.Header.Set("apns-priority", fmt.Sprintf("%v", n.Priority))
r.Header.Set("apns-priority", strconv.Itoa(n.Priority))
}
if !n.Expiration.IsZero() {
r.Header.Set("apns-expiration", fmt.Sprintf("%v", n.Expiration.Unix()))
r.Header.Set("apns-expiration", strconv.FormatInt(n.Expiration.Unix(), 10))
}
if n.PushType != "" {
r.Header.Set("apns-push-type", string(n.PushType))
Expand All @@ -237,10 +226,3 @@ func setHeaders(r *http.Request, n *Notification) {
}

}

func (c *Client) requestWithContext(ctx Context, req *http.Request) (*http.Response, error) {
if ctx != nil {
req = req.WithContext(ctx)
}
return c.HTTPClient.Do(req)
}
53 changes: 29 additions & 24 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package apns2_test

import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -213,6 +211,35 @@ func TestClientPushWithContext(t *testing.T) {
assert.Equal(t, res.ApnsID, apnsID)
}

func TestClientPushWithNilNotification(t *testing.T) {
var apnsID = "02ABC856-EF8D-4E49-8F15-7B8A61D978D6"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("apns-id", apnsID)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

assert.Panics(t, func() {
mockClient(server.URL).PushWithContext(context.Background(), nil)
})
}

func TestClientPushWithNilContext(t *testing.T) {
n := mockNotification()
var apnsID = "02ABC856-EF8D-4E49-8F15-7B8A61D978D6"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Header().Set("apns-id", apnsID)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

res, err := mockClient(server.URL).PushWithContext(nil, n)
assert.EqualError(t, err, "net/http: nil Context")
assert.Nil(t, res)
}

func TestHeaders(t *testing.T) {
n := mockNotification()
n.ApnsID = "84DB694F-464F-49BD-960A-D6DB028335C9"
Expand Down Expand Up @@ -418,25 +445,3 @@ func TestCloseIdleConnections(t *testing.T) {
client.CloseIdleConnections()
assert.Equal(t, true, transport.closed)
}

func BenchmarkEncoding(b *testing.B) {
buf := new(bytes.Buffer)
n := mockNotification()
b.Run("old", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := json.Marshal(n); err != nil {
panic(err)
}
}
})
b.Run("new", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if err := json.NewEncoder(buf).Encode(n); err != nil {
panic(err)
}
buf.Reset()
}
})
}
3 changes: 3 additions & 0 deletions notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ type Notification struct {

// MarshalJSON converts the notification payload to JSON.
func (n *Notification) MarshalJSON() ([]byte, error) {
if n == nil {
return []byte("null"), nil
}
switch payload := n.Payload.(type) {
case string:
return []byte(payload), nil
Expand Down

0 comments on commit 46601f8

Please sign in to comment.