-
Notifications
You must be signed in to change notification settings - Fork 117
/
query.go
198 lines (174 loc) · 6.68 KB
/
query.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
package runtime
import (
"context"
"fmt"
"io"
"strings"
"github.com/dgraph-io/ristretto"
runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/rilldata/rill/runtime/pkg/singleflight"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
var (
meter = otel.Meter("github.com/rilldata/rill/runtime")
queryCacheHitsCounter = observability.Must(meter.Int64ObservableCounter("query_cache.hits"))
queryCacheMissesCounter = observability.Must(meter.Int64ObservableCounter("query_cache.misses"))
queryCacheItemCountGauge = observability.Must(meter.Int64ObservableGauge("query_cache.items"))
queryCacheSizeBytesGauge = observability.Must(meter.Int64ObservableGauge("query_cache.size", metric.WithUnit("bytes")))
queryCacheEntrySizeHistogram = observability.Must(meter.Int64Histogram("query_cache.entry_size", metric.WithUnit("bytes")))
)
type QueryResult struct {
Value any
Bytes int64
}
type ExportOptions struct {
Format runtimev1.ExportFormat
Priority int
PreWriteHook func(filename string) error
}
type Query interface {
// Key should return a cache key that uniquely identifies the query
Key() string
// Deps should return the source and model names that the query targets.
// It's used to invalidate cached queries when the underlying data changes.
Deps() []string
// MarshalResult should return the query result and estimated cost in bytes for caching
MarshalResult() *QueryResult
// UnmarshalResult should populate a query with a cached result
UnmarshalResult(v any) error
// Resolve should execute the query against the instance's infra.
// Error can be nil along with a nil result in general, i.e. when a model contains no rows aggregation results can be nil.
Resolve(ctx context.Context, rt *Runtime, instanceID string, priority int) error
// Export resolves the query and serializes the result to the writer.
Export(ctx context.Context, rt *Runtime, instanceID string, w io.Writer, opts *ExportOptions) error
}
func (r *Runtime) Query(ctx context.Context, instanceID string, query Query, priority int) error {
// If key is empty, skip caching
qk := query.Key()
if qk == "" {
return query.Resolve(ctx, r, instanceID, priority)
}
// Skip caching for specific named drivers.
// TODO: Make this configurable with a default provided by the driver.
inst, err := r.FindInstance(ctx, instanceID)
if err != nil {
return err
}
if inst.OLAPDriver == "druid" {
return query.Resolve(ctx, r, instanceID, priority)
}
// Get dependency cache keys
deps := query.Deps()
depKeys := make([]string, len(deps))
for i, dep := range deps {
entry, err := r.GetCatalogEntry(ctx, instanceID, dep)
if err != nil {
// This err usually means the query has a dependency that does not exist in the catalog.
// Returning the error is not critical, it just saves a redundant subsequent query to the OLAP, which would likely fail.
// However, for dependencies created in the OLAP DB directly (and are hence not tracked in the catalog), the query would actually succeed.
// For read-only Druid dashboards on existing tables, we specifically need the ColumnTimeRange to succeed.
// TODO: Remove this horrible hack when discovery of existing tables is implemented. Then we can safely return an error in all cases.
if strings.HasPrefix(qk, "ColumnTimeRange") {
continue
}
return fmt.Errorf("query dependency %q not found", dep)
}
depKeys[i] = entry.Name + ":" + entry.RefreshedOn.String()
}
// If there were no known dependencies, skip caching
if len(depKeys) == 0 {
return query.Resolve(ctx, r, instanceID, priority)
}
// Build cache key
depKey := strings.Join(depKeys, ";")
key := queryCacheKey{
instanceID: instanceID,
queryKey: query.Key(),
dependencyKey: depKey,
}.String()
// Try to get from cache
if val, ok := r.queryCache.cache.Get(key); ok {
observability.AddRequestAttributes(ctx, attribute.Bool("query.cache_hit", true))
return query.UnmarshalResult(val)
}
observability.AddRequestAttributes(ctx, attribute.Bool("query.cache_hit", false))
// Load with singleflight
owner := false
val, err := r.queryCache.singleflight.Do(ctx, key, func(ctx context.Context) (any, error) {
// Try cache again
if val, ok := r.queryCache.cache.Get(key); ok {
return val, nil
}
// Load
err := query.Resolve(ctx, r, instanceID, priority)
if err != nil {
return nil, err
}
owner = true
res := query.MarshalResult()
r.queryCache.cache.Set(key, res.Value, res.Bytes)
queryCacheEntrySizeHistogram.Record(ctx, res.Bytes, metric.WithAttributes(attribute.String("query", queryName(query))))
return res.Value, nil
})
if err != nil {
return err
}
if !owner {
return query.UnmarshalResult(val)
}
return nil
}
type queryCacheKey struct {
instanceID string
queryKey string
dependencyKey string
}
func (k queryCacheKey) String() string {
return fmt.Sprintf("inst:%s deps:%s qry:%s", k.instanceID, k.dependencyKey, k.queryKey)
}
type queryCache struct {
cache *ristretto.Cache
singleflight *singleflight.Group[string, any]
metrics metric.Registration
}
func newQueryCache(sizeInBytes int64) *queryCache {
if sizeInBytes <= 100 {
panic(fmt.Sprintf("invalid cache size should be greater than 100: %v", sizeInBytes))
}
cache, err := ristretto.NewCache(&ristretto.Config{
// Use 5% of cache memory for storing counters. Each counter takes roughly 3 bytes.
// Recommended value is 10x the number of items in cache when full.
// Tune this again based on metrics.
NumCounters: int64(float64(sizeInBytes) * 0.05 / 3),
MaxCost: int64(float64(sizeInBytes) * 0.95),
BufferItems: 64,
Metrics: true,
})
if err != nil {
panic(err)
}
metrics := observability.Must(meter.RegisterCallback(func(ctx context.Context, observer metric.Observer) error {
observer.ObserveInt64(queryCacheHitsCounter, int64(cache.Metrics.Hits()))
observer.ObserveInt64(queryCacheMissesCounter, int64(cache.Metrics.Misses()))
observer.ObserveInt64(queryCacheItemCountGauge, int64(cache.Metrics.KeysAdded()-cache.Metrics.KeysEvicted()))
observer.ObserveInt64(queryCacheSizeBytesGauge, int64(cache.Metrics.CostAdded()-cache.Metrics.CostEvicted()))
return nil
}, queryCacheHitsCounter, queryCacheMissesCounter, queryCacheItemCountGauge, queryCacheSizeBytesGauge))
return &queryCache{
cache: cache,
singleflight: &singleflight.Group[string, any]{},
metrics: metrics,
}
}
func (c *queryCache) close() error {
c.cache.Close()
return c.metrics.Unregister()
}
func queryName(q Query) string {
nameWithPkg := fmt.Sprintf("%T", q)
_, after, _ := strings.Cut(nameWithPkg, ".")
return after
}