Skip to content

Commit ea17f5f

Browse files
committed
Allow requests with no body to passthrough EPP
1 parent 9b96433 commit ea17f5f

File tree

9 files changed

+290
-207
lines changed

9 files changed

+290
-207
lines changed

config/manifests/inferencepool.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ spec:
4343
spec:
4444
containers:
4545
- name: epp
46-
image: us-east1-docker.pkg.dev/kfswain-gke-dev/test-repo/ext-proc:test-mar-51
46+
image: us-east1-docker.pkg.dev/kfswain-gke-dev/test-repo/ext-proc:test-mar-53
4747
imagePullPolicy: Always
4848
args:
4949
- -poolName

pkg/epp/datastore/datastore.go

+22-9
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ type Datastore interface {
6565
PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.InferencePool) bool
6666
PodDelete(namespacedName types.NamespacedName)
6767
PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool)
68+
// PodGetRandom will grab a random pod, used for selecting a passthrough endpoint.
69+
PodGetRandom() *backendmetrics.Pod
70+
71+
RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel) string
72+
// This func should only be called for tests, to set a deterministic seed.
73+
SetRand(seed int64)
6874

6975
// Clears the store state, happens when the pool gets deleted.
7076
Clear()
@@ -77,6 +83,7 @@ func NewDatastore(parentCtx context.Context, pmf *backendmetrics.PodMetricsFacto
7783
models: make(map[string]*v1alpha2.InferenceModel),
7884
pods: &sync.Map{},
7985
pmf: pmf,
86+
rand: rand.New(rand.NewSource(rand.Int63())),
8087
}
8188
return store
8289
}
@@ -92,6 +99,7 @@ type datastore struct {
9299
// key: types.NamespacedName, value: backendmetrics.PodMetrics
93100
pods *sync.Map
94101
pmf *backendmetrics.PodMetricsFactory
102+
rand *rand.Rand
95103
}
96104

97105
func (ds *datastore) Clear() {
@@ -292,6 +300,17 @@ func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
292300
}
293301
}
294302

303+
func (ds *datastore) PodGetRandom() *backendmetrics.Pod {
304+
pods := ds.PodGetAll()
305+
pod := pods[rand.Intn(len(pods))]
306+
return pod.GetPod()
307+
}
308+
309+
// This func should only be called for tests to set a deterministic seed
310+
func (ds *datastore) SetRand(seed int64) {
311+
ds.rand = rand.New(rand.NewSource(seed))
312+
}
313+
295314
func selectorFromInferencePoolSelector(selector map[v1alpha2.LabelKey]v1alpha2.LabelValue) labels.Selector {
296315
return labels.SelectorFromSet(stripLabelKeyAliasFromLabelMap(selector))
297316
}
@@ -304,16 +323,10 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha2.LabelKey]v1alpha2.LabelV
304323
return outMap
305324
}
306325

