Skip to content

Commit

Permalink
Add thread safety to metrics context (#1517)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ardagan committed May 4, 2021
1 parent 9dc9ac8 commit 3e19d63
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 38 deletions.
150 changes: 150 additions & 0 deletions common/metrics/baggage_bench_test.go
@@ -0,0 +1,150 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package metrics

import (
"math/rand"
"sync"
"testing"
"time"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)

type (
// baggageBenchTest Was used to compare the behavior of sync.Map vs map+mutex in the context of usage for
// metricsContext.
// See testMapBaggage for test logic.
// As a summary, using mutex performed considerably faster than sync.Map.
// These results are to be reviewed if the usage scenario of metricsContext changes.
baggageBenchTest struct {
suite.Suite
*require.Assertions
controller *gomock.Controller
}

testBaggage interface {
Add(k string, v int64)
Get(k string) int64
}

baggageSyncMap struct {
data *sync.Map
}

baggageMutexMap struct {
sync.Mutex
data map[string]int64
}
)

func TestBaggageBenchSuite(t *testing.T) {
s := new(baggageBenchTest)
suite.Run(t, s)
}

func (s *baggageBenchTest) SetupTest() {
s.Assertions = require.New(s.T())
s.controller = gomock.NewController(s.T())
}

func (s *baggageBenchTest) TearDownTest() {}

func (b *baggageSyncMap) Add(k string, v int64) {
for done := false; !done; {
metricInterface, _ := b.data.LoadAndDelete(k)
var newValue = v
if metricInterface != nil {
newValue += metricInterface.(int64)
}
_, loaded := b.data.LoadOrStore(k, newValue)
done = !loaded
}
}

func (b *baggageSyncMap) Get(k string) int64 {
metricInterface, _ := b.data.LoadAndDelete(k)
if metricInterface == nil {
return 0
}
return metricInterface.(int64)
}

func (b *baggageMutexMap) Add(k string, v int64) {
b.Lock()
defer b.Unlock()

value, _ := b.data[k]
value += v
b.data[k] = value
}

func (b *baggageMutexMap) Get(k string) int64 {
b.Lock()
defer b.Unlock()
return b.data[k]
}

// roughly 1.7s/7.5s for mutex/sync
//baggageCount := 1000
//threadCount := 20
//updatesPerThread := 1000
func testMapBaggage(createTestObj func() testBaggage) {
baggageCount := 10
threadCount := 10
updatesPerThread := 10

keys := []string{"k1", "k2", "k3", "k4", "k5"}
start := time.Now()
sum := int64(0)
for bag := 0; bag < baggageCount; bag++ {
testObj := createTestObj()
wg := sync.WaitGroup{}
wg.Add(threadCount)
for th := 0; th < threadCount; th++ {
go func(key string) {
for upd := 0; upd < updatesPerThread; upd++ {
testObj.Add(key, rand.Int63())
}
wg.Done()
}(keys[th%len(keys)])
}
wg.Wait()
val := testObj.Get(keys[0])
sum += val
}
println("sum: ", sum)
println("duration: ", time.Since(start))
}

func (s *baggageBenchTest) TestSyncMapBaggage() {
testMapBaggage(func() testBaggage { return &baggageSyncMap{data: &sync.Map{}} })
}

func (s *baggageBenchTest) TestMutexMapBaggage() {
testMapBaggage(func() testBaggage { return &baggageMutexMap{data: make(map[string]int64)} })
}
93 changes: 67 additions & 26 deletions common/metrics/grpc.go
Expand Up @@ -26,6 +26,7 @@ package metrics

import (
"context"
"sync"

metricspb "go.temporal.io/server/api/metrics/v1"
"go.temporal.io/server/common/log"
Expand All @@ -35,13 +36,21 @@ import (
"google.golang.org/grpc/metadata"
)

type baggageContextKey struct{}
type (
metricsContextKey struct{}

// metricsContext is used to propagate metrics across single gRPC call within server
metricsContext struct {
sync.Mutex
CountersInt map[string]int64
}
)

var (
// "-bin" suffix is a reserved in gRPC that signals that metadata string value is actually a byte data
// If trailer key has such a suffix, value will be base64 encoded.
baggageTrailerKey = "metrics-baggage-bin"
baggageCtxKey = baggageContextKey{}
metricsTrailerKey = "metrics-trailer-bin"
metricsCtxKey = metricsContextKey{}
)

// NewServerMetricsContextInjectorInterceptor returns grpc server interceptor that adds metrics context to golang
Expand All @@ -53,7 +62,7 @@ func NewServerMetricsContextInjectorInterceptor() grpc.UnaryServerInterceptor {
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
ctxWithMetricsBaggage := AddMetricsBaggageToContext(ctx)
ctxWithMetricsBaggage := addMetricsContext(ctx)
return handler(ctxWithMetricsBaggage, req)
}
}
Expand All @@ -69,7 +78,7 @@ func NewClientMetricsTrailerPropagatorInterceptor(logger log.Logger) grpc.UnaryC
optsWithTrailer := append(opts, grpc.Trailer(&trailer))
err := invoker(ctx, method, req, reply, cc, optsWithTrailer...)

baggageStrings := trailer.Get(baggageTrailerKey)
baggageStrings := trailer.Get(metricsTrailerKey)
if len(baggageStrings) == 0 {
return err
}
Expand Down Expand Up @@ -109,17 +118,25 @@ func NewServerMetricsTrailerPropagatorInterceptor(logger log.Logger) grpc.UnaryS
default:
}

