diff --git a/cmd/rig/cmd/auth/login.go b/cmd/rig/cmd/auth/login.go index 715c6ab4b..a9c8afbd5 100644 --- a/cmd/rig/cmd/auth/login.go +++ b/cmd/rig/cmd/auth/login.go @@ -2,6 +2,7 @@ package auth import ( "context" + "fmt" "github.com/bufbuild/connect-go" "github.com/rigdev/rig-go-api/api/v1/authentication" @@ -10,41 +11,13 @@ import ( "github.com/rigdev/rig/cmd/common" "github.com/rigdev/rig/cmd/rig/cmd/base" "github.com/rigdev/rig/pkg/auth" + "github.com/rigdev/rig/pkg/errors" "github.com/rigdev/rig/pkg/uuid" "github.com/spf13/cobra" ) func AuthLogin(ctx context.Context, cmd *cobra.Command, client rig.Client, cfg *base.Config) error { - var identifier *model.UserIdentifier - var err error - if authUserIdentifier == "" { - identifier, err = common.PromptUserIndentifier() - if err != nil { - return err - } - } else { - identifier, err = common.ParseUserIdentifier(authUserIdentifier) - } - - if authPassword == "" { - pw, err := common.GetPasswordPrompt("Enter Password") - if err != nil { - return err - } - authPassword = string(pw) - } - - res, err := client.Authentication().Login(ctx, &connect.Request[authentication.LoginRequest]{ - Msg: &authentication.LoginRequest{ - Method: &authentication.LoginRequest_UserPassword{ - UserPassword: &authentication.UserPassword{ - Identifier: identifier, - Password: authPassword, - ProjectId: auth.RigProjectID.String(), - }, - }, - }, - }) + res, err := loginWithRetry(ctx, client, authUserIdentifier, authPassword, auth.RigProjectID.String()) if err != nil { return err } @@ -65,3 +38,59 @@ func AuthLogin(ctx context.Context, cmd *cobra.Command, client rig.Client, cfg * return nil } + +func loginWithRetry(ctx context.Context, client rig.Client, identifierStr, password, project string) (*connect.Response[authentication.LoginResponse], error) { + shouldPromptIdentifier := identifierStr == "" + shouldPromptPassword := password == "" + var identifier *model.UserIdentifier + for { + var err error + if shouldPromptIdentifier { + identifier, err = common.PromptUserIndentifier() + } else if identifier == nil { + identifier, err = common.ParseUserIdentifier(authUserIdentifier) + } + if err != nil { + return nil, err + } + + if shouldPromptPassword { + password, err = common.GetPasswordPrompt("Enter Password") + if err != nil { + return nil, err + } + } + + res, err := client.Authentication().Login(ctx, &connect.Request[authentication.LoginRequest]{ + Msg: &authentication.LoginRequest{ + Method: &authentication.LoginRequest_UserPassword{ + UserPassword: &authentication.UserPassword{ + Identifier: identifier, + Password: password, + ProjectId: project, + }, + }, + }, + }) + if err == nil { + return res, nil + } + + if errors.IsNotFound(err) { + if !shouldPromptIdentifier { + return nil, err + } + fmt.Println("User not found") + continue + } + + if errors.IsUnauthenticated(err) { + if !shouldPromptPassword { + return nil, err + } + shouldPromptIdentifier = false + fmt.Println("Wrong password") + continue + } + } +}