Skip to content

Commit

Permalink
Rename ClientPool to ClientManager
Browse files Browse the repository at this point in the history
  • Loading branch information
imhoffd authored and sideshow committed Oct 4, 2016
1 parent 30ae022 commit e42f05d
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 301 deletions.
2 changes: 1 addition & 1 deletion client.go
Expand Up @@ -52,7 +52,7 @@ type Client struct {
// connection and disconnection as a denial-of-service attack.
//
// If your use case involves multiple long-lived connections, consider using
// the ClientPool, which manages connections for you.
// the ClientManager, which manages clients for you.
func NewClient(certificate tls.Certificate) *Client {
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{certificate},
Expand Down
162 changes: 162 additions & 0 deletions client_manager.go
@@ -0,0 +1,162 @@
package apns2

import (
"container/list"
"crypto/sha1"
"crypto/tls"
"sync"
"time"
)

type managerItem struct {
key [sha1.Size]byte
client *Client
lastUsed time.Time
}

// ClientManager is a way to manage multiple connections to the APNs.
type ClientManager struct {
// MaxSize is the maximum number of clients allowed in the manager. When
// this limit is reached, the least recently used client is evicted. Set
// zero for no limit.
MaxSize int

// MaxAge is the maximum age of clients in the manager. Upon retrieval, if
// a client has remained unused in the manager for this duration or longer,
// it is evicted and nil is returned. Set zero to disable this
// functionality.
MaxAge time.Duration

// Factory is the function which constructs clients if not found in the
// manager.
Factory func(certificate tls.Certificate) *Client

cache map[[sha1.Size]byte]*list.Element
ll *list.List
mu sync.Mutex
}

// NewClientManager returns a new ClientManager for prolonged, concurrent usage
// of multiple APNs clients. ClientManager is flexible enough to work best for
// your use case. When a client is not found in the manager, Get will return
// the result of calling Factory, which can be a Client or nil.
//
// Having multiple clients per certificate in the manager is not allowed.
//
// By default, MaxSize is 64, MaxAge is 10 minutes, and Factory always returns
// a Client with default options.
func NewClientManager() *ClientManager {
manager := &ClientManager{
MaxSize: 64,
MaxAge: 10 * time.Minute,
Factory: NewClient,
}

manager.initInternals()

return manager
}

// Add adds a Client to the manager. You can use this to individually configure
// Clients in the manager.
func (m *ClientManager) Add(client *Client) {
if m.cache == nil {
m.initInternals()
}
m.mu.Lock()
defer m.mu.Unlock()
key := cacheKey(client.Certificate)
now := time.Now()
if ele, hit := m.cache[key]; hit {
item := ele.Value.(*managerItem)
item.client = client
item.lastUsed = now
m.ll.MoveToFront(ele)
return
}
ele := m.ll.PushFront(&managerItem{key, client, now})
m.cache[key] = ele
if m.MaxSize != 0 && m.ll.Len() > m.MaxSize {
m.mu.Unlock()
m.removeOldest()
m.mu.Lock()
}
}

// Get gets a Client from the manager. If a Client is not found in the manager
// or if a Client has remained in the manager longer than MaxAge, Get will call
// the ClientManager's Factory function, store the result in the manager if
// non-nil, and return it.
func (m *ClientManager) Get(certificate tls.Certificate) *Client {
if m.cache == nil {
m.initInternals()
}
m.mu.Lock()
defer m.mu.Unlock()
key := cacheKey(certificate)
now := time.Now()
if ele, hit := m.cache[key]; hit {
item := ele.Value.(*managerItem)
if m.MaxAge != 0 && item.lastUsed.Before(now.Add(-m.MaxAge)) {
c := m.Factory(certificate)
if c == nil {
return nil
}
item.client = c
}
item.lastUsed = now
m.ll.MoveToFront(ele)
return item.client
}

c := m.Factory(certificate)
if c == nil {
return nil
}
m.mu.Unlock()
m.Add(c)
m.mu.Lock()
return c
}

// Len returns the current size of the ClientManager.
func (m *ClientManager) Len() int {
if m.cache == nil {
return 0
}
m.mu.Lock()
defer m.mu.Unlock()
return m.ll.Len()
}

func (m *ClientManager) initInternals() {
m.cache = map[[sha1.Size]byte]*list.Element{}
m.ll = list.New()
m.mu = sync.Mutex{}
}

func (m *ClientManager) removeOldest() {
m.mu.Lock()
ele := m.ll.Back()
m.mu.Unlock()
if ele != nil {
m.removeElement(ele)
}
}

func (m *ClientManager) removeElement(e *list.Element) {
m.mu.Lock()
defer m.mu.Unlock()
m.ll.Remove(e)
delete(m.cache, e.Value.(*managerItem).key)
}

func cacheKey(certificate tls.Certificate) [sha1.Size]byte {
var data []byte

for _, cert := range certificate.Certificate {
data = append(data, cert...)
}

return sha1.Sum(data)
}
139 changes: 139 additions & 0 deletions client_manager_test.go
@@ -0,0 +1,139 @@
package apns2_test

import (
"bytes"
"crypto/tls"
"reflect"
"testing"
"time"

"github.com/sideshow/apns2"
"github.com/sideshow/apns2/certificate"
"github.com/stretchr/testify/assert"
)

func TestNewClientManager(t *testing.T) {
manager := apns2.NewClientManager()
assert.Equal(t, manager.MaxSize, 64)
assert.Equal(t, manager.MaxAge, 10*time.Minute)
}

func TestClientManagerGetWithoutNew(t *testing.T) {
manager := apns2.ClientManager{
MaxSize: 32,
MaxAge: 5 * time.Minute,
Factory: apns2.NewClient,
}

c1 := manager.Get(mockCert())
c2 := manager.Get(mockCert())
v1 := reflect.ValueOf(c1)
v2 := reflect.ValueOf(c2)
assert.NotNil(t, c1)
assert.Equal(t, v1.Pointer(), v2.Pointer())
assert.Equal(t, 1, manager.Len())
}

func TestClientManagerAddWithoutNew(t *testing.T) {
manager := apns2.ClientManager{
MaxSize: 32,
MaxAge: 5 * time.Minute,
Factory: apns2.NewClient,
}

manager.Add(apns2.NewClient(mockCert()))
assert.Equal(t, 1, manager.Len())
}

func TestClientManagerLenWithoutNew(t *testing.T) {
manager := apns2.ClientManager{
MaxSize: 32,
MaxAge: 5 * time.Minute,
Factory: apns2.NewClient,
}

assert.Equal(t, 0, manager.Len())
}

func TestClientManagerGetDefaultOptions(t *testing.T) {
manager := apns2.NewClientManager()
c1 := manager.Get(mockCert())
c2 := manager.Get(mockCert())
v1 := reflect.ValueOf(c1)
v2 := reflect.ValueOf(c2)
assert.NotNil(t, c1)
assert.Equal(t, v1.Pointer(), v2.Pointer())
assert.Equal(t, 1, manager.Len())
}

func TestClientManagerGetNilClientFactory(t *testing.T) {
manager := apns2.NewClientManager()
manager.Factory = func(certificate tls.Certificate) *apns2.Client {
return nil
}
c1 := manager.Get(mockCert())
c2 := manager.Get(mockCert())
assert.Nil(t, c1)
assert.Nil(t, c2)
assert.Equal(t, 0, manager.Len())
}

func TestClientManagerGetMaxAgeExpiration(t *testing.T) {
manager := apns2.NewClientManager()
manager.MaxAge = time.Nanosecond
c1 := manager.Get(mockCert())
time.Sleep(time.Microsecond)
c2 := manager.Get(mockCert())
v1 := reflect.ValueOf(c1)
v2 := reflect.ValueOf(c2)
assert.NotNil(t, c1)
assert.NotEqual(t, v1.Pointer(), v2.Pointer())
assert.Equal(t, 1, manager.Len())
}

func TestClientManagerGetMaxAgeExpirationWithNilFactory(t *testing.T) {
manager := apns2.NewClientManager()
manager.Factory = func(certificate tls.Certificate) *apns2.Client {
return nil
}
manager.MaxAge = time.Nanosecond
manager.Add(apns2.NewClient(mockCert()))
c1 := manager.Get(mockCert())
time.Sleep(time.Microsecond)
c2 := manager.Get(mockCert())
assert.Nil(t, c1)
assert.Nil(t, c2)
assert.Equal(t, 1, manager.Len())
}

func TestClientManagerGetMaxSizeExceeded(t *testing.T) {
manager := apns2.NewClientManager()
manager.MaxSize = 1
cert1 := mockCert()
_ = manager.Get(cert1)
cert2, _ := certificate.FromP12File("certificate/_fixtures/certificate-valid.p12", "")
_ = manager.Get(cert2)
cert3, _ := certificate.FromP12File("certificate/_fixtures/certificate-valid-encrypted.p12", "password")
c := manager.Get(cert3)
assert.True(t, bytes.Equal(cert3.Certificate[0], c.Certificate.Certificate[0]))
assert.Equal(t, 1, manager.Len())
}

func TestClientManagerAdd(t *testing.T) {
fn := func(certificate tls.Certificate) *apns2.Client {
t.Fatal("factory should not have been called")
return nil
}

manager := apns2.NewClientManager()
manager.Factory = fn
manager.Add(apns2.NewClient(mockCert()))
manager.Get(mockCert())
}

func TestClientManagerAddTwice(t *testing.T) {
manager := apns2.NewClientManager()
manager.Add(apns2.NewClient(mockCert()))
manager.Add(apns2.NewClient(mockCert()))
assert.Equal(t, 1, manager.Len())
}

0 comments on commit e42f05d

Please sign in to comment.