Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 30 additions & 88 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"

"connectrpc.com/connect"
"github.com/google/uuid"
Expand All @@ -26,9 +24,6 @@ import (
"golang.org/x/oauth2"
)

const Auth0ClientId = "j3LylZtIosVPZtouKI8WuVHmE6Lluva1"
const Auth0Domain = "om-prod.eu.auth0.com"

var logLevel string

//go:generate sh -c "echo -n $(git describe --tags --long) > commit.txt"
Expand Down Expand Up @@ -147,6 +142,7 @@ func readLocalToken(homeDir string, expectedScopes []string) (string, []string,
}
}

log.Debugf("Using local token from %v", path)
return token.AccessToken, currentScopes, nil
}

Expand Down Expand Up @@ -204,92 +200,28 @@ func ensureToken(ctx context.Context, requiredScopes []string) (context.Context,
// keep replacing it
requestScopes := append(requiredScopes, localScopes...)

// Authenticate using the oauth resource owner password flow
// Authenticate using the oauth device authorization flow
config := oauth2.Config{
ClientID: Auth0ClientId,
Scopes: requestScopes,
ClientID: viper.GetString("cli-auth0-client-id"),
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("https://%v/authorize", Auth0Domain),
TokenURL: fmt.Sprintf("https://%v/oauth/token", Auth0Domain),
AuthURL: fmt.Sprintf("https://%v/authorize", viper.GetString("cli-auth0-domain")),
TokenURL: fmt.Sprintf("https://%v/oauth/token", viper.GetString("cli-auth0-domain")),
DeviceAuthURL: fmt.Sprintf("https://%v/oauth/device/code", viper.GetString("cli-auth0-domain")),
},
RedirectURL: "http://127.0.0.1:7837/oauth/callback",
Scopes: requestScopes,
}

tokenChan := make(chan *oauth2.Token, 1)
// create a random token for this exchange
oAuthStateString := uuid.New().String()

// Start the web server to listen for the callback
handler := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

queryParts, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
log.WithContext(ctx).WithError(err).WithFields(log.Fields{
"url": r.URL,
}).Error("Failed to parse url")
}

// Use the authorization code that is pushed to the redirect
// URL.
code := queryParts["code"][0]
log.WithContext(ctx).Debugf("Got code: %v", code)

state := queryParts["state"][0]
log.WithContext(ctx).Debugf("Got state: %v", state)

if state != oAuthStateString {
log.WithContext(ctx).Errorf("Invalid state, expected %v, got %v", oAuthStateString, state)
return
}

// Exchange will do the handshake to retrieve the initial access token.
log.WithContext(ctx).Debug("Exchanging code for token")
tok, err := config.Exchange(ctx, code)
if err != nil {
log.WithContext(ctx).Error(err)
return
}
log.WithContext(ctx).Debug("Got token")

tokenChan <- tok

// show success page
msg := "<p><strong>Success!</strong></p>"
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"
fmt.Fprint(w, msg)
deviceCode, err := config.DeviceAuth(ctx, oauth2.SetAuthURLParam("audience", "https://api.overmind.tech"))
if err != nil {
return ctx, fmt.Errorf("error getting device code: %w", err)
}

audienceOption := oauth2.SetAuthURLParam("audience", "https://api.overmind.tech")

u := config.AuthCodeURL(oAuthStateString, oauth2.AccessTypeOnline, audienceOption)
log.WithContext(ctx).Infof("Follow this link to authenticate: %v", Underline.TextStyle(u))
fmt.Printf("Go to %v and verify this code: %v\n", deviceCode.VerificationURIComplete, deviceCode.UserCode)

// Start the webserver
log.WithContext(ctx).Trace("Starting webserver to listen for callback, press Ctrl+C to cancel")
srv := &http.Server{Addr: ":7837", ReadHeaderTimeout: 30 * time.Second}
http.HandleFunc("/oauth/callback", handler)

go func() {
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
// unexpected error. port in use?
log.WithContext(ctx).Errorf("HTTP Server error: %v", err)
}
}()

// Wait for the token or cancel
var token *oauth2.Token
select {
case token = <-tokenChan:
// Keep working
case <-ctx.Done():
return ctx, ctx.Err()
}

// Stop the server
err = srv.Shutdown(ctx)
token, err := config.DeviceAccessToken(ctx, deviceCode)
if err != nil {
log.WithContext(ctx).WithError(err).Warn("failed to shutdown auth callback server, but continuing anyway")
fmt.Printf(": %v\n", err)
return ctx, fmt.Errorf("Error exchanging Device Code for for access token: %w", err)
}

// Check that we actually got the claims we asked for. If you don't have
Expand Down Expand Up @@ -438,15 +370,17 @@ func init() {
log.WithError(err).Fatal("could not bind api key to env")
}

// tracing
// internal configs
rootCmd.PersistentFlags().String("cli-auth0-client-id", "QMfjMww3x4QTpeXiuRtMV3JIQkx6mZa4", "OAuth Client ID to use when connecting with auth0")
rootCmd.PersistentFlags().String("cli-auth0-domain", "om-prod.eu.auth0.com", "Auth0 domain to connect to")
rootCmd.PersistentFlags().String("honeycomb-api-key", "", "If specified, configures opentelemetry libraries to submit traces to honeycomb. This requires --otel to be set.")
// Mark this as hidden. This means that it will still be parsed of supplied,

// Mark these as hidden. This means that it will still be parsed of supplied,
// and we will still look for it in the environment, but it won't be shown
// in the help
err = rootCmd.PersistentFlags().MarkHidden("honeycomb-api-key")
if err != nil {
log.WithError(err).Fatal("could not mark `honeycomb-api-key` flag as hidden")
}
must(rootCmd.PersistentFlags().MarkHidden("cli-auth0-client-id"))
must(rootCmd.PersistentFlags().MarkHidden("cli-auth0-domain"))
must(rootCmd.PersistentFlags().MarkHidden("honeycomb-api-key"))

// Create groups
rootCmd.AddGroup(&cobra.Group{
Expand Down Expand Up @@ -502,3 +436,11 @@ func initConfig() {
viper.SetEnvKeyReplacer(replacer)
viper.AutomaticEnv() // read in environment variables that match
}

// must panics if the passed in error is not nil
// use this for init-time error checking of viper/cobra stuff that sometimes errors if the flag does not exist
func must(err error) {
if err != nil {
panic(fmt.Errorf("error initialising: %w", err))
}
}