Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/semantic-router/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
github.com/onsi/gomega v1.38.0
github.com/openai/openai-go v1.12.0
github.com/prometheus/client_golang v1.23.0
github.com/prometheus/client_model v0.6.2
github.com/vllm-project/semantic-router/candle-binding v0.0.0-00010101000000-000000000000
go.uber.org/zap v1.27.0
google.golang.org/grpc v1.71.1
Expand Down Expand Up @@ -47,7 +48,6 @@ require (
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.65.0 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
Expand Down
116 changes: 116 additions & 0 deletions src/semantic-router/pkg/extproc/metrics_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package extproc

import (
"encoding/json"
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
)

func getHistogramSampleCount(metricName, model string) uint64 {
mf, _ := prometheus.DefaultGatherer.Gather()
for _, fam := range mf {
if fam.GetName() != metricName || fam.GetType() != dto.MetricType_HISTOGRAM {
continue
}
for _, m := range fam.GetMetric() {
labels := m.GetLabel()
match := false
for _, l := range labels {
if l.GetName() == "model" && l.GetValue() == model {
match = true
break
}
}
if match {
h := m.GetHistogram()
if h != nil && h.SampleCount != nil {
return h.GetSampleCount()
}
}
}
}
return 0
}

var _ = Describe("Metrics recording", func() {
var router *OpenAIRouter

BeforeEach(func() {
// Use a minimal router that doesn't require external models
router = &OpenAIRouter{}
// Initialize internal maps used by handlers
router.InitializeForTesting()
})

It("records TTFT on response headers", func() {
ctx := &RequestContext{
RequestModel: "model-a",
ProcessingStartTime: time.Now().Add(-75 * time.Millisecond),
}

before := getHistogramSampleCount("llm_model_ttft_seconds", ctx.RequestModel)

respHeaders := &ext_proc.ProcessingRequest_ResponseHeaders{
ResponseHeaders: &ext_proc.HttpHeaders{
Headers: &core.HeaderMap{Headers: []*core.HeaderValue{{Key: "content-type", Value: "application/json"}}},
},
}

response, err := router.handleResponseHeaders(respHeaders, ctx)
Expect(err).NotTo(HaveOccurred())
Expect(response.GetResponseHeaders()).NotTo(BeNil())

after := getHistogramSampleCount("llm_model_ttft_seconds", ctx.RequestModel)
Expect(after).To(BeNumerically(">", before))
Expect(ctx.TTFTRecorded).To(BeTrue())
Expect(ctx.TTFTSeconds).To(BeNumerically(">", 0))
})

It("records TPOT on response body", func() {
ctx := &RequestContext{
RequestID: "tpot-test-1",
RequestModel: "model-a",
StartTime: time.Now().Add(-1 * time.Second),
}

before := getHistogramSampleCount("llm_model_tpot_seconds", ctx.RequestModel)

openAIResponse := map[string]interface{}{
"id": "chatcmpl-xyz",
"object": "chat.completion",
"created": time.Now().Unix(),
"model": ctx.RequestModel,
"usage": map[string]interface{}{
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"role": "assistant", "content": "Hello"},
"finish_reason": "stop",
},
},
}
respBodyJSON, err := json.Marshal(openAIResponse)
Expect(err).NotTo(HaveOccurred())

respBody := &ext_proc.ProcessingRequest_ResponseBody{
ResponseBody: &ext_proc.HttpBody{Body: respBodyJSON},
}

response, err := router.handleResponseBody(respBody, ctx)
Expect(err).NotTo(HaveOccurred())
Expect(response.GetResponseBody()).NotTo(BeNil())

after := getHistogramSampleCount("llm_model_tpot_seconds", ctx.RequestModel)
Expect(after).To(BeNumerically(">", before))
})
})
2 changes: 1 addition & 1 deletion src/semantic-router/pkg/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (r *OpenAIRouter) Process(stream ext_proc.ExternalProcessor_ProcessServer)
}

case *ext_proc.ProcessingRequest_ResponseHeaders:
response, err := r.handleResponseHeaders(v)
response, err := r.handleResponseHeaders(v, ctx)
if err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions src/semantic-router/pkg/extproc/request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ type RequestContext struct {
RequestQuery string
StartTime time.Time
ProcessingStartTime time.Time

// TTFT tracking
TTFTRecorded bool
TTFTSeconds float64
}

// handleRequestHeaders processes the request headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,13 @@ var _ = Describe("Request Processing", func() {
},
}

response, err := router.HandleResponseHeaders(responseHeaders)
ctx := &extproc.RequestContext{
Headers: make(map[string]string),
RequestModel: "model-a",
ProcessingStartTime: time.Now().Add(-50 * time.Millisecond),
}

