diff --git a/pkg/cmd/listen.go b/pkg/cmd/listen.go index 5d1b67c5c..8b5ace22b 100644 --- a/pkg/cmd/listen.go +++ b/pkg/cmd/listen.go @@ -90,7 +90,7 @@ func (lc *listenCmd) runListenCmd(cmd *cobra.Command, args []string) error { endpointRoutes := make([]proxy.EndpointRoute, 0) - key, err := Config.Profile.GetAPIKey() + key, err := Config.Profile.GetAPIKey(false) if err != nil { return err } diff --git a/pkg/cmd/logs/tail.go b/pkg/cmd/logs/tail.go index 32eed271e..c9b94fecf 100644 --- a/pkg/cmd/logs/tail.go +++ b/pkg/cmd/logs/tail.go @@ -132,7 +132,7 @@ func (tailCmd *TailCmd) runTailCmd(cmd *cobra.Command, args []string) error { return err } - key, err := tailCmd.cfg.Profile.GetAPIKey() + key, err := tailCmd.cfg.Profile.GetAPIKey(false) if err != nil { return err } diff --git a/pkg/cmd/resource/operation.go b/pkg/cmd/resource/operation.go index 0c6cc66e0..01052bbf8 100644 --- a/pkg/cmd/resource/operation.go +++ b/pkg/cmd/resource/operation.go @@ -32,7 +32,7 @@ type OperationCmd struct { } func (oc *OperationCmd) runOperationCmd(cmd *cobra.Command, args []string) error { - apiKey, err := oc.Profile.GetAPIKey() + apiKey, err := oc.Profile.GetAPIKey(false) if err != nil { return err } diff --git a/pkg/cmd/trigger.go b/pkg/cmd/trigger.go index c7bfcc19d..d8f33dce0 100644 --- a/pkg/cmd/trigger.go +++ b/pkg/cmd/trigger.go @@ -81,7 +81,7 @@ needed to create the triggered event. } func (tc *triggerCmd) runTriggerCmd(cmd *cobra.Command, args []string) error { - apiKey, err := Config.Profile.GetAPIKey() + apiKey, err := Config.Profile.GetAPIKey(false) if err != nil { return err } diff --git a/pkg/config/profile.go b/pkg/config/profile.go index 809d57076..0770f03c2 100644 --- a/pkg/config/profile.go +++ b/pkg/config/profile.go @@ -13,10 +13,13 @@ import ( // Profile handles all things related to managing the project specific configurations type Profile struct { - DeviceName string - ProfileName string - APIKey string - PublishableKey string + DeviceName string + ProfileName string + APIKey string + LiveModeAPIKey string + LiveModePublishableKey string + TestModeAPIKey string + TestModePublishableKey string } // CreateProfile creates a profile when logging in @@ -64,7 +67,7 @@ func (p *Profile) GetDeviceName() (string, error) { } // GetAPIKey will return the existing key for the given profile -func (p *Profile) GetAPIKey() (string, error) { +func (p *Profile) GetAPIKey(livemode bool) (string, error) { if p.APIKey != "" { err := validators.APIKey(p.APIKey) if err != nil { @@ -75,13 +78,17 @@ func (p *Profile) GetAPIKey() (string, error) { // If the user doesn't have an api_key field set, they might be using an // old configuration so try to read from secret_key - if !viper.IsSet(p.GetConfigField("api_key")) { - p.RegisterAlias("api_key", "secret_key") + if !livemode { + if !viper.IsSet(p.GetConfigField("api_key")) { + p.RegisterAlias("api_key", "secret_key") + } else { + p.RegisterAlias("test_mode_api_key", "api_key") + } } // Try to fetch the API key from the configuration file if err := viper.ReadInConfig(); err == nil { - key := viper.GetString(p.GetConfigField("api_key")) + key := viper.GetString(p.GetConfigField(livemodeKeyField(livemode))) err := validators.APIKey(key) if err != nil { return "", err @@ -129,27 +136,28 @@ func (p *Profile) writeProfile(runtimeViper *viper.Viper) error { if p.DeviceName != "" { runtimeViper.Set(p.GetConfigField("device_name"), strings.TrimSpace(p.DeviceName)) } - if p.APIKey != "" { - runtimeViper.Set(p.GetConfigField("api_key"), strings.TrimSpace(p.APIKey)) + if p.LiveModeAPIKey != "" { + runtimeViper.Set(p.GetConfigField("live_mode_api_key"), strings.TrimSpace(p.LiveModeAPIKey)) + } + if p.LiveModePublishableKey != "" { + runtimeViper.Set(p.GetConfigField("live_mode_publishable_key"), strings.TrimSpace(p.LiveModePublishableKey)) + } + if p.TestModeAPIKey != "" { + runtimeViper.Set(p.GetConfigField("test_mode_api_key"), strings.TrimSpace(p.TestModeAPIKey)) } - if p.PublishableKey != "" { - runtimeViper.Set(p.GetConfigField("publishable_key"), strings.TrimSpace(p.PublishableKey)) + if p.TestModePublishableKey != "" { + runtimeViper.Set(p.GetConfigField("test_mode_publishable_key"), strings.TrimSpace(p.TestModePublishableKey)) } runtimeViper.MergeInConfig() // Do this after we merge the old configs in - if p.APIKey != "" { - if runtimeViper.IsSet(p.GetConfigField("secret_key")) { - newViper, err := removeKey(runtimeViper, p.GetConfigField("secret_key")) - if err == nil { - // I don't want to fail the entire login process on not being able to remove - // the old secret_key field so if there's no error - runtimeViper = newViper - } else { - fmt.Println(err) - } - } + if p.TestModeAPIKey != "" { + runtimeViper = p.safeRemove(runtimeViper, "secret_key") + runtimeViper = p.safeRemove(runtimeViper, "api_key") + } + if p.TestModePublishableKey != "" { + runtimeViper = p.safeRemove(runtimeViper, "publishable_key") } runtimeViper.SetConfigFile(profilesFile) @@ -163,3 +171,24 @@ func (p *Profile) writeProfile(runtimeViper *viper.Viper) error { return nil } + +func (p *Profile) safeRemove(v *viper.Viper, key string) *viper.Viper { + if v.IsSet(p.GetConfigField(key)) { + newViper, err := removeKey(v, p.GetConfigField(key)) + if err == nil { + // I don't want to fail the entire login process on not being able to remove + // the old secret_key field so if there's no error + return newViper + } + } + + return v +} + +func livemodeKeyField(livemode bool) string { + if livemode { + return "live_mode_api_key" + } + + return "test_mode_api_key" +} diff --git a/pkg/config/profile_test.go b/pkg/config/profile_test.go index fd61354c3..566e96199 100644 --- a/pkg/config/profile_test.go +++ b/pkg/config/profile_test.go @@ -14,9 +14,9 @@ import ( func TestWriteProfile(t *testing.T) { profilesFile := filepath.Join(os.TempDir(), "stripe", "config.toml") p := Profile{ - DeviceName: "st-testing", - ProfileName: "tests", - APIKey: "sk_test_123", + DeviceName: "st-testing", + ProfileName: "tests", + TestModeAPIKey: "sk_test_123", } c := &Config{ @@ -38,8 +38,8 @@ func TestWriteProfile(t *testing.T) { configValues := helperLoadBytes(t, c.ProfilesFile) expectedConfig := ` [tests] - api_key = "sk_test_123" device_name = "st-testing" + test_mode_api_key = "sk_test_123" ` require.EqualValues(t, expectedConfig, string(configValues)) @@ -49,9 +49,9 @@ func TestWriteProfile(t *testing.T) { func TestWriteProfilesMerge(t *testing.T) { profilesFile := filepath.Join(os.TempDir(), "stripe", "config.toml") p := Profile{ - ProfileName: "tests", - DeviceName: "st-testing", - APIKey: "sk_test_123", + ProfileName: "tests", + DeviceName: "st-testing", + TestModeAPIKey: "sk_test_123", } c := &Config{ @@ -76,12 +76,12 @@ func TestWriteProfilesMerge(t *testing.T) { configValues := helperLoadBytes(t, c.ProfilesFile) expectedConfig := ` [tests] - api_key = "sk_test_123" device_name = "st-testing" + test_mode_api_key = "sk_test_123" [tests-merge] - api_key = "sk_test_123" device_name = "st-testing" + test_mode_api_key = "sk_test_123" ` require.EqualValues(t, expectedConfig, string(configValues)) diff --git a/pkg/login/client_login.go b/pkg/login/client_login.go index 405dd246f..68cf2ca71 100644 --- a/pkg/login/client_login.go +++ b/pkg/login/client_login.go @@ -56,24 +56,24 @@ func Login(baseURL string, config *config.Config, input io.Reader) error { } //Call poll function - apiKey, publishableKey, account, err := PollForKey(links.PollURL, 0, 0) + response, account, err := PollForKey(links.PollURL, 0, 0) if err != nil { return err } - validateErr := validators.APIKey(apiKey) + validateErr := validators.APIKey(response.TestModeAPIKey) if validateErr != nil { return validateErr } - config.Profile.APIKey = apiKey - config.Profile.PublishableKey = publishableKey + config.Profile.TestModeAPIKey = response.TestModeAPIKey + config.Profile.TestModePublishableKey = response.TestModePublishableKey profileErr := config.Profile.CreateProfile() if profileErr != nil { return profileErr } - message, err := SuccessMessage(account, stripe.DefaultAPIBaseURL, apiKey) + message, err := SuccessMessage(account, stripe.DefaultAPIBaseURL, response.TestModeAPIKey) if err != nil { fmt.Println(fmt.Sprintf("> Error verifying the CLI was set up successfully: %s", err)) } else { diff --git a/pkg/login/interactive_login.go b/pkg/login/interactive_login.go index 6b7958a4b..4feb6b05f 100644 --- a/pkg/login/interactive_login.go +++ b/pkg/login/interactive_login.go @@ -26,7 +26,7 @@ func InteractiveLogin(config *config.Config) error { } config.Profile.DeviceName = getConfigureDeviceName(os.Stdin) - config.Profile.APIKey = apiKey + config.Profile.TestModeAPIKey = apiKey profileErr := config.Profile.CreateProfile() if profileErr != nil { diff --git a/pkg/login/poll.go b/pkg/login/poll.go index dfd3096d2..974e315d0 100644 --- a/pkg/login/poll.go +++ b/pkg/login/poll.go @@ -15,16 +15,19 @@ import ( const maxAttemptsDefault = 2 * 60 const intervalDefault = 1 * time.Second -type pollAPIKeyResponse struct { - Redeemed bool `json:"redeemed"` - AccountID string `json:"account_id"` - AccountDisplayName string `json:"account_display_name"` - APIKey string `json:"testmode_key_secret"` - PublishableKey string `json:"testmode_key_publishable"` +// PollAPIKeyResponse returns the data of the polling client login +type PollAPIKeyResponse struct { + Redeemed bool `json:"redeemed"` + AccountID string `json:"account_id"` + AccountDisplayName string `json:"account_display_name"` + TestModeAPIKey string `json:"testmode_key_secret"` + TestModePublishableKey string `json:"testmode_key_publishable"` } // PollForKey polls Stripe at the specified interval until either the API key is available or we've reached the max attempts. -func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (string, string, *Account, error) { +func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (*PollAPIKeyResponse, *Account, error) { + var response PollAPIKeyResponse + if maxAttempts == 0 { maxAttempts = maxAttemptsDefault } @@ -35,7 +38,7 @@ func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (string parsedURL, err := url.Parse(pollURL) if err != nil { - return "", "", nil, err + return nil, nil, err } baseURL := &url.URL{Scheme: parsedURL.Scheme, Host: parsedURL.Host} @@ -48,23 +51,22 @@ func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (string for count < maxAttempts { res, err := client.PerformRequest(http.MethodGet, parsedURL.Path, parsedURL.Query().Encode(), nil) if err != nil { - return "", "", nil, err + return nil, nil, err } defer res.Body.Close() bodyBytes, err := ioutil.ReadAll(res.Body) if err != nil { - return "", "", nil, err + return nil, nil, err } if res.StatusCode != http.StatusOK { - return "", "", nil, fmt.Errorf("unexpected http status code: %d %s", res.StatusCode, string(bodyBytes)) + return nil, nil, fmt.Errorf("unexpected http status code: %d %s", res.StatusCode, string(bodyBytes)) } - var response pollAPIKeyResponse jsonErr := json.Unmarshal(bodyBytes, &response) if jsonErr != nil { - return "", "", nil, jsonErr + return nil, nil, jsonErr } if response.Redeemed { @@ -74,7 +76,7 @@ func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (string account.Settings.Dashboard.DisplayName = response.AccountDisplayName - return response.APIKey, response.PublishableKey, account, nil + return &response, account, nil } count++ @@ -82,5 +84,5 @@ func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (string } - return "", "", nil, errors.New("exceeded max attempts") + return nil, nil, errors.New("exceeded max attempts") } diff --git a/pkg/login/poll_test.go b/pkg/login/poll_test.go index 4497329df..8ba8ea145 100644 --- a/pkg/login/poll_test.go +++ b/pkg/login/poll_test.go @@ -19,15 +19,15 @@ func TestRedeemed(t *testing.T) { atomic.AddUint64(&attempts, 1) - response := &pollAPIKeyResponse{ + response := &PollAPIKeyResponse{ Redeemed: false, } if atomic.LoadUint64(&attempts) == 2 { response.Redeemed = true response.AccountID = "acct_123" response.AccountDisplayName = "test_disp_name" - response.APIKey = "sk_test_123" - response.PublishableKey = "pk_test_123" + response.TestModeAPIKey = "sk_test_123" + response.TestModePublishableKey = "pk_test_123" } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") @@ -35,10 +35,10 @@ func TestRedeemed(t *testing.T) { })) defer ts.Close() - apiKey, publishableKey, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) + response, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) require.NoError(t, err) - require.Equal(t, "sk_test_123", apiKey) - require.Equal(t, "pk_test_123", publishableKey) + require.Equal(t, "sk_test_123", response.TestModeAPIKey) + require.Equal(t, "pk_test_123", response.TestModePublishableKey) require.Equal(t, "acct_123", account.ID) require.Equal(t, "test_disp_name", account.Settings.Dashboard.DisplayName) require.Equal(t, uint64(2), atomic.LoadUint64(&attempts)) @@ -52,14 +52,14 @@ func TestRedeemedNoDisplayName(t *testing.T) { atomic.AddUint64(&attempts, 1) - response := &pollAPIKeyResponse{ + response := &PollAPIKeyResponse{ Redeemed: false, } if atomic.LoadUint64(&attempts) == 2 { response.Redeemed = true response.AccountID = "acct_123" - response.APIKey = "sk_test_123" - response.PublishableKey = "pk_test_123" + response.TestModeAPIKey = "sk_test_123" + response.TestModePublishableKey = "pk_test_123" } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") @@ -67,10 +67,10 @@ func TestRedeemedNoDisplayName(t *testing.T) { })) defer ts.Close() - apiKey, publishableKey, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) + response, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) require.NoError(t, err) - require.Equal(t, "sk_test_123", apiKey) - require.Equal(t, "pk_test_123", publishableKey) + require.Equal(t, "sk_test_123", response.TestModeAPIKey) + require.Equal(t, "pk_test_123", response.TestModePublishableKey) require.Equal(t, "acct_123", account.ID) require.Equal(t, "", account.Settings.Dashboard.DisplayName) require.Equal(t, uint64(2), atomic.LoadUint64(&attempts)) @@ -84,7 +84,7 @@ func TestExceedMaxAttempts(t *testing.T) { atomic.AddUint64(&attempts, 1) - response := pollAPIKeyResponse{ + response := PollAPIKeyResponse{ Redeemed: false, } w.WriteHeader(http.StatusOK) @@ -93,10 +93,9 @@ func TestExceedMaxAttempts(t *testing.T) { })) defer ts.Close() - apiKey, publishableKey, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) + response, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) require.EqualError(t, err, "exceeded max attempts") - require.Empty(t, apiKey) - require.Empty(t, publishableKey) + require.Nil(t, response) require.Empty(t, account) require.Equal(t, uint64(3), atomic.LoadUint64(&attempts)) } @@ -113,10 +112,9 @@ func TestHTTPStatusError(t *testing.T) { })) defer ts.Close() - apiKey, publishableKey, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) + response, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) require.EqualError(t, err, "unexpected http status code: 500 ") - require.Empty(t, apiKey) - require.Empty(t, publishableKey) + require.Nil(t, response) require.Nil(t, account) require.Equal(t, uint64(1), atomic.LoadUint64(&attempts)) } @@ -126,10 +124,9 @@ func TestHTTPRequestError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) ts.Close() - apiKey, publishableKey, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) + response, account, err := PollForKey(ts.URL, 1*time.Millisecond, 3) require.Error(t, err) require.Contains(t, err.Error(), "connect: connection refused") - require.Empty(t, apiKey) - require.Empty(t, publishableKey) + require.Nil(t, response) require.Nil(t, account) } diff --git a/pkg/requests/base.go b/pkg/requests/base.go index 28727510c..18c621cf8 100644 --- a/pkg/requests/base.go +++ b/pkg/requests/base.go @@ -67,7 +67,7 @@ func (rb *Base) RunRequestsCmd(cmd *cobra.Command, args []string) error { return nil } - secretKey, err := rb.Profile.GetAPIKey() + secretKey, err := rb.Profile.GetAPIKey(false) if err != nil { return err }