/
main.go
107 lines (87 loc) · 2.2 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package main
import (
"context"
"flag"
"fmt"
"os"
"os/exec"
"syscall"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts"
)
const usage = `NAME:
usurp - Temporary AWS role assumption shim
USAGE:
usage: usurp <role arn> <command>
`
type Credentials struct {
AccessKeyId string
SecretAccessKey string
SessionToken string
Expiration time.Time
}
func abort(status int, message interface{}) {
fmt.Fprintf(os.Stderr, "ERROR: %s\n", message)
os.Exit(status)
}
func assumeRole(roleArn string) (Credentials, error) {
user, exists := os.LookupEnv("USER")
if !exists {
user = "usurp-user"
}
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return Credentials{}, fmt.Errorf("cant get aws config: %w", err)
}
stsClient := sts.NewFromConfig(cfg)
o, err := stsClient.AssumeRole(
context.Background(),
&sts.AssumeRoleInput{
RoleArn: aws.String(roleArn),
RoleSessionName: aws.String(user),
},
)
if err != nil {
return Credentials{}, fmt.Errorf("cant assume role %s: %w", roleArn, err)
}
return Credentials{
AccessKeyId: *o.Credentials.AccessKeyId,
SecretAccessKey: *o.Credentials.SecretAccessKey,
SessionToken: *o.Credentials.SessionToken,
Expiration: *o.Credentials.Expiration,
},
nil
}
func runCommand(creds Credentials, command []string) {
commandPath, err := exec.LookPath(command[0])
if err != nil {
abort(1, err)
}
os.Setenv("AWS_ACCESS_KEY_ID", creds.AccessKeyId)
os.Setenv("AWS_SECRET_ACCESS_KEY", creds.SecretAccessKey)
os.Setenv("AWS_SESSION_TOKEN", creds.SessionToken)
err = syscall.Exec(commandPath, command, os.Environ())
if err != nil {
abort(1, err)
}
}
func main() {
var help bool
flag.BoolVar(&help, "h", false, "show program help")
flag.Parse()
if help || flag.NArg() < 2 {
fmt.Println(usage)
os.Exit(64)
}
roleArn := flag.Arg(0)
command := flag.Args()[1:]
fmt.Fprintf(os.Stderr, "💅 Assuming role: %s\n", roleArn)
creds, err := assumeRole(roleArn)
if err != nil {
abort(1, err)
}
fmt.Fprintf(os.Stderr, "⏲ Session expires: %s\n", creds.Expiration.Format(time.RFC1123))
runCommand(creds, command)
}