From 4016f721ced0b4149513079710ae9a4cd87d76e6 Mon Sep 17 00:00:00 2001 From: brandonl-stripe <52111783+brandonl-stripe@users.noreply.github.com> Date: Sat, 6 Jul 2019 00:10:14 -0700 Subject: [PATCH] New login configuration flow (#7) - go formatted cmd directory & make profile public - remove configure command moved under stripe auth --interactive - Create login command - Break cli configuration into own file (config.go) - test files with @jmuia-stripe --- cmd/configure.go | 153 ------------------ cmd/delete.go | 2 +- cmd/get.go | 2 +- cmd/listen.go | 6 +- cmd/login.go | 38 +++++ cmd/login_test.go | 66 ++++++++ cmd/post.go | 3 +- cmd/root.go | 16 +- cmd/root_test.go | 6 +- cmd/trigger.go | 5 +- go.mod | 1 + login/client_login.go | 67 ++++++++ login/interactive_login.go | 95 +++++++++++ .../interactive_login_test.go | 20 +-- login/poll.go | 6 +- profile/config.go | 85 ++++++++++ profile/config_test.go | 115 +++++++++++++ profile/profile.go | 16 +- 18 files changed, 512 insertions(+), 190 deletions(-) delete mode 100644 cmd/configure.go create mode 100644 cmd/login.go create mode 100644 cmd/login_test.go create mode 100644 login/interactive_login.go rename cmd/configure_test.go => login/interactive_login_test.go (75%) create mode 100644 profile/config.go create mode 100644 profile/config_test.go diff --git a/cmd/configure.go b/cmd/configure.go deleted file mode 100644 index 0bbeae4b..00000000 --- a/cmd/configure.go +++ /dev/null @@ -1,153 +0,0 @@ -package cmd - -import ( - "bufio" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "strings" - "syscall" - - "github.com/spf13/cobra" - "github.com/spf13/viper" - "golang.org/x/crypto/ssh/terminal" - - "github.com/stripe/stripe-cli/ansi" - "github.com/stripe/stripe-cli/validators" - - log "github.com/sirupsen/logrus" -) - -type configureCmd struct { - cmd *cobra.Command -} - -func newConfigureCmd() *configureCmd { - cc := &configureCmd{} - - cc.cmd = &cobra.Command{ - Use: "configure", - Args: validators.NoArgs, - Short: "Configure the Stripe CLI", - Long: `Add your Stripe test secret API Key to connect to Stripe. - -By default, this will store the API key in the "default" namespace. You may -optionally provide a project name to store multiple API keys. - -The configure command will also prompt for a device name to identify the -connected computer. This is used to show who is currently connected to the -webhooks tunnel through the Stripe Dashboard. - -Run configuration: -$ stripe configure - -Configure for a specific project: -$ stripe configure --project-name rocket_rides`, - RunE: cc.runConfigureCmd, - } - - return cc -} - -func (cc *configureCmd) runConfigureCmd(cmd *cobra.Command, args []string) error { - configPath := profile.GetConfigFolder(os.Getenv("XDG_CONFIG_HOME")) - dotfilePath := filepath.Join(configPath, "config.toml") - - if _, err := os.Stat(configPath); os.IsNotExist(err) { - err = os.MkdirAll(configPath, os.ModePerm) - if err != nil { - return err - } - } - - apiKey, err := cc.getConfigureAPIKey(os.Stdin) - if err != nil { - return err - } - - deviceName := cc.getConfigureDeviceName(os.Stdin) - - log.WithFields(log.Fields{ - "prefix": "cmd.configureCmd.runConfigureCmd", - "path": dotfilePath, - }).Debug("Writing config file") - - viper.SetConfigType("toml") - viper.SetConfigFile(dotfilePath) - - viper.Set(profile.ProfileName+".secret_key", strings.TrimSpace(apiKey)) - viper.Set("default.device_name", strings.TrimSpace(deviceName)) - err = viper.WriteConfig() - if err != nil { - return err - } - - fmt.Println("You're configured and all set to get started") - - return nil -} - -func (cc *configureCmd) getConfigureAPIKey(input io.Reader) (string, error) { - fmt.Print("Enter your test mode secret API key: ") - apiKey, err := cc.securePrompt(input) - if err != nil { - return "", err - } - apiKey = strings.TrimSpace(apiKey) - if apiKey == "" { - return "", errors.New("API key is required, please provide your test mode secret API key") - } - err = validators.APIKey(apiKey) - if err != nil { - return "", err - } - - fmt.Printf("Your API key is: %s\n", cc.redactAPIKey(apiKey)) - - return apiKey, nil -} - -func (cc *configureCmd) getConfigureDeviceName(input io.Reader) string { - hostName, _ := os.Hostname() - reader := bufio.NewReader(input) - - color := ansi.Color(os.Stdout) - fmt.Printf("How would you like to identify this device in the Stripe Dashboard? [default: %s] ", color.Bold(color.Cyan(hostName))) - - deviceName, _ := reader.ReadString('\n') - if strings.TrimSpace(deviceName) == "" { - deviceName = hostName - } - - return deviceName -} - -// redactAPIKey returns a redacted version of API keys. The first 8 and last 4 -// characters are not redacted, everything else is replaced by "*" characters. -// -// It panics if the provided string has less than 12 characters. -func (cc *configureCmd) redactAPIKey(apiKey string) string { - var b strings.Builder - - b.WriteString(apiKey[0:8]) // #nosec G104 (gosec bug: https://github.com/securego/gosec/issues/267) - b.WriteString(strings.Repeat("*", len(apiKey)-12)) // #nosec G104 (gosec bug: https://github.com/securego/gosec/issues/267) - b.WriteString(apiKey[len(apiKey)-4:]) // #nosec G104 (gosec bug: https://github.com/securego/gosec/issues/267) - - return b.String() -} - -func (cc *configureCmd) securePrompt(input io.Reader) (string, error) { - if input == os.Stdin { - buf, err := terminal.ReadPassword(int(syscall.Stdin)) - if err != nil { - return "", err - } - fmt.Print("\n") - return string(buf), nil - } - - reader := bufio.NewReader(input) - return reader.ReadString('\n') -} diff --git a/cmd/delete.go b/cmd/delete.go index 0efcedbf..6d29cbca 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -14,7 +14,7 @@ func newDeleteCmd() *deleteCmd { gc := &deleteCmd{} gc.reqs.Method = "DELETE" - gc.reqs.Profile = profile + gc.reqs.Profile = Profile gc.reqs.Cmd = &cobra.Command{ Use: "delete", Args: validators.ExactArgs(1), diff --git a/cmd/get.go b/cmd/get.go index 2769f30e..ce0a348f 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -14,7 +14,7 @@ func newGetCmd() *getCmd { gc := &getCmd{} gc.reqs.Method = "GET" - gc.reqs.Profile = profile + gc.reqs.Profile = Profile gc.reqs.Cmd = &cobra.Command{ Use: "get", Args: validators.ExactArgs(1), diff --git a/cmd/listen.go b/cmd/listen.go index 3e459bf9..2187d74f 100644 --- a/cmd/listen.go +++ b/cmd/listen.go @@ -77,14 +77,14 @@ $ stripe listen --events charge.created,charge.updated --forward-to localhost:90 // Normally, this function would be listed alphabetically with the others declared in this file, // but since it's acting as the core functionality for the cmd above, I'm keeping it close. func (lc *listenCmd) runListenCmd(cmd *cobra.Command, args []string) error { - deviceName, err := profile.GetDeviceName() + deviceName, err := Profile.GetDeviceName() if err != nil { return err } endpointsMap := make(map[string][]string) - key, err := profile.GetSecretKey() + key, err := Profile.GetSecretKey() if err != nil { return err } @@ -135,7 +135,7 @@ func (lc *listenCmd) runListenCmd(cmd *cobra.Command, args []string) error { func (lc *listenCmd) getEndpointsFromAPI(secretKey string) requests.WebhookEndpointList { examples := requests.Examples{ - Profile: profile, + Profile: Profile, APIVersion: "2019-03-14", SecretKey: secretKey, } diff --git a/cmd/login.go b/cmd/login.go new file mode 100644 index 00000000..a2f2787c --- /dev/null +++ b/cmd/login.go @@ -0,0 +1,38 @@ +package cmd + +import ( + "github.com/spf13/cobra" + "github.com/stripe/stripe-cli/login" + "github.com/stripe/stripe-cli/validators" +) + +type loginCmd struct { + cmd *cobra.Command + interactive bool + url string +} + +func newLoginCmd() *loginCmd { + lc := &loginCmd{} + + lc.cmd = &cobra.Command{ + Use: "login", + Args: validators.NoArgs, + Short: "Log into your Stripe account", + Long: `Log into your Stripe account to write your configuration file`, + RunE: lc.runLoginCmd, + } + lc.cmd.Flags().BoolVarP(&lc.interactive, "interactive", "i", false, "interactive configuration mode") + lc.cmd.Flags().StringVarP(&lc.url, "url", "u", "", "Testing URL for login ") + lc.cmd.Flags().MarkHidden("url") + + return lc +} + + +func (lc *loginCmd) runLoginCmd(cmd *cobra.Command, args []string) error { + if lc.interactive { + return login.InteractiveLogin(Profile) + } + return login.Login(lc.url, Profile) +} diff --git a/cmd/login_test.go b/cmd/login_test.go new file mode 100644 index 00000000..124b96d8 --- /dev/null +++ b/cmd/login_test.go @@ -0,0 +1,66 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "github.com/stretchr/testify/assert" + "github.com/stripe/stripe-cli/login" + "github.com/stripe/stripe-cli/profile" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + +) + +func TestLogin(t *testing.T) { + configFile := filepath.Join(os.TempDir(), "stripe", "config.toml") + p := profile.Profile{ + Color: "auto", + ConfigFile: configFile, + LogLevel: "info", + ProfileName: "tests", + DeviceName: "st-testing", + } + + + + var pollURL string + var browserURL string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "auth") { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + expectedLinks := login.Links{ + BrowserURL: browserURL, + PollURL: pollURL, + VerificationCode: "dinosaur-pineapple-polkadot", + } + json.NewEncoder(w).Encode(expectedLinks) + } + if strings.Contains(r.URL.Path,"browser") { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "text/html") + w.Write([]byte("")) + + } + if strings.Contains(r.URL.Path,"poll") { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + data := []byte(`{"redeemed": true, "account_id": "acct_123", "testmode_key_secret": "sk_test_1234"}`) + fmt.Println(string(data)) + w.Write(data) + } + })) + defer ts.Close() + + authURL := fmt.Sprintf( "%s%s", ts.URL, "/auth") + pollURL = fmt.Sprintf( "%s%s", ts.URL, "/poll") + browserURL = fmt.Sprintf( "%s%s", ts.URL, "/browser") + + err := login.Login(authURL, p) + assert.NoError(t, err) +} diff --git a/cmd/post.go b/cmd/post.go index 274cc0d8..74587ff8 100644 --- a/cmd/post.go +++ b/cmd/post.go @@ -14,7 +14,7 @@ func newPostCmd() *postCmd { gc := &postCmd{} gc.reqs.Method = "POST" - gc.reqs.Profile = profile + gc.reqs.Profile = Profile gc.reqs.Cmd = &cobra.Command{ Use: "post", Args: validators.ExactArgs(1), @@ -32,6 +32,7 @@ $ stripe post /payment_intents -d amount=2000 -d currency=usd -d payment_method_ RunE: gc.reqs.RunRequestsCmd, } + gc.reqs.InitFlags() return gc diff --git a/cmd/root.go b/cmd/root.go index 61d83353..9cc791a4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -10,7 +10,8 @@ import ( "github.com/stripe/stripe-cli/version" ) -var profile prof.Profile +// Profile is the cli configuration for the user +var Profile prof.Profile // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ @@ -82,20 +83,21 @@ Use "{{.CommandPath}} [command] --help" for more information about a command.{{e } func init() { - cobra.OnInitialize(profile.InitConfig) + cobra.OnInitialize(Profile.InitConfig) rootCmd.PersistentFlags().String("api-key", "", "Your test mode API secret key to use for the command") - rootCmd.PersistentFlags().StringVar(&profile.Color, "color", "auto", "turn on/off color output (on, off, auto)") - rootCmd.PersistentFlags().StringVar(&profile.ConfigFile, "config", "", "config file (default is $HOME/.config/stripe/config.toml)") - rootCmd.PersistentFlags().StringVar(&profile.ProfileName, "project-name", "default", "the project name to read from for config") - rootCmd.PersistentFlags().StringVar(&profile.LogLevel, "log-level", "info", "log level (debug, info, warn, error)") + rootCmd.PersistentFlags().StringVar(&Profile.Color, "color", "auto", "turn on/off color output (on, off, auto)") + rootCmd.PersistentFlags().StringVar(&Profile.ConfigFile, "config", "", "config file (default is $HOME/.config/stripe/config.toml)") + rootCmd.PersistentFlags().StringVar(&Profile.ProfileName, "project-name", "default", "the project name to read from for config") + rootCmd.PersistentFlags().StringVar(&Profile.LogLevel, "log-level", "info", "log level (debug, info, warn, error)") + rootCmd.PersistentFlags().StringVar(&Profile.DeviceName, "device-name", "", "device name") viper.BindPFlag("secret_key", rootCmd.PersistentFlags().Lookup("api-key")) // #nosec G104 viper.SetEnvPrefix("stripe") viper.AutomaticEnv() // read in environment variables that match rootCmd.AddCommand(newCompletionCmd().cmd) - rootCmd.AddCommand(newConfigureCmd().cmd) + rootCmd.AddCommand(newLoginCmd().cmd) rootCmd.AddCommand(newDeleteCmd().reqs.Cmd) rootCmd.AddCommand(newGetCmd().reqs.Cmd) rootCmd.AddCommand(newListenCmd().cmd) diff --git a/cmd/root_test.go b/cmd/root_test.go index e9b00922..1bdfc2f2 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -3,12 +3,12 @@ package cmd import ( "testing" - homedir "github.com/mitchellh/go-homedir" + "github.com/mitchellh/go-homedir" "github.com/stretchr/testify/assert" ) func TestGetPathNoXDG(t *testing.T) { - actual := profile.GetConfigFolder("") + actual := Profile.GetConfigFolder("") expected, err := homedir.Dir() expected += "/.config/stripe" @@ -17,7 +17,7 @@ func TestGetPathNoXDG(t *testing.T) { } func TestGetPathXDG(t *testing.T) { - actual := profile.GetConfigFolder("/some/xdg/path") + actual := Profile.GetConfigFolder("/some/xdg/path") expected := "/some/xdg/path/stripe" assert.Equal(t, actual, expected) diff --git a/cmd/trigger.go b/cmd/trigger.go index 8423705d..0da3e9eb 100644 --- a/cmd/trigger.go +++ b/cmd/trigger.go @@ -50,17 +50,18 @@ Supported events: }, } + return tc } func triggerEvent(event string) error { - secretKey, err := profile.GetSecretKey() + secretKey, err := Profile.GetSecretKey() if err != nil { return err } examples := requests.Examples{ - Profile: profile, + Profile: Profile, APIUrl: stripeURL, APIVersion: apiVersion, SecretKey: secretKey, diff --git a/go.mod b/go.mod index ebf71825..8f2ea819 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/stripe/stripe-cli go 1.12 require ( + github.com/BurntSushi/toml v0.3.1 github.com/briandowns/spinner v0.0.0-20190319032542-ac46072a5a91 github.com/fatih/color v1.7.0 // indirect github.com/golang/protobuf v1.3.1 // indirect diff --git a/login/client_login.go b/login/client_login.go index d934fbca..adfa0380 100644 --- a/login/client_login.go +++ b/login/client_login.go @@ -3,10 +3,14 @@ package login import ( "encoding/json" "fmt" + "github.com/stripe/stripe-cli/profile" "github.com/stripe/stripe-cli/stripeauth" + "github.com/stripe/stripe-cli/validators" "io/ioutil" "net/http" "net/url" + "os/exec" + "runtime" ) const stripeCLIAuthURL = "https://dashboard.stripe.com/stripecli/auth" @@ -18,6 +22,69 @@ type Links struct { VerificationCode string `json:"verification_code"` } +//TODO +/* +4. Observability and associated alerting? Business metrics (how many users use this flow)? +5. Rate limiting for each operation? +6. Audit trail for key generation +7. Move configuration changes to profile package +*/ + +// Login function is used to obtain credentials via stripe dashboard. +func Login(url string, profile profile.Profile) error { + + links, err := getLinks(url, profile.DeviceName) + if err != nil { + return err + } + + fmt.Printf("Opening login link %s in your browser.\nVerification code is %s\n", links.BrowserURL, links.VerificationCode) + + urlErr := openBrowser(links.BrowserURL) + if urlErr != nil { + return urlErr + } + + //Call poll function + apiKey, err := PollForKey(links.PollURL, 0, 0) + if err != nil { + return err + } + + validateErr := validators.APIKey(apiKey) + if validateErr != nil { + return validateErr + } + + configErr := profile.ConfigureProfile(apiKey) + if configErr != nil { + return configErr + } + + return nil +} + +func openBrowser(url string) error { + var err error + + switch runtime.GOOS { + case "linux": + err = exec.Command("xdg-open", url).Start() + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + default: + err = fmt.Errorf("unsupported platform") + } + + if err != nil { + return err + } + + return nil +} + func getLinks(authURL string, deviceName string) (*Links, error) { client := stripeauth.NewHTTPClient("") diff --git a/login/interactive_login.go b/login/interactive_login.go new file mode 100644 index 00000000..0db3009c --- /dev/null +++ b/login/interactive_login.go @@ -0,0 +1,95 @@ +package login + +import ( + "bufio" + "errors" + "fmt" + "github.com/stripe/stripe-cli/ansi" + "github.com/stripe/stripe-cli/profile" + "github.com/stripe/stripe-cli/validators" + "golang.org/x/crypto/ssh/terminal" + "io" + "os" + "strings" + "syscall" +) + +// InteractiveLogin lets the user set configuration on the command line +func InteractiveLogin(profile profile.Profile) error { + + apiKey, err := getConfigureAPIKey(os.Stdin) + if err != nil { + return err + } + + profile.DeviceName = getConfigureDeviceName(os.Stdin) + + configErr := profile.ConfigureProfile(apiKey) + if configErr != nil { + return configErr + } + + return nil +} + +func getConfigureAPIKey(input io.Reader) (string, error) { + apiKey, err := securePrompt(input) + if err != nil { + return "", err + } + apiKey = strings.TrimSpace(apiKey) + if apiKey == "" { + return "", errors.New("API key is required, please provide your test mode secret API key") + } + err = validators.APIKey(apiKey) + if err != nil { + return "", err + } + + fmt.Printf("Your API key is: %s\n", redactAPIKey(apiKey)) + + return apiKey, nil +} + +func getConfigureDeviceName(input io.Reader) string { + hostName, _ := os.Hostname() + reader := bufio.NewReader(input) + + color := ansi.Color(os.Stdout) + fmt.Printf("How would you like to identify this device in the Stripe Dashboard? [default: %s] ", color.Bold(color.Cyan(hostName))) + + deviceName, _ := reader.ReadString('\n') + if strings.TrimSpace(deviceName) == "" { + deviceName = hostName + } + + return deviceName +} + +// redactAPIKey returns a redacted version of API keys. The first 8 and last 4 +// characters are not redacted, everything else is replaced by "*" characters. +// +// It panics if the provided string has less than 12 characters. +func redactAPIKey(apiKey string) string { + var b strings.Builder + + b.WriteString(apiKey[0:8]) // #nosec G104 (gosec bug: https://github.com/securego/gosec/issues/267) + b.WriteString(strings.Repeat("*", len(apiKey)-12)) // #nosec G104 (gosec bug: https://github.com/securego/gosec/issues/267) + b.WriteString(apiKey[len(apiKey)-4:]) // #nosec G104 (gosec bug: https://github.com/securego/gosec/issues/267) + + return b.String() +} + +func securePrompt(input io.Reader) (string, error) { + if input == os.Stdin { + buf, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + return "", err + } + fmt.Print("\n") + return string(buf), nil + } + + reader := bufio.NewReader(input) + return reader.ReadString('\n') +} diff --git a/cmd/configure_test.go b/login/interactive_login_test.go similarity index 75% rename from cmd/configure_test.go rename to login/interactive_login_test.go index c76d690b..b6a132e2 100644 --- a/cmd/configure_test.go +++ b/login/interactive_login_test.go @@ -1,19 +1,17 @@ -package cmd +package login import ( + "github.com/stretchr/testify/assert" "os" "strings" "testing" - - "github.com/stretchr/testify/assert" ) func TestAPIKeyInput(t *testing.T) { expectedKey := "sk_test_foo1234" - cc := newConfigureCmd() keyInput := strings.NewReader(expectedKey + "\n") - actualKey, err := cc.getConfigureAPIKey(keyInput) + actualKey, err := getConfigureAPIKey(keyInput) assert.Equal(t, expectedKey, actualKey) assert.Nil(t, err) @@ -23,9 +21,8 @@ func TestAPIKeyInputEmpty(t *testing.T) { expectedKey := "" expectedErrorString := "API key is required, please provide your test mode secret API key" - cc := newConfigureCmd() keyInput := strings.NewReader(expectedKey + "\n") - actualKey, err := cc.getConfigureAPIKey(keyInput) + actualKey, err := getConfigureAPIKey(keyInput) assert.Equal(t, expectedKey, actualKey) assert.NotNil(t, err) @@ -37,9 +34,8 @@ func TestAPIKeyInputLivemode(t *testing.T) { livemodeKey := "sk_live_foo123" expectedErrorString := "the CLI only supports using a test mode key" - cc := newConfigureCmd() keyInput := strings.NewReader(livemodeKey + "\n") - actualKey, err := cc.getConfigureAPIKey(keyInput) + actualKey, err := getConfigureAPIKey(keyInput) assert.Equal(t, expectedKey, actualKey) assert.NotNil(t, err) @@ -50,8 +46,7 @@ func TestDeviceNameInput(t *testing.T) { expectedDeviceName := "Bender's Laptop" deviceNameInput := strings.NewReader(expectedDeviceName) - cc := newConfigureCmd() - actualDeviceName := cc.getConfigureDeviceName(deviceNameInput) + actualDeviceName := getConfigureDeviceName(deviceNameInput) assert.Equal(t, expectedDeviceName, actualDeviceName) } @@ -60,8 +55,7 @@ func TestDeviceNameAutoDetect(t *testing.T) { hostName, _ := os.Hostname() deviceNameInput := strings.NewReader("") - cc := newConfigureCmd() - actualDeviceName := cc.getConfigureDeviceName(deviceNameInput) + actualDeviceName := getConfigureDeviceName(deviceNameInput) assert.Equal(t, hostName, actualDeviceName) } diff --git a/login/poll.go b/login/poll.go index 4928bd66..61d3423a 100644 --- a/login/poll.go +++ b/login/poll.go @@ -3,6 +3,7 @@ package login import ( "encoding/json" "errors" + "fmt" "github.com/stripe/stripe-cli/stripeauth" "io/ioutil" @@ -16,7 +17,7 @@ const intervalDefault = 1 * time.Second type pollAPIKeyResponse struct { Redeemed bool `json:"redeemed"` AccountID string `json:"account_id"` - APIKey string `json:"api_key"` + APIKey string `json:"testmode_key_secret"` } // PollForKey polls Stripe at the specified interval until either the API key is available or we've reached the max attempts. @@ -31,7 +32,7 @@ func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (string client := stripeauth.NewHTTPClient("") - var count = 0 + var count= 0 for count < maxAttempts { res, err := client.Get(pollURL) if err != nil { @@ -59,6 +60,7 @@ func PollForKey(pollURL string, interval time.Duration, maxAttempts int) (string count++ time.Sleep(interval) + } return "", errors.New("exceeded max attempts") diff --git a/profile/config.go b/profile/config.go new file mode 100644 index 00000000..f451436a --- /dev/null +++ b/profile/config.go @@ -0,0 +1,85 @@ +package profile + +import ( + "bytes" + "fmt" + "github.com/BurntSushi/toml" + "github.com/spf13/viper" + "os" + "path/filepath" + "strings" +) + +// ConfigureProfile creates a profile when logging in +func (p *Profile) ConfigureProfile(apiKey string) error { + runtimeViper, removeErr := removeKey(viper.GetViper(), "secret_key") + if removeErr != nil { + return removeErr + } + + writeErr := p.writeConfig(runtimeViper, apiKey) + if writeErr != nil { + return writeErr + } + + fmt.Println("You're configured and all set to get started") + + return nil +} + +func (p *Profile) writeConfig(runtimeViper *viper.Viper, apiKey string) error { + configFile := viper.ConfigFileUsed() + + err := makePath(configFile) + if err != nil { + return err + } + + runtimeViper.SetConfigFile(configFile) + + // Ensure we preserve the config file type + runtimeViper.SetConfigType(filepath.Ext(configFile)) + + runtimeViper.Set(p.ProfileName +".device_name", strings.TrimSpace(p.DeviceName)) + runtimeViper.Set(p.ProfileName +".secret_key", strings.TrimSpace(apiKey)) + + runtimeViper.MergeInConfig() + runtimeViper.WriteConfig() + + return nil +} + +// Temporary workaround until https://github.com/spf13/viper/pull/519 can remove a key from viper +func removeKey(v *viper.Viper, key string) (*viper.Viper, error) { + configMap := v.AllSettings() + + delete(configMap, key) + + buf := new(bytes.Buffer) + encodeErr := toml.NewEncoder(buf).Encode(configMap) + if encodeErr != nil { + return nil, encodeErr + } + + nv := viper.New() + nv.SetConfigType("toml") // hint to viper that we've encoded the data as toml + + err := nv.ReadConfig(buf) + if err != nil { + return nil, err + } + + return nv, nil +} + +func makePath(path string) error { + dir := filepath.Dir(path) + + if _, err := os.Stat(dir); os.IsNotExist(err) { + err = os.MkdirAll(dir, os.ModePerm) + if err != nil { + return err + } + } + return nil +} diff --git a/profile/config_test.go b/profile/config_test.go new file mode 100644 index 00000000..1e3b2217 --- /dev/null +++ b/profile/config_test.go @@ -0,0 +1,115 @@ +package profile + +import ( + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "io/ioutil" + "os" + "path/filepath" + "testing" +) + + +func TestRemoveKey(t *testing.T) { + v := viper.New() + v.Set("remove", "me") + v.Set("stay", "here") + + nv, err := removeKey(v, "remove") + assert.NoError(t, err) + + assert.EqualValues(t, []string{"stay"}, nv.AllKeys()) + assert.ElementsMatch(t, []string{"stay", "remove"}, v.AllKeys()) +} + + +func TestWriteConfig(t *testing.T) { + configFile := filepath.Join(os.TempDir(), "stripe", "config.toml") + p := &Profile{ + Color: "auto", + ConfigFile: configFile, + LogLevel: "info", + ProfileName: "tests", + DeviceName: "st-testing", + } + + p.InitConfig() + + apiKey := "sk_test_123" + v := viper.New() + + err := p.writeConfig(v, apiKey) + assert.NoError(t, err) + + assert.FileExists(t, p.ConfigFile) + + configValues := helperLoadBytes(t, p.ConfigFile) + expectedConfig := ` +[tests] + device_name = "st-testing" + secret_key = "sk_test_123" +` + assert.EqualValues(t, expectedConfig, string(configValues)) + + cleanUp(p.ConfigFile) +} + + +func TestWriteConfigMerge(t *testing.T) { + configFile := filepath.Join(os.TempDir(), "stripe", "config.toml") + p := &Profile{ + Color: "auto", + ConfigFile: configFile, + LogLevel: "info", + ProfileName: "tests", + DeviceName: "st-testing", + } + p.InitConfig() + v := viper.New() + writeErr := writeFile(v, p) + assert.NoError(t, writeErr) + assert.FileExists(t, p.ConfigFile) + + p.ProfileName = "tests-merge" + writeErrTwo := writeFile(v, p) + assert.NoError(t, writeErrTwo) + assert.FileExists(t, p.ConfigFile) + + configValues := helperLoadBytes(t, p.ConfigFile) + expectedConfig := ` +[tests] + device_name = "st-testing" + secret_key = "sk_test_123" + +[tests-merge] + device_name = "st-testing" + secret_key = "sk_test_123" +` + + assert.EqualValues(t, expectedConfig, string(configValues)) + + cleanUp(p.ConfigFile) + +} + +func writeFile(v *viper.Viper, p *Profile) error { + apiKey := "sk_test_123" + + err := p.writeConfig(v, apiKey) + + return err + +} + +func helperLoadBytes(t *testing.T, name string) []byte { + bytes, err := ioutil.ReadFile(name) + if err != nil { + t.Fatal(err) + } + return bytes +} + + +func cleanUp(file string) { + os.Remove(file) +} diff --git a/profile/profile.go b/profile/profile.go index 0f826b00..3f79b408 100644 --- a/profile/profile.go +++ b/profile/profile.go @@ -22,6 +22,7 @@ type Profile struct { ConfigFile string LogLevel string ProfileName string + DeviceName string } // GetDeviceName returns the configured device name @@ -126,11 +127,10 @@ func (p *Profile) InitConfig() { // Use config file from the flag. viper.SetConfigFile(p.ConfigFile) } else { + configFolder := p.GetConfigFolder(os.Getenv("XDG_CONFIG_HOME")) + configFile := filepath.Join(configFolder, "config.toml") viper.SetConfigType("toml") - // Search config in home directory or xdg path with name "config.toml". - viper.AddConfigPath(p.GetConfigFolder(os.Getenv("XDG_CONFIG_HOME"))) - // TODO(tomer) - support overriding with configs in local dir - viper.SetConfigName("config") + viper.SetConfigFile(configFile) } // If a config file is found, read it in. @@ -140,4 +140,12 @@ func (p *Profile) InitConfig() { "path": viper.ConfigFileUsed(), }).Debug("Using config file") } + + if p.DeviceName == "" { + deviceName, err := os.Hostname() + if err != nil { + deviceName = "unknown" + } + p.DeviceName = deviceName + } }