response, err := router.HandleResponseHeaders(responseHeaders, ctx)
Expect(err).NotTo(HaveOccurred())
Expect(response).NotTo(BeNil())

Expand Down
17 changes: 16 additions & 1 deletion src/semantic-router/pkg/extproc/response_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@ import (
)

// handleResponseHeaders processes the response headers
func (r *OpenAIRouter) handleResponseHeaders(_ *ext_proc.ProcessingRequest_ResponseHeaders) (*ext_proc.ProcessingResponse, error) {
func (r *OpenAIRouter) handleResponseHeaders(_ *ext_proc.ProcessingRequest_ResponseHeaders, ctx *RequestContext) (*ext_proc.ProcessingResponse, error) {
// Best-effort TTFT measurement: record on first response headers if we have a start time and model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now we haven't tried streaming mode. In the buffered mode, the response from LLM has to be fully received before the response is received. If you can add an issue to track TTFT in streaming mode, that'll be great.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#128 for tracking

if ctx != nil && !ctx.TTFTRecorded && !ctx.ProcessingStartTime.IsZero() && ctx.RequestModel != "" {
ttft := time.Since(ctx.ProcessingStartTime).Seconds()
if ttft > 0 {
metrics.RecordModelTTFT(ctx.RequestModel, ttft)
ctx.TTFTSeconds = ttft
ctx.TTFTRecorded = true
}
}

// Allow the response to continue without modification
response := &ext_proc.ProcessingResponse{
Expand Down Expand Up @@ -53,6 +62,12 @@ func (r *OpenAIRouter) handleResponseBody(v *ext_proc.ProcessingRequest_Response
)
metrics.RecordModelCompletionLatency(ctx.RequestModel, completionLatency.Seconds())

// Record TPOT (time per output token) if completion tokens are available
if completionTokens > 0 {
timePerToken := completionLatency.Seconds() / float64(completionTokens)
metrics.RecordModelTPOT(ctx.RequestModel, timePerToken)
}

// Compute and record cost if pricing is configured
if r.Config != nil {
promptRatePer1M, completionRatePer1M, currency, ok := r.Config.GetModelPricing(ctx.RequestModel)
Expand Down
4 changes: 2 additions & 2 deletions src/semantic-router/pkg/extproc/testing_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ func (r *OpenAIRouter) HandleRequestBody(v *ext_proc.ProcessingRequest_RequestBo
}

// HandleResponseHeaders exposes handleResponseHeaders for testing
func (r *OpenAIRouter) HandleResponseHeaders(v *ext_proc.ProcessingRequest_ResponseHeaders) (*ext_proc.ProcessingResponse, error) {
return r.handleResponseHeaders(v)
func (r *OpenAIRouter) HandleResponseHeaders(v *ext_proc.ProcessingRequest_ResponseHeaders, ctx *RequestContext) (*ext_proc.ProcessingResponse, error) {
return r.handleResponseHeaders(v, ctx)
}

// HandleResponseBody exposes handleResponseBody for testing
Expand Down
42 changes: 42 additions & 0 deletions src/semantic-router/pkg/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,26 @@ var (
[]string{"model"},
)

// ModelTTFT tracks time to first token by model
ModelTTFT = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "llm_model_ttft_seconds",
Help: "Time to first token for LLM model responses in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"model"},
)

// ModelTPOT tracks time per output token by model
ModelTPOT = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "llm_model_tpot_seconds",
Help: "Time per output token (completion latency / completion tokens) for LLM model responses in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"model"},
)

// ModelRoutingLatency tracks the latency of model routing
ModelRoutingLatency = promauto.NewHistogram(
prometheus.HistogramOpts{
Expand Down Expand Up @@ -384,6 +404,28 @@ func RecordModelCompletionLatency(model string, seconds float64) {
ModelCompletionLatency.WithLabelValues(model).Observe(seconds)
}

// RecordModelTTFT records time to first token for a model
func RecordModelTTFT(model string, seconds float64) {
if seconds <= 0 {
return
}
if model == "" {
model = "unknown"
}
ModelTTFT.WithLabelValues(model).Observe(seconds)
}

// RecordModelTPOT records time per output token (seconds per token) for a model
func RecordModelTPOT(model string, secondsPerToken float64) {
if secondsPerToken <= 0 {
return
}
if model == "" {
model = "unknown"
}
ModelTPOT.WithLabelValues(model).Observe(secondsPerToken)
}

// RecordModelRoutingLatency records the latency of model routing
func RecordModelRoutingLatency(seconds float64) {
ModelRoutingLatency.Observe(seconds)
Expand Down
Loading