/
oauth.go
131 lines (118 loc) · 3.83 KB
/
oauth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package sshizzleagent
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"os/exec"
"os/signal"
"runtime"
"time"
"github.com/google/uuid"
"github.com/thalesgroup/sshizzle/internal/config"
"golang.org/x/oauth2"
)
// Authenticate takes an OAuth2 token and validates it. If invalid, it attempts to authenticate and renew
func Authenticate(token *oauth2.Token, config *oauth2.Config) (*oauth2.Token, error) {
// Check if the token we already have is valid, if not, fetch a new one
if !token.Valid() {
var server *http.Server = &http.Server{}
// Create a state for later validation
state := uuid.New().String()
// Start the calback handler
go func() {
mux := http.NewServeMux()
mux.HandleFunc("/callback", handleLoginCallback(token, config, state))
server = &http.Server{Addr: ":8080", Handler: mux}
if err := server.ListenAndServe(); err != nil {
log.Println(fmt.Errorf("error in callback listener: %s", err.Error()))
}
}()
// Close the callback handler on function return
defer server.Close()
// Get the URL required for the user to authenticate
url := config.AuthCodeURL(state, oauth2.AccessTypeOffline)
// Try to open the URL in the browser
err := openURL(url)
if err != nil {
// Otherwise dump the URL to stdout as a prompt
log.Printf("Failed to open browser. Please visit this URL and sign in:\n\n%s\n\n", url)
log.Println("Waiting up to 60s for authentication...")
}
// Catch interrupt/kill signals to exit nicely
sigs := make(chan os.Signal)
signal.Notify(sigs, os.Interrupt, os.Kill)
// Take the current time
now := time.Now()
// Wait for the login callback to succeed or timeout
for !token.Valid() {
// Check if we've been waiting longer than 60s
if time.Now().Unix() > now.Add(60*time.Second).Unix() {
return nil, fmt.Errorf("azure AD authentication timed out after 60s, try again")
}
// Check for user input/interrupt
select {
case s := <-sigs:
return nil, fmt.Errorf("Cancelled, signal received: %s", s.String())
default:
time.Sleep(time.Second * 1)
}
}
}
return token, nil
}
// Handler for the response from requesting a new token
func handleLoginCallback(token *oauth2.Token, oauth *oauth2.Config, state string) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
// We'll output some basic HTML to the user, so set the header accordingly
w.Header().Set("Content-Type", "text/html; charset=utf-8")
// Check the authcode matches the last login request
if r.FormValue("state") != state {
fmt.Fprintf(w, "<h3>Oops, that didn't work! AuthCode invalid!</h3>")
return
}
// Retreive the auth code from the response
code := r.FormValue("code")
// Attempt to exchange the auth code for a new token
newToken, err := oauth.Exchange(context.Background(), code)
if err != nil {
fmt.Fprintf(w, "<h3>Oops, that didn't work! Try again?</h3>")
return
}
fmt.Fprintf(w, "<h3>Success! You can close this window now!</h3>")
// Update the original token value
*token = *newToken
// Convert the response to [nice, indented] JSON
json, err := json.MarshalIndent(token, "", " ")
if err != nil {
return
}
// Get path to the token cache
tokenFile, err := config.GetSSHizzleTokenFile()
if err == nil {
// If we got the path successfully, try to write the file
if err = ioutil.WriteFile(tokenFile, json, 0600); err != nil {
log.Printf("unable to update token cache at %s\n", tokenFile)
}
}
}
}
// openURL attempts to open a URL in a browser in an OS-agnostic way
// #nosec
func openURL(url string) (err error) {
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
default:
err = fmt.Errorf("unsupported platform")
}
if err != nil {
return err
}
return nil
}