/
msi.go
145 lines (118 loc) · 3.78 KB
/
msi.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package azure
import (
"encoding/json"
"io"
"net/http"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/spire/pkg/common/agentpathtemplate"
"github.com/spiffe/spire/pkg/common/idutil"
"github.com/zeebo/errs"
)
const (
// DefaultMSIResourceID is the default resource ID to use as the intended
// audience of the MSI token. The current value is the service ID for the
// Resource Manager API.
DefaultMSIResourceID = "https://management.azure.com/"
PluginName = "azure_msi"
)
// DefaultAgentPathTemplate is the default text/template
var DefaultAgentPathTemplate = agentpathtemplate.MustParse("/{{ .PluginName }}/{{ .TenantID }}/{{ .PrincipalID }}")
type ComputeMetadata struct {
Name string `json:"name"`
SubscriptionID string `json:"subscriptionId"`
ResourceGroupName string `json:"resourceGroupName"`
}
type InstanceMetadata struct {
Compute ComputeMetadata `json:"compute"`
}
type MSIAttestationData struct {
Token string `json:"token"`
}
type MSITokenClaims struct {
jwt.Claims
TenantID string `json:"tid,omitempty"`
PrincipalID string `json:"sub,omitempty"`
}
type HTTPClient interface {
Do(*http.Request) (*http.Response, error)
}
type HTTPClientFunc func(*http.Request) (*http.Response, error)
func (fn HTTPClientFunc) Do(req *http.Request) (*http.Response, error) {
return fn(req)
}
func FetchMSIToken(cl HTTPClient, resource string) (string, error) {
req, err := http.NewRequest("GET", "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01", nil)
if err != nil {
return "", errs.Wrap(err)
}
req.Header.Add("Metadata", "true")
q := req.URL.Query()
q.Set("resource", resource)
req.URL.RawQuery = q.Encode()
resp, err := cl.Do(req)
if err != nil {
return "", errs.Wrap(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body))
}
r := struct {
AccessToken string `json:"access_token"`
}{}
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
return "", errs.New("unable to decode response: %v", err)
}
if r.AccessToken == "" {
return "", errs.New("response missing access token")
}
return r.AccessToken, nil
}
func FetchInstanceMetadata(cl HTTPClient) (*InstanceMetadata, error) {
req, err := http.NewRequest("GET", "http://169.254.169.254/metadata/instance?api-version=2017-08-01&format=json", nil)
if err != nil {
return nil, errs.Wrap(err)
}
req.Header.Add("Metadata", "true")
resp, err := cl.Do(req)
if err != nil {
return nil, errs.Wrap(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body))
}
metadata := new(InstanceMetadata)
if err := json.NewDecoder(resp.Body).Decode(metadata); err != nil {
return nil, errs.New("unable to decode response: %v", err)
}
switch {
case metadata.Compute.Name == "":
return nil, errs.New("response missing instance name")
case metadata.Compute.SubscriptionID == "":
return nil, errs.New("response missing instance subscription id")
case metadata.Compute.ResourceGroupName == "":
return nil, errs.New("response missing instance resource group name")
}
return metadata, nil
}
type agentPathTemplateData struct {
MSITokenClaims
PluginName string
}
func MakeAgentID(td spiffeid.TrustDomain, agentPathTemplate *agentpathtemplate.Template, claims *MSITokenClaims) (spiffeid.ID, error) {
agentPath, err := agentPathTemplate.Execute(agentPathTemplateData{
MSITokenClaims: *claims,
PluginName: PluginName,
})
if err != nil {
return spiffeid.ID{}, err
}
return idutil.AgentID(td, agentPath)
}
func tryRead(r io.Reader) string {
b := make([]byte, 1024)
n, _ := r.Read(b)
return string(b[:n])
}