Skip to content

Commit

Permalink
client id middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandroMiceli committed May 10, 2021
1 parent 4cc212a commit f6812e5
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 0 deletions.
112 changes: 112 additions & 0 deletions http/middleware/response_header.go
@@ -0,0 +1,112 @@
// Copyright © 2021 by PACE Telematics GmbH. All rights reserved.
// Created at 2021/05/10 by Alessandro Miceli


package middleware

import (
"context"
"net/http"
"strings"
"sync"

"github.com/pace/bricks/maintenance/log"
)

// ClientIDHeaderName name of the HTTP header that is used for reporting
const ClientIDHeaderName = "Client-ID"

// ResponseClientID middleware to report client ID
func ResponseClientID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var rcc ResponseClientIDContext
rcw := responseClientIDWriter{
ResponseWriter: w,
rcc: &rcc,
}

r = r.WithContext(ContextWithResponseClientID(r.Context(), &rcc))
next.ServeHTTP(&rcw, r)
})
}

func AddResponseClientID(ctx context.Context, clientID string) {
cIDc := ResponseClientIDContextFromContext(ctx)
if cIDc == nil {
log.Ctx(ctx).Warn().Msgf("can't add client %s, because context is missing", clientID)
return
}
cIDc.AddResponseClientID(clientID)
}

type responseClientIDWriter struct {
http.ResponseWriter
header bool
rcc *ResponseClientIDContext
}

// addHeader adds the clientID header if not done already
func (w *responseClientIDWriter) addHeader() {
if len(w.rcc.clientIDs) > 0 {
w.ResponseWriter.Header().Add(ClientIDHeaderName, w.rcc.String())
}
w.header = true
}

func (w *responseClientIDWriter) Write(data []byte) (int, error) {
w.addHeader()
return w.ResponseWriter.Write(data)
}

// ContextWithResponseClientID creates a contex with the provided client ID
func ContextWithResponseClientID(ctx context.Context, rcc *ResponseClientIDContext) context.Context {
return context.WithValue(ctx, (*ResponseClientIDContext)(nil), rcc)
}

// ResponseClientIDContextFromContext returns the client ID context or nil
func ResponseClientIDContextFromContext(ctx context.Context) *ResponseClientIDContext {
if v := ctx.Value((*ResponseClientIDContext)(nil)); v != nil {
return v.(*ResponseClientIDContext)
}
return nil
}

// ResponseClientIDContext contains all client IDs that were seen
// during the request livecycle
type ResponseClientIDContext struct {
mu sync.RWMutex
clientIDs []responseClientID
}

func (rcc *ResponseClientIDContext) AddResponseClientID(clientID string) {
rcc.mu.Lock()
rcc.clientIDs = append(rcc.clientIDs, responseClientID{
ClientID: clientID,
})
rcc.mu.Unlock()
}

// String formats all client IDs
func (rcc *ResponseClientIDContext) String() string {
var b strings.Builder
sep := len(rcc.clientIDs) - 1
for _, dep := range rcc.clientIDs {
b.WriteString(dep.String())
if sep > 0 {
b.WriteByte(',')
sep--
}
}
return b.String()
}

// responseClientID represents the client ID that
// was sent in the request
type responseClientID struct {
ClientID string //client ID name
}

// String return a client ID
func (rcc responseClientID) String() string {
return rcc.ClientID
}
58 changes: 58 additions & 0 deletions http/middleware/response_header_test.go
@@ -0,0 +1,58 @@
// Copyright © 2021 by PACE Telematics GmbH. All rights reserved.
// Created at 2021/05/10 by Alessandro Miceli

package middleware

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_ResponseClientID_Middleare(t *testing.T) {
AddResponseClientID(context.TODO(), "test") // should not fail
t.Run("empty set", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)

h := ResponseClientID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
h.ServeHTTP(rec, req)
assert.Nil(t, rec.Result().Header[http.CanonicalHeaderKey(ClientIDHeaderName)])
})
t.Run("one client set", func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)

h := ResponseClientID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
AddResponseClientID(r.Context(), "test")
w.Write(nil) // nolint: errcheck
}))
h.ServeHTTP(rec, req)
assert.Equal(t, rec.Result().Header[http.CanonicalHeaderKey(ClientIDHeaderName)][0], "test")
})
}

func Test_ResponseClientIDContext_String(t *testing.T) {
var rcc ResponseClientIDContext

// empty
assert.Empty(t, rcc.String())

// one dependency
rcc.AddResponseClientID("test1")
assert.EqualValues(t, "test1", rcc.String())

// multiple dependencies
rcc.AddResponseClientID("test2")
assert.EqualValues(t, "test1,test2", rcc.String())

rcc.AddResponseClientID("test3")
assert.EqualValues(t, "test1,test2,test3", rcc.String())

rcc.AddResponseClientID("test4")
assert.EqualValues(t, "test1,test2,test3,test4", rcc.String())
}

0 comments on commit f6812e5

Please sign in to comment.