307-
func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string {
308-
source := rand.NewSource(rand.Int63())
309-
if seed > 0 {
310-
source = rand.NewSource(seed)
311-
}
312-
r := rand.New(source)
313-
326+
func (ds *datastore) RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel) string {
314327
// all the weight values are nil, then we should return random model name
315328
if model.Spec.TargetModels[0].Weight == nil {
316-
index := r.Int31n(int32(len(model.Spec.TargetModels)))
329+
index := ds.rand.Int31n(int32(len(model.Spec.TargetModels)))
317330
return model.Spec.TargetModels[index].Name
318331
}
319332

@@ -322,7 +335,7 @@ func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed
322335
weights += *model.Weight
323336
}
324337
logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights)
325-
randomVal := r.Int31n(weights)
338+
randomVal := ds.rand.Int31n(weights)
326339
// TODO: optimize this without using loop
327340
for _, model := range model.Spec.TargetModels {
328341
if randomVal < *model.Weight {

pkg/epp/datastore/datastore_test.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,12 @@ func TestRandomWeightedDraw(t *testing.T) {
313313
},
314314
}
315315
var seedVal int64 = 420
316+
var ds *datastore = &datastore{}
317+
ds.SetRand(seedVal)
316318
for _, test := range tests {
317319
t.Run(test.name, func(t *testing.T) {
318320
for range 10000 {
319-
model := RandomWeightedDraw(logger, test.model, seedVal)
321+
model := ds.RandomWeightedDraw(logger, test.model)
320322
if model != test.want {
321323
t.Errorf("Model returned: %v != %v", model, test.want)
322324
break

pkg/epp/handlers/request.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (s *Server) HandleRequestBody(
6969
return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)}
7070
}
7171
if len(modelObj.Spec.TargetModels) > 0 {
72-
modelName = datastore.RandomWeightedDraw(logger, modelObj, 0)
72+
modelName = s.datastore.RandomWeightedDraw(logger, modelObj)
7373
if modelName == "" {
7474
return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
7575
}

pkg/epp/handlers/response.go

+4-6
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ func (s *Server) HandleResponseHeaders(
8585
if header.Key == "content-type" {
8686
contentType := header.RawValue
8787
if strings.Contains(string(contentType), "text/event-stream") {
88-
reqCtx.Streaming = true
89-
} else {
90-
reqCtx.Streaming = false
88+
reqCtx.modelServerStreaming = true
9189
}
9290
typeFound = true
9391
}
@@ -155,7 +153,7 @@ func (s *Server) HandleResponseBody(
155153
loggerVerbose := logger.V(logutil.VERBOSE)
156154
body := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)
157155

158-
if reqCtx.Streaming {
156+
if reqCtx.modelServerStreaming {
159157
logger.V(logutil.DEBUG).Info("Processing HandleResponseBody")
160158
if err := s.HandleStreaming(ctx, reqCtx, body, loggerVerbose); err != nil {
161159
return nil, err
@@ -189,7 +187,7 @@ func (s *Server) HandleNonStreaming(
189187
if err := json.Unmarshal(body.ResponseBody.Body, &res); err != nil {
190188
return errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("unmarshaling response body: %v", err)}
191189
}
192-
reqCtx.Response = res
190+
reqCtx.Usage = res.Usage
193191
reqCtx.ResponseSize = len(body.ResponseBody.Body)
194192
reqCtx.ResponseComplete = true
195193
loggerVerbose.Info("Response generated", "response", res)
@@ -205,7 +203,7 @@ func (s *Server) HandleStreaming(
205203
responseText := string(body.ResponseBody.Body)
206204
if strings.Contains(responseText, streamingEndMsg) {
207205
parsedResp := ParseRespForUsage(ctx, responseText, loggerVerbose)
208-
reqCtx.Response = parsedResp
206+
reqCtx.Usage = parsedResp.Usage
209207
}
210208

211209
if body.ResponseBody.EndOfStream {

pkg/epp/handlers/response_test.go

+12-16
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func TestHandleResponseBody(t *testing.T) {
6565
name string
6666
req *extProcPb.ProcessingRequest_ResponseBody
6767
reqCtx *RequestContext
68-
want Response
68+
want Usage
6969
wantErr bool
7070
}{
7171
{
@@ -75,12 +75,10 @@ func TestHandleResponseBody(t *testing.T) {
7575
Body: []byte(body),
7676
},
7777
},
78-
want: Response{
79-
Usage: Usage{
80-
PromptTokens: 11,
81-
TotalTokens: 111,
82-
CompletionTokens: 100,
83-
},
78+
want: Usage{
79+
PromptTokens: 11,
80+
TotalTokens: 111,
81+
CompletionTokens: 100,
8482
},
8583
},
8684
{
@@ -100,7 +98,7 @@ func TestHandleResponseBody(t *testing.T) {
10098
},
10199
},
102100
reqCtx: &RequestContext{
103-
Streaming: true,
101+
modelServerStreaming: true,
104102
},
105103
wantErr: false,
106104
// In the middle of streaming response, so request context response is not set yet.
@@ -113,15 +111,13 @@ func TestHandleResponseBody(t *testing.T) {
113111
},
114112
},
115113
reqCtx: &RequestContext{
116-
Streaming: true,
114+
modelServerStreaming: true,
117115
},
118116
wantErr: false,
119-
want: Response{
120-
Usage: Usage{
121-
PromptTokens: 7,
122-
TotalTokens: 17,
123-
CompletionTokens: 10,
124-
},
117+
want: Usage{
118+
PromptTokens: 7,
119+
TotalTokens: 17,
120+
CompletionTokens: 10,
125121
},
126122
},
127123
}
@@ -141,7 +137,7 @@ func TestHandleResponseBody(t *testing.T) {
141137
return
142138
}
143139

144-
if diff := cmp.Diff(test.want, reqCtx.Response); diff != "" {
140+
if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" {
145141
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)
146142
}
147143
})

pkg/epp/handlers/server.go

+29-6
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
128128
reqCtx.ResponseCompleteTimestamp = time.Now()
129129
metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
130130
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
131-
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Response.Usage.PromptTokens)
132-
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Response.Usage.CompletionTokens)
131+
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens)
132+
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens)
133133
}
134-
if reqCtx.Streaming {
134+
if reqCtx.modelServerStreaming {
135135
logger.V(logutil.DEBUG).Info("Request context after HandleResponseBody", "context", reqCtx)
136136
} else {
137137
loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx)
@@ -149,7 +149,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
149149
}
150150
}
151151

152-
if !reqCtx.Streaming {
152+
if !reqCtx.modelServerStreaming {
153153
loggerVerbose.Info("Response generated", "response", resp)
154154
} else {
155155
logger.V(logutil.DEBUG).Info("Response generated", "response", resp)
@@ -224,9 +224,32 @@ type RequestContext struct {
224224
RequestReceivedTimestamp time.Time
225225
ResponseCompleteTimestamp time.Time
226226
RequestSize int
227-
Response Response
227+
Usage Usage
228228
ResponseSize int
229229
ResponseComplete bool
230230
ResponseStatusCode string
231-
Streaming bool
231+
232+
RequestState StreamRequestState
233+
modelServerStreaming bool
234+
235+
reqHeaderResp *extProcPb.ProcessingResponse
236+
reqBodyResp *extProcPb.ProcessingResponse
237+
reqTrailerResp *extProcPb.ProcessingResponse
238+
239+
respHeaderResp *extProcPb.ProcessingResponse
240+
respBodyResp *extProcPb.ProcessingResponse
241+
respTrailerResp *extProcPb.ProcessingResponse
232242
}
243+
244+
type StreamRequestState int
245+
246+
const (
247+
RequestReceived StreamRequestState = 0
248+
HeaderRequestResponseComplete StreamRequestState = 1
249+
BodyRequestResponsesComplete StreamRequestState = 2
250+
TrailerRequestResponsesComplete StreamRequestState = 3
251+
ResponseRecieved StreamRequestState = 4
252+
HeaderResponseResponseComplete StreamRequestState = 5
253+
BodyResponseResponsesComplete StreamRequestState = 6
254+
TrailerResponseResponsesComplete StreamRequestState = 7
255+
)

0 commit comments

Comments
 (0)