Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4cc212a
commit f6812e5
Showing
2 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
// Copyright © 2021 by PACE Telematics GmbH. All rights reserved. | ||
// Created at 2021/05/10 by Alessandro Miceli | ||
|
||
|
||
package middleware | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"strings" | ||
"sync" | ||
|
||
"github.com/pace/bricks/maintenance/log" | ||
) | ||
|
||
// ClientIDHeaderName name of the HTTP header that is used for reporting | ||
const ClientIDHeaderName = "Client-ID" | ||
|
||
// ResponseClientID middleware to report client ID | ||
func ResponseClientID(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
var rcc ResponseClientIDContext | ||
rcw := responseClientIDWriter{ | ||
ResponseWriter: w, | ||
rcc: &rcc, | ||
} | ||
|
||
r = r.WithContext(ContextWithResponseClientID(r.Context(), &rcc)) | ||
next.ServeHTTP(&rcw, r) | ||
}) | ||
} | ||
|
||
func AddResponseClientID(ctx context.Context, clientID string) { | ||
cIDc := ResponseClientIDContextFromContext(ctx) | ||
if cIDc == nil { | ||
log.Ctx(ctx).Warn().Msgf("can't add client %s, because context is missing", clientID) | ||
return | ||
} | ||
cIDc.AddResponseClientID(clientID) | ||
} | ||
|
||
type responseClientIDWriter struct { | ||
http.ResponseWriter | ||
header bool | ||
rcc *ResponseClientIDContext | ||
} | ||
|
||
// addHeader adds the clientID header if not done already | ||
func (w *responseClientIDWriter) addHeader() { | ||
if len(w.rcc.clientIDs) > 0 { | ||
w.ResponseWriter.Header().Add(ClientIDHeaderName, w.rcc.String()) | ||
} | ||
w.header = true | ||
} | ||
|
||
func (w *responseClientIDWriter) Write(data []byte) (int, error) { | ||
w.addHeader() | ||
return w.ResponseWriter.Write(data) | ||
} | ||
|
||
// ContextWithResponseClientID creates a contex with the provided client ID | ||
func ContextWithResponseClientID(ctx context.Context, rcc *ResponseClientIDContext) context.Context { | ||
return context.WithValue(ctx, (*ResponseClientIDContext)(nil), rcc) | ||
} | ||
|
||
// ResponseClientIDContextFromContext returns the client ID context or nil | ||
func ResponseClientIDContextFromContext(ctx context.Context) *ResponseClientIDContext { | ||
if v := ctx.Value((*ResponseClientIDContext)(nil)); v != nil { | ||
return v.(*ResponseClientIDContext) | ||
} | ||
return nil | ||
} | ||
|
||
// ResponseClientIDContext contains all client IDs that were seen | ||
// during the request livecycle | ||
type ResponseClientIDContext struct { | ||
mu sync.RWMutex | ||
clientIDs []responseClientID | ||
} | ||
|
||
func (rcc *ResponseClientIDContext) AddResponseClientID(clientID string) { | ||
rcc.mu.Lock() | ||
rcc.clientIDs = append(rcc.clientIDs, responseClientID{ | ||
ClientID: clientID, | ||
}) | ||
rcc.mu.Unlock() | ||
} | ||
|
||
// String formats all client IDs | ||
func (rcc *ResponseClientIDContext) String() string { | ||
var b strings.Builder | ||
sep := len(rcc.clientIDs) - 1 | ||
for _, dep := range rcc.clientIDs { | ||
b.WriteString(dep.String()) | ||
if sep > 0 { | ||
b.WriteByte(',') | ||
sep-- | ||
} | ||
} | ||
return b.String() | ||
} | ||
|
||
// responseClientID represents the client ID that | ||
// was sent in the request | ||
type responseClientID struct { | ||
ClientID string //client ID name | ||
} | ||
|
||
// String return a client ID | ||
func (rcc responseClientID) String() string { | ||
return rcc.ClientID | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// Copyright © 2021 by PACE Telematics GmbH. All rights reserved. | ||
// Created at 2021/05/10 by Alessandro Miceli | ||
|
||
package middleware | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func Test_ResponseClientID_Middleare(t *testing.T) { | ||
AddResponseClientID(context.TODO(), "test") // should not fail | ||
t.Run("empty set", func(t *testing.T) { | ||
rec := httptest.NewRecorder() | ||
req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
|
||
h := ResponseClientID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
})) | ||
h.ServeHTTP(rec, req) | ||
assert.Nil(t, rec.Result().Header[http.CanonicalHeaderKey(ClientIDHeaderName)]) | ||
}) | ||
t.Run("one client set", func(t *testing.T) { | ||
rec := httptest.NewRecorder() | ||
req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
|
||
h := ResponseClientID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
AddResponseClientID(r.Context(), "test") | ||
w.Write(nil) // nolint: errcheck | ||
})) | ||
h.ServeHTTP(rec, req) | ||
assert.Equal(t, rec.Result().Header[http.CanonicalHeaderKey(ClientIDHeaderName)][0], "test") | ||
}) | ||
} | ||
|
||
func Test_ResponseClientIDContext_String(t *testing.T) { | ||
var rcc ResponseClientIDContext | ||
|
||
// empty | ||
assert.Empty(t, rcc.String()) | ||
|
||
// one dependency | ||
rcc.AddResponseClientID("test1") | ||
assert.EqualValues(t, "test1", rcc.String()) | ||
|
||
// multiple dependencies | ||
rcc.AddResponseClientID("test2") | ||
assert.EqualValues(t, "test1,test2", rcc.String()) | ||
|
||
rcc.AddResponseClientID("test3") | ||
assert.EqualValues(t, "test1,test2,test3", rcc.String()) | ||
|
||
rcc.AddResponseClientID("test4") | ||
assert.EqualValues(t, "test1,test2,test3,test4", rcc.String()) | ||
} |