diff --git a/core/stores/sqlx/readamplication.go b/core/stores/sqlx/readamplication.go new file mode 100644 index 000000000000..df14f0e45750 --- /dev/null +++ b/core/stores/sqlx/readamplication.go @@ -0,0 +1,70 @@ +package sqlx + +import ( + "sync" + + "github.com/zeromicro/go-zero/core/logx" +) + +const ( + concurrencyThreshold = 3 + logInterval = 60 * 1000 // 1 minute +) + +var logger = logx.NewLessLogger(logInterval) + +type ( + concurrentReads struct { + reads map[string]*queryReference + lock sync.Mutex + } + + queryReference struct { + concurrency uint32 + maxConcurrency uint32 + } +) + +func newConcurrentReads() *concurrentReads { + return &concurrentReads{ + reads: make(map[string]*queryReference), + } +} + +func (r *concurrentReads) add(query string) { + r.lock.Lock() + defer r.lock.Unlock() + + if ref, ok := r.reads[query]; ok { + ref.concurrency++ + if ref.maxConcurrency < ref.concurrency { + ref.maxConcurrency = ref.concurrency + } + } else { + r.reads[query] = &queryReference{ + concurrency: 1, + maxConcurrency: 1, + } + } +} + +func (r *concurrentReads) remove(query string) { + r.lock.Lock() + defer r.lock.Unlock() + ref, ok := r.reads[query] + if !ok { + return + } + + if ref.concurrency > 1 { + ref.concurrency-- + return + } + + // last reference to remove + delete(r.reads, query) + if ref.maxConcurrency >= concurrencyThreshold { + logger.Errorf("sql query amplified, query: %q, maxConcurrency: %d", + query, ref.maxConcurrency) + } +} diff --git a/core/stores/sqlx/readamplification_test.go b/core/stores/sqlx/readamplification_test.go new file mode 100644 index 000000000000..4435a8fd7a98 --- /dev/null +++ b/core/stores/sqlx/readamplification_test.go @@ -0,0 +1,33 @@ +package sqlx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestName(t *testing.T) { + const ( + query = "select foo" + times = 10 + ) + cr := newConcurrentReads() + assert.NotPanics(t, func() { + cr.remove(query) + }) + + for i := 0; i < times; i++ { + cr.add(query) + } + + ref := cr.reads[query] + assert.Equal(t, uint32(times), ref.concurrency) + + for i := 0; i < times; i++ { + cr.remove(query) + } + + // just removed, not decremented + assert.Equal(t, uint32(1), ref.concurrency) + assert.Equal(t, uint32(times), ref.maxConcurrency) +} diff --git a/core/stores/sqlx/stmt.go b/core/stores/sqlx/stmt.go index b66dcc7f58e2..86d3a69da9a6 100644 --- a/core/stores/sqlx/stmt.go +++ b/core/stores/sqlx/stmt.go @@ -15,9 +15,10 @@ import ( const defaultSlowThreshold = time.Millisecond * 500 var ( - slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) - logSql = syncx.ForAtomicBool(true) - logSlowSql = syncx.ForAtomicBool(true) + slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) + logSql = syncx.ForAtomicBool(true) + logSlowSql = syncx.ForAtomicBool(true) + concurrentQueries = newConcurrentReads() ) type ( @@ -266,6 +267,7 @@ func (n nilGuard) finish(_ context.Context, _ error) { } func (e *realSqlGuard) finish(ctx context.Context, err error) { + concurrentQueries.remove(e.stmt) duration := timex.Since(e.startTime) if duration > slowThreshold.Load() { logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] %s: slowcall - %s", e.command, e.stmt) @@ -289,6 +291,7 @@ func (e *realSqlGuard) start(q string, args ...any) error { e.stmt = stmt e.startTime = timex.Now() + concurrentQueries.add(stmt) return nil }