Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions cmd/src/login_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,40 @@ import (
"github.com/sourcegraph/src-cli/internal/oauth"
)

var loadStoredOAuthToken = oauth.LoadToken
var (
loadStoredOAuthToken = oauth.LoadToken
storeOAuthToken = oauth.StoreToken
)

func runOAuthLogin(ctx context.Context, p loginParams) error {
client, err := oauthLoginClient(ctx, p)
client, loadedFromStore, err := oauthLoginClient(ctx, p)
if err != nil {
printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err))
fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL))
return cmderrors.ExitCode1
}

if loadedFromStore {
username, validateErr := currentUsername(ctx, client)
if validateErr == nil && username != "" {
printAuthenticatedUser(p.out, username, p.cfg.endpointURL)
fmt.Fprintln(p.out)
fmt.Fprint(p.out, "✔︎ Authenticated with OAuth credentials")
fmt.Fprintln(p.out)
return nil
}

fmt.Fprintln(p.out)
fmt.Fprintln(p.out, "⚠️ Warning: Stored OAuth credentials could not be verified. Starting a new OAuth device flow.")

client, err = newOAuthLoginClient(ctx, p)
if err != nil {
printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err))
fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL))
return cmderrors.ExitCode1
}
}

if err := validateCurrentUser(ctx, client, p.out, p.cfg.endpointURL); err != nil {
return err
}
Expand All @@ -38,18 +62,23 @@ func runOAuthLogin(ctx context.Context, p loginParams) error {
// oauthLoginClient returns a api.Client with the OAuth token set. It will check secret storage for a token
// and use it if one is present.
// If no token is found, it will start a OAuth Device flow to get a token and storage in secret storage.
func oauthLoginClient(ctx context.Context, p loginParams) (api.Client, error) {
// if we have a stored token, used it. Otherwise run the device flow
func oauthLoginClient(ctx context.Context, p loginParams) (api.Client, bool, error) {
// if we have a stored token, use it. Otherwise run the device flow
if token, err := loadStoredOAuthToken(ctx, p.cfg.endpointURL); err == nil {
return newOAuthAPIClient(p, token), nil
return newOAuthAPIClient(p, token), true, nil
}

client, err := newOAuthLoginClient(ctx, p)
return client, false, err
}

func newOAuthLoginClient(ctx context.Context, p loginParams) (api.Client, error) {
token, err := runOAuthDeviceFlow(ctx, p.cfg.endpointURL, p.out, p.oauthClient)
if err != nil {
return nil, err
}

if err := oauth.StoreToken(ctx, token); err != nil {
if err := storeOAuthToken(ctx, token); err != nil {
fmt.Fprintln(p.out)
fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err)
}
Expand Down
93 changes: 92 additions & 1 deletion cmd/src/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,80 @@ func TestLogin(t *testing.T) {
t.Errorf("got output %q, want %q", gotOut, wantOut)
}
})

t.Run("invalid stored oauth token restarts device flow", func(t *testing.T) {
var authHeaders []string
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeaders = append(authHeaders, r.Header.Get("Authorization"))
if r.Header.Get("Authorization") != "Bearer new-oauth-token" {
http.Error(w, "", http.StatusUnauthorized)
return
}
fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`)
}))
defer s.Close()

restoreStoredOAuthLoader(t, func(_ context.Context, _ *url.URL) (*oauth.Token, error) {
return &oauth.Token{
Endpoint: s.URL,
ClientID: oauth.DefaultClientID,
AccessToken: "old-oauth-token",
ExpiresAt: time.Now().Add(time.Hour),
}, nil
})
restoreOAuthTokenStore(t, func(context.Context, *oauth.Token) error { return nil })

u, _ := url.ParseRequestURI(s.URL)
startCalled := false
pollCalled := false
var out bytes.Buffer
err := loginCmd(context.Background(), loginParams{
cfg: &config{endpointURL: u},
client: (&config{endpointURL: u}).apiClient(nil, io.Discard),
out: &out,
oauthClient: fakeOAuthClient{
startCalled: &startCalled,
deviceResp: &oauth.DeviceAuthResponse{
DeviceCode: "device-code",
ExpiresIn: 60,
},
pollCalled: &pollCalled,
pollResp: &oauth.TokenResponse{
AccessToken: "new-oauth-token",
ExpiresIn: 3600,
TokenType: "Bearer",
},
},
})
if err != nil {
t.Fatal(err)
}
if !startCalled || !pollCalled {
t.Fatal("expected invalid stored oauth token to restart device flow")
}
if len(authHeaders) != 2 || authHeaders[0] != "Bearer old-oauth-token" || authHeaders[1] != "Bearer new-oauth-token" {
t.Fatalf("Authorization headers = %q, want old token then new token", authHeaders)
}
gotOut := out.String()
for _, want := range []string{
"⚠️ Warning: Stored OAuth credentials could not be verified. Starting a new OAuth device flow.",
"Waiting for authorization... DONE",
"✔︎ Authenticated as alice on " + s.URL,
"✔︎ Authenticated with OAuth credentials",
} {
if !strings.Contains(gotOut, want) {
t.Errorf("got output %q, want it to contain %q", gotOut, want)
}
}
})
}

type fakeOAuthClient struct {
startErr error
startCalled *bool
deviceResp *oauth.DeviceAuthResponse
pollCalled *bool
pollResp *oauth.TokenResponse
}

func (f fakeOAuthClient) ClientID() string {
Expand All @@ -156,10 +225,22 @@ func (f fakeOAuthClient) Start(context.Context, *url.URL, []string) (*oauth.Devi
if f.startCalled != nil {
*f.startCalled = true
}
return nil, f.startErr
if f.startErr != nil {
return nil, f.startErr
}
if f.deviceResp != nil {
return f.deviceResp, nil
}
return nil, fmt.Errorf("unexpected call to Start")
}

func (f fakeOAuthClient) Poll(context.Context, *url.URL, string, time.Duration, int) (*oauth.TokenResponse, error) {
if f.pollCalled != nil {
*f.pollCalled = true
}
if f.pollResp != nil {
return f.pollResp, nil
}
return nil, fmt.Errorf("unexpected call to Poll")
}

Expand Down Expand Up @@ -242,3 +323,13 @@ func restoreStoredOAuthLoader(t *testing.T, loader func(context.Context, *url.UR
loadStoredOAuthToken = prev
})
}

func restoreOAuthTokenStore(t *testing.T, store func(context.Context, *oauth.Token) error) {
t.Helper()

prev := storeOAuthToken
storeOAuthToken = store
t.Cleanup(func() {
storeOAuthToken = prev
})
}
31 changes: 23 additions & 8 deletions cmd/src/login_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@ func runValidatedLogin(ctx context.Context, p loginParams) error {
}

func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, endpointURL *url.URL) error {
query := `query CurrentUser { currentUser { username } }`
var result struct {
CurrentUser *struct{ Username string }
}
if _, err := client.NewRequest(query, nil).Do(ctx, &result); err != nil {
username, err := currentUsername(ctx, client)
if err != nil {
if strings.HasPrefix(err.Error(), "error: 401 Unauthorized") || strings.HasPrefix(err.Error(), "error: 403 Forbidden") {
printLoginProblem(out, "Invalid access token.")
} else {
Expand All @@ -31,14 +28,32 @@ func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer,
return cmderrors.ExitCode1
}

if result.CurrentUser == nil {
if username == "" {
// This should never happen; we verified there is an access token, so there should always be
// a user.
printLoginProblem(out, fmt.Sprintf("Unable to determine user on %s.", endpointURL))
return cmderrors.ExitCode1
}
printAuthenticatedUser(out, username, endpointURL)
return nil
}

func printAuthenticatedUser(out io.Writer, username string, endpointURL *url.URL) {
fmt.Fprintln(out)
fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointURL)
fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", username, endpointURL)
fmt.Fprintln(out)
return nil
}

func currentUsername(ctx context.Context, client api.Client) (string, error) {
query := `query CurrentUser { currentUser { username } }`
var result struct {
CurrentUser *struct{ Username string }
}
if _, err := client.NewRequest(query, nil).Do(ctx, &result); err != nil {
return "", err
}
if result.CurrentUser == nil {
return "", nil
}
return result.CurrentUser.Username, nil
}
Loading