baggage := GetMetricsBaggageFromContext(ctx)
if baggage == nil {
metricsCtx := getMetricsContext(ctx)
if metricsCtx == nil {
return resp, err
}

bytes, marshalErr := baggage.Marshal()
metricsBaggage := &metricspb.Baggage{CountersInt: make(map[string]int64)}

metricsCtx.Lock()
for k, v := range metricsCtx.CountersInt {
metricsBaggage.CountersInt[k] = v
}
metricsCtx.Unlock()

bytes, marshalErr := metricsBaggage.Marshal()
if marshalErr != nil {
logger.Error("unable to marshal metric baggage", tag.Error(marshalErr))
}

md := metadata.Pairs(baggageTrailerKey, string(bytes))
md := metadata.Pairs(metricsTrailerKey, string(bytes))

marshalErr = grpc.SetTrailer(ctx, md)
if marshalErr != nil {
Expand All @@ -130,34 +147,58 @@ func NewServerMetricsTrailerPropagatorInterceptor(logger log.Logger) grpc.UnaryS
}
}

// GetMetricsBaggageFromContext extracts metrics context from golang context.
func GetMetricsBaggageFromContext(ctx context.Context) *metricspb.Baggage {
metricsBaggage := ctx.Value(baggageCtxKey)
if metricsBaggage == nil {
// getMetricsContext extracts metrics context from golang context.
func getMetricsContext(ctx context.Context) *metricsContext {
metricsCtx := ctx.Value(metricsCtxKey)
if metricsCtx == nil {
return nil
}

return metricsBaggage.(*metricspb.Baggage)
return metricsCtx.(*metricsContext)
}

func AddMetricsBaggageToContext(ctx context.Context) context.Context {
metricsBaggage := &metricspb.Baggage{}
return context.WithValue(ctx, baggageCtxKey, metricsBaggage)
func addMetricsContext(ctx context.Context) context.Context {
metricsCtx := &metricsContext{}
return context.WithValue(ctx, metricsCtxKey, metricsCtx)
}

// ContextCounterAdd adds value to counter within metrics context.
func ContextCounterAdd(ctx context.Context, name string, value int64) {
metricsBaggage := GetMetricsBaggageFromContext(ctx)
func ContextCounterAdd(ctx context.Context, name string, value int64) bool {
metricsCtx := getMetricsContext(ctx)

if metricsCtx == nil {
return false
}

if metricsBaggage == nil {
return
metricsCtx.Lock()
defer metricsCtx.Unlock()

if metricsCtx.CountersInt == nil {
metricsCtx.CountersInt = make(map[string]int64)
}

val := metricsCtx.CountersInt[name]
val += value
metricsCtx.CountersInt[name] = val

return true
}

// ContextCounterGet returns value and true if successfully retrieved value
func ContextCounterGet(ctx context.Context, name string) (int64, bool) {
metricsCtx := getMetricsContext(ctx)

if metricsCtx == nil {
return 0, false
}

if metricsBaggage.CountersInt == nil {
metricsBaggage.CountersInt = make(map[string]int64)
metricsCtx.Lock()
defer metricsCtx.Unlock()

if metricsCtx.CountersInt == nil {
return 0, false
}

metricValue, _ := metricsBaggage.CountersInt[name]
metricValue += value
metricsBaggage.CountersInt[name] = metricValue
result, _ := metricsCtx.CountersInt[name]
return result, true
}
13 changes: 6 additions & 7 deletions common/metrics/grpc_test.go
Expand Up @@ -89,7 +89,7 @@ func (s *grpcSuite) TestMetadataMetricInjection() {
s.Fail("failed to marshal values")
}
*trailer.TrailerAddr = metadata.MD{}
trailer.TrailerAddr.Append(baggageTrailerKey, string(data))
trailer.TrailerAddr.Append(metricsTrailerKey, string(data))
return nil
},
)
Expand All @@ -99,7 +99,7 @@ func (s *grpcSuite) TestMetadataMetricInjection() {

s.Nil(err)
s.Equal(len(ssts.trailers), 1)
propagationContextBlobs := ssts.trailers[0].Get(baggageTrailerKey)
propagationContextBlobs := ssts.trailers[0].Get(metricsTrailerKey)
s.NotNil(propagationContextBlobs)
s.Equal(1, len(propagationContextBlobs))
baggage := &metricspb.Baggage{}
Expand Down Expand Up @@ -144,7 +144,7 @@ func (s *grpcSuite) TestMetadataMetricInjection_NoMetricPresent() {
s.Fail("failed to marshal values")
}
trailer.TrailerAddr = &metadata.MD{}
trailer.TrailerAddr.Append(baggageTrailerKey, string(data))
trailer.TrailerAddr.Append(metricsTrailerKey, string(data))
return nil
},
)
Expand All @@ -154,7 +154,7 @@ func (s *grpcSuite) TestMetadataMetricInjection_NoMetricPresent() {

s.Nil(err)
s.Equal(len(ssts.trailers), 1)
propagationContextBlobs := ssts.trailers[0].Get(baggageTrailerKey)
propagationContextBlobs := ssts.trailers[0].Get(metricsTrailerKey)
s.NotNil(propagationContextBlobs)
s.Equal(1, len(propagationContextBlobs))
baggage := &metricspb.Baggage{}
Expand All @@ -171,15 +171,14 @@ func (s *grpcSuite) TestMetadataMetricInjection_NoMetricPresent() {
}

func (s *grpcSuite) TestContextCounterAdd() {
ctx := AddMetricsBaggageToContext(context.Background())
ctx := addMetricsContext(context.Background())

testCounterName := "test_counter"
ContextCounterAdd(ctx, testCounterName, 100)
ContextCounterAdd(ctx, testCounterName, 20)
ContextCounterAdd(ctx, testCounterName, 3)

metricsBaggage := GetMetricsBaggageFromContext(ctx)
value, ok := metricsBaggage.CountersInt[testCounterName]
value, ok := ContextCounterGet(ctx, testCounterName)
s.True(ok)
s.Equal(int64(123), value)
}
Expand Down
7 changes: 2 additions & 5 deletions common/rpc/interceptor/telemetry.go
Expand Up @@ -115,11 +115,8 @@ func (ti *TelemetryInterceptor) Intercept(

resp, err := handler(ctx, req)

metricsBaggage := metrics.GetMetricsBaggageFromContext(ctx)
if metricsBaggage != nil {
if val, ok := metricsBaggage.CountersInt[metrics.HistoryWorkflowExecutionCacheLatency]; ok {
timerNoUserLatency.Subtract(time.Duration(val))
}
if val, ok := metrics.ContextCounterGet(ctx, metrics.HistoryWorkflowExecutionCacheLatency); ok {
timerNoUserLatency.Subtract(time.Duration(val))
}

if err != nil {
Expand Down

0 comments on commit 3e19d63

Please sign in to comment.