Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add X-Request-Id and User-Agent headers to attestation requests #509

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions tpm/attestation/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -214,7 +215,7 @@ func (ac *Client) attest(ctx context.Context, info *tpm.Info, ek *tpm.EK, attest
}

attestURL := ac.baseURL.JoinPath("attest").String()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, attestURL, bytes.NewReader(body))
req, err := newRequest(ctx, http.MethodPost, attestURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed creating POST http request for %q: %w", attestURL, err)
}
Expand Down Expand Up @@ -258,7 +259,7 @@ func (ac *Client) secret(ctx context.Context, secret []byte) (*secretResponse, e
}

secretURL := ac.baseURL.JoinPath("secret").String()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, secretURL, bytes.NewReader(body))
req, err := newRequest(ctx, http.MethodPost, secretURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed creating POST http request for %q: %w", secretURL, err)
}
Expand All @@ -280,3 +281,14 @@ func (ac *Client) secret(ctx context.Context, secret []byte) (*secretResponse, e

return &secretResp, nil
}

func newRequest(ctx context.Context, method, requestURL string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, requestURL, body)
if err != nil {
return nil, err
}
enforceRequestID(req)
setUserAgent(req)

return req, nil
}
7 changes: 7 additions & 0 deletions tpm/attestation/client_simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func mustParseURL(t *testing.T, urlString string) *url.URL {

func TestClient_Attest(t *testing.T) {
ctx := context.Background()
ctx = NewRequestIDContext(ctx, "requestID")
instance := newSimulatedTPM(t)
ak, err := instance.CreateAK(ctx, "ak1")
require.NoError(t, err)
Expand Down Expand Up @@ -140,6 +141,9 @@ func TestClient_Attest(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/attest":
assert.Equal(t, "step-attestation-http-client/1.0", r.Header.Get("User-Agent"))
assert.Equal(t, "requestID", r.Header.Get("X-Request-Id"))

var ar attestationRequest
err := json.NewDecoder(r.Body).Decode(&ar)
require.NoError(t, err)
Expand All @@ -165,6 +169,9 @@ func TestClient_Attest(t *testing.T) {
Secret: encryptedCredentials.Secret,
})
case "/secret":
assert.Equal(t, "step-attestation-http-client/1.0", r.Header.Get("User-Agent"))
assert.Equal(t, "requestID", r.Header.Get("X-Request-Id"))

var sr secretRequest
err := json.NewDecoder(r.Body).Decode(&sr)
require.NoError(t, err)
Expand Down
52 changes: 52 additions & 0 deletions tpm/attestation/requestid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package attestation

import (
"context"
"net/http"

"go.step.sm/crypto/randutil"
)

type requestIDContextKey struct{}

// NewRequestIDContext returns a new context with the given request ID added to the
// context.
func NewRequestIDContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDContextKey{}, requestID)
}

// RequestIDFromContext returns the request ID from the context if it exists.
// and is not empty.
func RequestIDFromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(requestIDContextKey{}).(string)
return v, ok && v != ""
}

// requestIDHeader is the header name used for propagating request IDs from
// the attestation client to the attestation CA and back again.
const requestIDHeader = "X-Request-Id"

// newRequestID generates a new random UUIDv4 request ID. If it fails,
// the request ID will be the empty string.
func newRequestID() string {
requestID, err := randutil.UUIDv4()
if err != nil {
return ""
}

return requestID
}

// enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's
// empty, the context is searched for a request ID. If that's also empty, a new
// request ID is generated.
func enforceRequestID(r *http.Request) {
if requestID := r.Header.Get(requestIDHeader); requestID == "" {
if reqID, ok := RequestIDFromContext(r.Context()); ok {
requestID = reqID
} else {
requestID = newRequestID()
}
r.Header.Set(requestIDHeader, requestID)
}
}
12 changes: 12 additions & 0 deletions tpm/attestation/useragent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package attestation

import "net/http"

// UserAgent is the value of the User-Agent HTTP header that will
// be set in HTTP requests to the attestation CA.
var UserAgent = "step-attestation-http-client/1.0"

// setUserAgent sets the User-Agent header in HTTP requests.
func setUserAgent(r *http.Request) {
r.Header.Set("User-Agent", UserAgent)
}
Loading