Skip to content

Commit

Permalink
stripe login: poll for api key (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmuia-stripe authored and ob-stripe committed Jul 1, 2019
1 parent 5a9f65e commit ef548c8
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 2 deletions.
65 changes: 65 additions & 0 deletions login/poll.go
@@ -0,0 +1,65 @@
package login

import (
"encoding/json"
"errors"
"fmt"
"github.com/stripe/stripe-cli/stripeauth"
"io/ioutil"
"net/http"
"time"
)

const maxAttemptsDefault = 2 * 60
const intervalDefault = 1 * time.Second

type pollAPIKeyResponse struct {
Redeemed bool `json:"redeemed"`
AccountID string `json:"account_id"`
APIKey string `json:"api_key"`
}

// PollForKey polls Stripe at the specified interval until either the API key is available or we've reached the max attempts.
func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (string, error) {
if maxAttempts == 0 {
maxAttempts = maxAttemptsDefault
}

if interval == 0 {
interval = intervalDefault
}

client := stripeauth.NewHTTPClient("")

var count = 0
for count < maxAttempts {
res, err := client.Get(pollURL)
if err != nil {
return "", err
}

bodyBytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return "", err
}

if res.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected http status code: %d %s", res.StatusCode, string(bodyBytes))
}

var response pollAPIKeyResponse
jsonErr := json.Unmarshal(bodyBytes, &response)
if jsonErr != nil {
return "", jsonErr
}

if response.Redeemed {
return response.APIKey, nil
}

count++
time.Sleep(interval)
}

return "", errors.New("exceeded max attempts")
}
91 changes: 91 additions & 0 deletions login/poll_test.go
@@ -0,0 +1,91 @@
package login

import (
"encoding/json"
assert "github.com/stretchr/testify/require"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)

func TestRedeemed(t *testing.T) {
var attempts uint64

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)

atomic.AddUint64(&attempts, 1)

response := &pollAPIKeyResponse{
Redeemed: false,
}
if atomic.LoadUint64(&attempts) == 2 {
response.Redeemed = true
response.AccountID = "acct_123"
response.APIKey = "sk_test_123"
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer ts.Close()

apiKey, err := PollForKey(ts.URL, 1*time.Millisecond, 3)
assert.NoError(t, err)
assert.Equal(t, "sk_test_123", apiKey)
assert.Equal(t, uint64(2), atomic.LoadUint64(&attempts))
}

func TestExceedMaxAttempts(t *testing.T) {
var attempts uint64

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)

atomic.AddUint64(&attempts, 1)

response := pollAPIKeyResponse{
Redeemed: false,
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer ts.Close()

apiKey, err := PollForKey(ts.URL, 1*time.Millisecond, 3)
assert.EqualError(t, err, "exceeded max attempts")
assert.Empty(t, apiKey)
assert.Equal(t, uint64(3), atomic.LoadUint64(&attempts))
}

func TestHTTPStatusError(t *testing.T) {
var attempts uint64

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)

atomic.AddUint64(&attempts, 1)

w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()

apiKey, err := PollForKey(ts.URL, 1*time.Millisecond, 3)
assert.EqualError(t, err, "unexpected http status code: 500 ")
assert.Empty(t, apiKey)
assert.Equal(t, uint64(1), atomic.LoadUint64(&attempts))
}

func TestHTTPRequestError(t *testing.T) {
// Immediately close the HTTP server so that the poll request fails.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
ts.Close()

apiKey, err := PollForKey(ts.URL, 1*time.Millisecond, 3)
assert.Error(t, err)
assert.Contains(t, err.Error(), "connect: connection refused")
assert.Empty(t, apiKey)
}
5 changes: 3 additions & 2 deletions stripeauth/client.go
Expand Up @@ -101,7 +101,7 @@ func NewClient(key string, cfg *Config) *Client {
cfg.Log = &log.Logger{Out: ioutil.Discard}
}
if cfg.HTTPClient == nil {
cfg.HTTPClient = newHTTPClient(cfg.UnixSocket)
cfg.HTTPClient = NewHTTPClient(cfg.UnixSocket)
}
if cfg.URL == "" {
cfg.URL = defaultAuthorizeURL
Expand All @@ -121,7 +121,8 @@ const (
defaultAuthorizeURL = "https://api.stripe.com/v1/stripecli/sessions"
)

func newHTTPClient(unixSocket string) *http.Client {
// NewHTTPClient returns a configured HTTP client.
func NewHTTPClient(unixSocket string) *http.Client {
var httpTransport *http.Transport
if unixSocket != "" {
dialFunc := func(network, addr string) (net.Conn, error) {
Expand Down

0 comments on commit ef548c8

Please sign in to comment.