diff --git a/.gitignore b/.gitignore index 1dca570d..1399f631 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ bin/ +.idea/ md5Sums.txt .DS_Store *.sw[op] diff --git a/cmd/drive/.gitignore b/cmd/drive/.gitignore new file mode 100644 index 00000000..d257bbba --- /dev/null +++ b/cmd/drive/.gitignore @@ -0,0 +1 @@ +/drive diff --git a/drive-gen/.gitignore b/drive-gen/.gitignore new file mode 100644 index 00000000..c56afabf --- /dev/null +++ b/drive-gen/.gitignore @@ -0,0 +1 @@ +/drive-gen diff --git a/go.mod b/go.mod index 43c267c7..70879e73 100644 --- a/go.mod +++ b/go.mod @@ -29,5 +29,6 @@ require ( golang.org/x/crypto v0.0.0-20201217014255-9d1352758620 golang.org/x/net v0.0.0-20201216054612-986b41b23924 golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5 + golang.org/x/sys v0.0.0-20221006211917-84dc82d7e875 // indirect google.golang.org/api v0.36.0 ) diff --git a/go.sum b/go.sum index e022ae75..34f53a58 100644 --- a/go.sum +++ b/go.sum @@ -302,6 +302,8 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3 h1:kzM6+9dur93BcC2kVlYl34cHU+TYZLanmpSJHVMmL64= golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20221006211917-84dc82d7e875 h1:AzgQNqF+FKwyQ5LbVrVqOcuuFB67N47F9+htZYH0wFM= +golang.org/x/sys v0.0.0-20221006211917-84dc82d7e875/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/src/remote.go b/src/remote.go index 1c09b424..f1a80a59 100644 --- a/src/remote.go +++ b/src/remote.go @@ -17,13 +17,16 @@ package drive import ( "fmt" "io" - "math/rand" + "net" "net/http" "net/url" "os" "strings" + "sync" "time" + crand "crypto/rand" + "golang.org/x/net/context" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -39,15 +42,9 @@ import ( ) const ( - // OAuth 2.0 OOB redirect URL for authorization. - RedirectURL = "urn:ietf:wg:oauth:2.0:oob" - // OAuth 2.0 full Drive scope used for authorization. DriveScope = "https://www.googleapis.com/auth/drive" - // OAuth 2.0 access type for offline/refresh access. - AccessType = "offline" - // Google Drive webpage host DriveResourceHostURL = "https://googledrive.com/host/" @@ -181,14 +178,207 @@ func (r *Remote) change(changeId string) (*drive.Change, error) { return r.service.Changes.Get(changeId).Do() } +type loopbackServer struct { + // Authorization codes come here + codeChan <-chan string + // Errors while serving the callback endpoint + serveErrChan <-chan error + // Errors on the listener, including shutdown errors. + listenerErrChan <-chan error + // Signals that the handler is done. + done <-chan struct{} + // Invoke this to begin server shutdown. + stop func() + // The server listens on this endpoint. + redirectURL string + // Auth URL including CSRF token. + authURL string +} + +func startTokenServer(config *oauth2.Config) (*loopbackServer, error) { + var buf [16]uint8 + if _, err := io.ReadFull(crand.Reader, buf[:]); err != nil { + return nil, fmt.Errorf("could not generate random request token: %v", err) + } + randState := fmt.Sprintf("%x", buf) + // We explicitly listen on the loopback device to prevent external access. + // TODO: Can we portably use localhost:0? + listenHost := "127.0.0.1" + listener, err := net.Listen("tcp", fmt.Sprintf("%s:0", listenHost)) + if err != nil { + return nil, err + } + port := listener.Addr().(*net.TCPAddr).Port + redirectURL := fmt.Sprintf("http://%s:%d/", listenHost, port) + // TODO: Consider if we can set/return the redirect URL in a more principled way. + config.RedirectURL = redirectURL + codeChan := make(chan string) + serveErrChan := make(chan error) + listenerErrChan := make(chan error) + + // NOTE: This could equally well be done with context cancellation. + // However, current guidance is to _not_ store contexts (and, presumably, + // their cancel functions) beyond individual requests (and we really only + // need simple cancellation/completion signaling anyway). Instead, we use a + // sync.Once to ensure that the done channel is only closed once. + done, cancel := func() (<-chan struct{}, func()) { + done := make(chan struct{}) + var once sync.Once + cancel := func() { + once.Do(func() { + close(done) + }) + } + return done, cancel + }() + + handleConnection := func(w http.ResponseWriter, r *http.Request) { + alreadyDoneMessage := "Already done. Return to the drive app.\n" + if r.URL.Path != "/" { + // Ignore requests at unexpected paths, e.g. /favicon.ico. + http.NotFound(w, r) + return + } + select { + case <-done: + _, _ = w.Write([]byte(alreadyDoneMessage)) + return + default: + } + + // All channel writes happen in select blocks because they might race + // with the done check above. + requestState := r.FormValue("state") + if requestState != randState { + select { + case serveErrChan <- fmt.Errorf("invalid CSRF token; rerun drive init"): + _, _ = w.Write([]byte("Error: invalid CSRF token.")) + case <-done: + _, _ = w.Write([]byte(alreadyDoneMessage)) + } + return + } + code := r.FormValue("code") + if code == "" { + select { + case serveErrChan <- fmt.Errorf("received empty request code; rerun drive init"): + _, _ = w.Write([]byte("Error: received empty code.")) + case <-done: + _, _ = w.Write([]byte(alreadyDoneMessage)) + } + return + } + + select { + case codeChan <- code: + _, _ = w.Write([]byte("Code received. Return to the drive app.")) + case <-done: + _, _ = w.Write([]byte(alreadyDoneMessage)) + } + } + + server := http.Server{ + Handler: http.HandlerFunc(handleConnection), + } + + // We use sync.Once here because we need to potentially call close on the + // listener error channel in 2 places. + var closeListenerErrChanOnce sync.Once + closeListenerErrChan := func(err error) { + closeListenerErrChanOnce.Do(func() { + listenerErrChan <- err + close(listenerErrChan) + }) + } + go func() { + // Server closer. + <-done + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + err := server.Shutdown(ctx) + if err != nil { + // Usually, we close the error channel below on server exit. + // However, if the Shutdown call hangs and we time out, we want to + // release the main goroutine. To handle the scenario where Shutdown + // times out but the underlying server somehow returns, we guard + // this in a sync.Once. In manual testing, I wasn't able to elicit + // any hangs or errors in the Shutdown call itself, even by + // wrapping the net.Listener in another listener that always returns + // an error from Close. + closeListenerErrChan(err) + } + }() + go func() { + // Listener. + err := server.Serve(listener) + if err == http.ErrServerClosed { + err = nil + } else if err != nil { + // Defensively check for non-nil errors. It's unclear if Serve() can ever + // exit with a nil error. + err = fmt.Errorf("server closed unexpectedly: %v", err) + } + closeListenerErrChan(err) + }() + authURL := config.AuthCodeURL(randState, oauth2.AccessTypeOffline) + return &loopbackServer{ + codeChan: codeChan, + serveErrChan: serveErrChan, + listenerErrChan: listenerErrChan, + done: done, + stop: cancel, + authURL: authURL, + redirectURL: redirectURL, + }, nil +} + +func (s *loopbackServer) RedirectURL() string { + return s.redirectURL +} + +func (s *loopbackServer) AuthURL() string { + return s.authURL +} + +func (s *loopbackServer) GetCode() (string, error) { + select { + case err := <-s.serveErrChan: + return "", err + case code := <-s.codeChan: + return code, nil + case <-s.done: + return "", fmt.Errorf("server already closed") + } +} + +func (s *loopbackServer) Close() error { + s.stop() + return <-s.listenerErrChan +} + +func getCodeViaLoopback(config *oauth2.Config) (string, error) { + server, err := startTokenServer(config) + if err != nil { + return "", err + } + config.RedirectURL = server.RedirectURL() + fmt.Printf("Visit this URL to get an authorization code\n%s\n", server.AuthURL()) + code, err := server.GetCode() + closeErr := server.Close() + if closeErr != nil { + // We already have either a code or root error, so no need to surface this. + fmt.Printf("warning: error closing loopback server: %v\n", closeErr) + } + return code, err +} + func RetrieveRefreshToken(ctx context.Context, context *config.Context) (string, error) { config := newAuthConfig(context) - randState := fmt.Sprintf("%s%v", time.Now(), rand.Uint32()) - url := config.AuthCodeURL(randState, oauth2.AccessTypeOffline) - - fmt.Printf("Visit this URL to get an authorization code\n%s\n", url) - code := prompt(os.Stdin, os.Stdout, "Paste the authorization code: ") + code, err := getCodeViaLoopback(config) + if err != nil { + return "", err + } token, err := config.Exchange(ctx, code) if err != nil { @@ -1207,7 +1397,6 @@ func newAuthConfig(context *config.Context) *oauth2.Config { return &oauth2.Config{ ClientID: context.ClientId, ClientSecret: context.ClientSecret, - RedirectURL: RedirectURL, Endpoint: google.Endpoint, Scopes: []string{DriveScope}, }