Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5a9f65e
commit ef548c8
Showing
3 changed files
with
159 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
package login | ||
|
||
import ( | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"github.com/stripe/stripe-cli/stripeauth" | ||
"io/ioutil" | ||
"net/http" | ||
"time" | ||
) | ||
|
||
const maxAttemptsDefault = 2 * 60 | ||
const intervalDefault = 1 * time.Second | ||
|
||
type pollAPIKeyResponse struct { | ||
Redeemed bool `json:"redeemed"` | ||
AccountID string `json:"account_id"` | ||
APIKey string `json:"api_key"` | ||
} | ||
|
||
// PollForKey polls Stripe at the specified interval until either the API key is available or we've reached the max attempts. | ||
func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (string, error) { | ||
if maxAttempts == 0 { | ||
maxAttempts = maxAttemptsDefault | ||
} | ||
|
||
if interval == 0 { | ||
interval = intervalDefault | ||
} | ||
|
||
client := stripeauth.NewHTTPClient("") | ||
|
||
var count = 0 | ||
for count < maxAttempts { | ||
res, err := client.Get(pollURL) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
bodyBytes, err := ioutil.ReadAll(res.Body) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
if res.StatusCode != http.StatusOK { | ||
return "", fmt.Errorf("unexpected http status code: %d %s", res.StatusCode, string(bodyBytes)) | ||
} | ||
|
||
var response pollAPIKeyResponse | ||
jsonErr := json.Unmarshal(bodyBytes, &response) | ||
if jsonErr != nil { | ||
return "", jsonErr | ||
} | ||
|
||
if response.Redeemed { | ||
return response.APIKey, nil | ||
} | ||
|
||
count++ | ||
time.Sleep(interval) | ||
} | ||
|
||
return "", errors.New("exceeded max attempts") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package login | ||
|
||
import ( | ||
"encoding/json" | ||
assert "github.com/stretchr/testify/require" | ||
"net/http" | ||
"net/http/httptest" | ||
"sync/atomic" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestRedeemed(t *testing.T) { | ||
var attempts uint64 | ||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
assert.Equal(t, "GET", r.Method) | ||
|
||
atomic.AddUint64(&attempts, 1) | ||
|
||
response := &pollAPIKeyResponse{ | ||
Redeemed: false, | ||
} | ||
if atomic.LoadUint64(&attempts) == 2 { | ||
response.Redeemed = true | ||
response.AccountID = "acct_123" | ||
response.APIKey = "sk_test_123" | ||
} | ||
w.WriteHeader(http.StatusOK) | ||
w.Header().Set("Content-Type", "application/json") | ||
json.NewEncoder(w).Encode(response) | ||
})) | ||
defer ts.Close() | ||
|
||
apiKey, err := PollForKey(ts.URL, 1*time.Millisecond, 3) | ||
assert.NoError(t, err) | ||
assert.Equal(t, "sk_test_123", apiKey) | ||
assert.Equal(t, uint64(2), atomic.LoadUint64(&attempts)) | ||
} | ||
|
||
func TestExceedMaxAttempts(t *testing.T) { | ||
var attempts uint64 | ||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
assert.Equal(t, "GET", r.Method) | ||
|
||
atomic.AddUint64(&attempts, 1) | ||
|
||
response := pollAPIKeyResponse{ | ||
Redeemed: false, | ||
} | ||
w.WriteHeader(http.StatusOK) | ||
w.Header().Set("Content-Type", "application/json") | ||
json.NewEncoder(w).Encode(response) | ||
})) | ||
defer ts.Close() | ||
|
||
apiKey, err := PollForKey(ts.URL, 1*time.Millisecond, 3) | ||
assert.EqualError(t, err, "exceeded max attempts") | ||
assert.Empty(t, apiKey) | ||
assert.Equal(t, uint64(3), atomic.LoadUint64(&attempts)) | ||
} | ||
|
||
func TestHTTPStatusError(t *testing.T) { | ||
var attempts uint64 | ||
|
||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
assert.Equal(t, "GET", r.Method) | ||
|
||
atomic.AddUint64(&attempts, 1) | ||
|
||
w.WriteHeader(http.StatusInternalServerError) | ||
})) | ||
defer ts.Close() | ||
|
||
apiKey, err := PollForKey(ts.URL, 1*time.Millisecond, 3) | ||
assert.EqualError(t, err, "unexpected http status code: 500 ") | ||
assert.Empty(t, apiKey) | ||
assert.Equal(t, uint64(1), atomic.LoadUint64(&attempts)) | ||
} | ||
|
||
func TestHTTPRequestError(t *testing.T) { | ||
// Immediately close the HTTP server so that the poll request fails. | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) | ||
ts.Close() | ||
|
||
apiKey, err := PollForKey(ts.URL, 1*time.Millisecond, 3) | ||
assert.Error(t, err) | ||
assert.Contains(t, err.Error(), "connect: connection refused") | ||
assert.Empty(t, apiKey) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters