diff --git a/Makefile b/Makefile index 639c369..065de7b 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ ENVTEST_K8S_VERSION = 1.27.1 .PHONY: test -test: test-unit test-race test-integration test-e2e +test: test-unit test-integration test-e2e .PHONY: test-unit test-unit: diff --git a/cmd/lingo/main.go b/cmd/lingo/main.go index 2f9ae6f..113328a 100644 --- a/cmd/lingo/main.go +++ b/cmd/lingo/main.go @@ -67,6 +67,7 @@ func run() error { concurrency := getEnvInt("CONCURRENCY", 100) scaleDownDelay := getEnvInt("SCALE_DOWN_DELAY", 30) + backendRetries := getEnvInt("BACKEND_RETRIES", 1) var metricsAddr string var probeAddr string @@ -154,6 +155,7 @@ func run() error { proxy.MustRegister(metricsRegistry) proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager) + proxyHandler.MaxRetries = backendRetries proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler} statsHandler := &stats.Handler{ diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index a51d9d3..2c9d74d 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -1,153 +1,140 @@ package proxy import ( - "bytes" "context" - "encoding/json" "errors" - "fmt" - "io" "log" "net/http" "net/http/httputil" "net/url" - "strconv" - "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" - - "github.com/substratusai/lingo/pkg/deployments" - "github.com/substratusai/lingo/pkg/endpoints" - "github.com/substratusai/lingo/pkg/queue" ) +var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "http_response_time_seconds", + Help: "Duration of HTTP requests.", + Buckets: prometheus.DefBuckets, +}, []string{"model", "status_code"}) + +func MustRegister(r prometheus.Registerer) { + r.MustRegister(httpDuration) +} + +type DeploymentManager interface { + ResolveDeployment(model string) (string, bool) + AtLeastOne(model string) +} + +type EndpointManager interface { + AwaitHostAddress(ctx context.Context, service, portName string) (string, error) +} + +type QueueManager interface { + EnqueueAndWait(ctx context.Context, deploymentName, id string) func() +} + // Handler serves http requests for end-clients. // It is also responsible for triggering scale-from-zero. type Handler struct { - Deployments *deployments.Manager - Endpoints *endpoints.Manager - Queues *queue.Manager + Deployments DeploymentManager + Endpoints EndpointManager + Queues QueueManager + + MaxRetries int + RetryCodes map[int]struct{} } -func NewHandler(deployments *deployments.Manager, endpoints *endpoints.Manager, queues *queue.Manager) *Handler { - return &Handler{Deployments: deployments, Endpoints: endpoints, Queues: queues} +func NewHandler( + deployments DeploymentManager, + endpoints EndpointManager, + queues QueueManager, +) *Handler { + return &Handler{ + Deployments: deployments, + Endpoints: endpoints, + Queues: queues, + } +} + +var defaultRetryCodes = map[int]struct{}{ + http.StatusInternalServerError: {}, + http.StatusBadGateway: {}, + http.StatusServiceUnavailable: {}, + http.StatusGatewayTimeout: {}, } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var modelName string - captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w) - w = captureStatusRespWriter - timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.statusCode)).Observe(v) - })) - defer timer.ObserveDuration() - - id := uuid.New().String() - log.Printf("request: %v", r.URL) + log.Printf("url: %v", r.URL) + w.Header().Set("X-Proxy", "lingo") - var ( - proxyRequest *http.Request - err error - ) + pr := newProxyRequest(r) + defer pr.done() + // TODO: Only parse model for paths that would have a model. - modelName, proxyRequest, err = parseModel(r) - if err != nil || modelName == "" { - modelName = "unknown" - log.Printf("error reading model from request body: %v", err) - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("Bad request: unable to parse .model from JSON payload")) + if err := pr.parseModel(); err != nil { + pr.sendErrorResponse(w, http.StatusBadRequest, "unable to parse model: %v", err) return } - log.Println("model:", modelName) - deploy, found := h.Deployments.ResolveDeployment(modelName) - if !found { - log.Printf("deployment not found for model: %v", err) - w.WriteHeader(http.StatusNotFound) - w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName))) + log.Println("model:", pr.model) + + var backendExists bool + pr.backendDeployment, backendExists = h.Deployments.ResolveDeployment(pr.model) + if !backendExists { + pr.sendErrorResponse(w, http.StatusNotFound, "model not found: %v", pr.model) return } - h.Deployments.AtLeastOne(deploy) + // Ensure the backend is scaled to at least one Pod. + h.Deployments.AtLeastOne(pr.backendDeployment) - log.Println("Entering queue", id) - complete := h.Queues.EnqueueAndWait(r.Context(), deploy, id) - log.Println("Admitted into queue", id) + log.Printf("Entering queue: %v", pr.id) + + // Wait to until the request is admitted into the queue before proceeding with + // serving the request. + complete := h.Queues.EnqueueAndWait(r.Context(), pr.backendDeployment, pr.id) defer complete() - // abort when deployment was removed meanwhile - if _, exists := h.Deployments.ResolveDeployment(modelName); !exists { - log.Printf("deployment not active for model removed: %v", err) - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName))) + log.Printf("Admitted into queue: %v", pr.id) + + // After waiting for the request to be admitted, double check that the model + // still exists. It's possible that the model was deleted while waiting. + // This would lead to a long subequent wait with the host lookup. + pr.backendDeployment, backendExists = h.Deployments.ResolveDeployment(pr.model) + if !backendExists { + pr.sendErrorResponse(w, http.StatusNotFound, "model not found after being dequeued: %v", pr.model) return } - log.Println("Waiting for IPs", id) - host, err := h.Endpoints.AwaitHostAddress(r.Context(), deploy, "http") + h.proxyHTTP(w, pr) +} + +// AdditionalProxyRewrite is an injection point for modifying proxy requests. +// Used in tests. +var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {} + +func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) { + log.Printf("Waiting for host: %v", pr.id) + + host, err := h.Endpoints.AwaitHostAddress(pr.r.Context(), pr.backendDeployment, "http") if err != nil { - log.Printf("error while finding the host address %v", err) switch { case errors.Is(err, context.Canceled): - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte("Request cancelled")) + pr.sendErrorResponse(w, http.StatusInternalServerError, "request cancelled while finding host: %v", err) return case errors.Is(err, context.DeadlineExceeded): - w.WriteHeader(http.StatusGatewayTimeout) - _, _ = w.Write([]byte(fmt.Sprintf("Request timed out for model: %v", modelName))) + pr.sendErrorResponse(w, http.StatusGatewayTimeout, "request timeout while finding host: %v", err) return default: - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte("Internal server error")) + pr.sendErrorResponse(w, http.StatusGatewayTimeout, "unable to find host: %v", err) return } } - log.Printf("Got host: %v, id: %v\n", host, id) - - // TODO: Avoid creating new reverse proxies for each request. - // TODO: Consider implementing a round robin scheme. - log.Printf("Proxying request to host %v: %v\n", host, id) - newReverseProxy(host).ServeHTTP(w, proxyRequest) -} - -// parseModel parses the model name from the request -// returns empty string when none found or an error for failures on the proxy request object -func parseModel(r *http.Request) (string, *http.Request, error) { - if model := r.Header.Get("X-Model"); model != "" { - return model, r, nil - } - // parse request body for model name, ignore errors - body, err := io.ReadAll(r.Body) - if err != nil { - return "", r, nil - } - var payload struct { - Model string `json:"model"` - } - var model string - if err := json.Unmarshal(body, &payload); err == nil { - model = payload.Model - } + log.Printf("Got host: %v, id: %v\n", host, pr.id) - // create new request object - proxyReq, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), bytes.NewReader(body)) - if err != nil { - return "", nil, fmt.Errorf("create proxy request: %w", err) - } - proxyReq.Header = r.Header - if err := proxyReq.ParseForm(); err != nil { - return "", nil, fmt.Errorf("parse proxy form: %w", err) - } - return model, proxyReq, nil -} - -// AdditionalProxyRewrite is an injection point for modifying proxy requests. -// Used in tests. -var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {} - -func newReverseProxy(host string) *httputil.ReverseProxy { proxy := &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { r.SetURL(&url.URL{ @@ -158,5 +145,50 @@ func newReverseProxy(host string) *httputil.ReverseProxy { AdditionalProxyRewrite(r) }, } - return proxy + + proxy.ModifyResponse = func(r *http.Response) error { + // Record the response for metrics. + pr.status = r.StatusCode + + // This point is reached if a response code is received. + if h.isRetryCode(r.StatusCode) && pr.attempt < h.MaxRetries { + // Returning an error will trigger the ErrorHandler. + return ErrRetry + } + + return nil + } + + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + // This point could be reached if a bad response code was sent by the backend + // or + // if there was an issue with the connection and no response was ever received. + if err != nil && r.Context().Err() == nil && pr.attempt < h.MaxRetries { + pr.attempt++ + + log.Printf("Retrying request (%v/%v): %v", pr.attempt, h.MaxRetries, pr.id) + h.proxyHTTP(w, pr) + return + } + + if !errors.Is(err, ErrRetry) { + pr.sendErrorResponse(w, http.StatusBadGateway, "proxy: exceeded retries: %v/%v", pr.attempt, h.MaxRetries) + } + } + + log.Printf("Proxying request to host %v: %v\n", host, pr.id) + proxy.ServeHTTP(w, pr.httpRequest()) +} + +var ErrRetry = errors.New("retry") + +func (h *Handler) isRetryCode(status int) bool { + var retry bool + // TODO: avoid the nil check here and set a default map in the constructor. + if h.RetryCodes != nil { + _, retry = h.RetryCodes[status] + } else { + _, retry = defaultRetryCodes[status] + } + return retry } diff --git a/pkg/proxy/handler_test.go b/pkg/proxy/handler_test.go new file mode 100644 index 0000000..4d15d77 --- /dev/null +++ b/pkg/proxy/handler_test.go @@ -0,0 +1,248 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + io_prometheus_client "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHandler(t *testing.T) { + const ( + model1 = "model1" + model2 = "model2" + + maxRetries = 3 + ) + models := map[string]string{ + model1: "deploy1", + model2: "deploy2", + } + + specs := map[string]struct { + reqBody string + reqHeaders map[string]string + + backendPanic bool + backendCode int + backendBody string + + expCode int + expBody string + expLabels map[string]string + expBackendRequestCount int + }{ + "no model": { + reqBody: "{}", + expCode: http.StatusBadRequest, + expBody: `{"error":"unable to parse model: no model specified"}` + "\n", + expLabels: map[string]string{ + "model": "", + "status_code": "400", + }, + expBackendRequestCount: 0, + }, + "model not found": { + reqBody: `{"model":"does-not-exist"}`, + expCode: http.StatusNotFound, + expBody: `{"error":"model not found: does-not-exist"}` + "\n", + expLabels: map[string]string{ + "model": "does-not-exist", + "status_code": "404", + }, + expBackendRequestCount: 0, + }, + "happy 200 model in body": { + reqBody: fmt.Sprintf(`{"model":%q}`, model1), + backendCode: http.StatusOK, + backendBody: `{"result":"ok"}`, + expCode: http.StatusOK, + expBody: `{"result":"ok"}`, + expLabels: map[string]string{ + "model": model1, + "status_code": "200", + }, + expBackendRequestCount: 1, + }, + "happy 200 model in header": { + reqBody: "{}", + reqHeaders: map[string]string{"X-Model": model1}, + backendCode: http.StatusOK, + backendBody: `{"result":"ok"}`, + expCode: http.StatusOK, + expBody: `{"result":"ok"}`, + expLabels: map[string]string{ + "model": model1, + "status_code": "200", + }, + expBackendRequestCount: 1, + }, + "retryable 500": { + reqBody: fmt.Sprintf(`{"model":%q}`, model1), + backendCode: http.StatusInternalServerError, + backendBody: `{"err":"oh no!"}`, + expCode: http.StatusInternalServerError, + expBody: `{"err":"oh no!"}`, + expLabels: map[string]string{ + "model": model1, + "status_code": "500", + }, + expBackendRequestCount: 1 + maxRetries, + }, + "not retryable 400": { + reqBody: fmt.Sprintf(`{"model":%q}`, model1), + backendCode: http.StatusBadRequest, + backendBody: `{"err":"bad request"}`, + expCode: http.StatusBadRequest, + expBody: `{"err":"bad request"}`, + expLabels: map[string]string{ + "model": model1, + "status_code": "400", + }, + expBackendRequestCount: 1, + }, + "good request but dropped connection": { + reqBody: fmt.Sprintf(`{"model":%q}`, model1), + backendPanic: true, + expCode: http.StatusBadGateway, + expBody: `{"error":"Bad Gateway"}` + "\n", + expLabels: map[string]string{ + "model": model1, + "status_code": "502", + }, + expBackendRequestCount: 1 + maxRetries, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + // Register metrics from a clean slate. + httpDuration.Reset() + metricsRegistry := prometheus.NewPedanticRegistry() + MustRegister(metricsRegistry) + + // Mock backend. + var backendRequestCount int + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendRequestCount++ + + bdy, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, spec.reqBody, string(bdy), "The request body should reach the backend") + + if spec.backendPanic { + // Panic should close connection. + // https://pkg.go.dev/net/http#Handler + panic("panicing on purpose") + } + + if spec.backendCode != 0 { + w.WriteHeader(spec.backendCode) + } + if spec.backendBody != "" { + _, _ = w.Write([]byte(spec.backendBody)) + } + })) + + // Setup handler. + deploys := &testDeploymentManager{models: models} + endpoints := &testEndpointManager{address: backend.Listener.Addr().String()} + queues := &testQueueManager{} + h := NewHandler(deploys, endpoints, queues) + h.MaxRetries = maxRetries + server := httptest.NewServer(h) + + // Issue request. + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, server.URL, strings.NewReader(spec.reqBody)) + require.NoError(t, err) + for k, v := range spec.reqHeaders { + req.Header.Add(k, v) + } + resp, err := client.Do(req) + require.NoError(t, err, "The client request should not fail") + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // Assert on response. + assert.Equal(t, spec.expCode, resp.StatusCode, "Unexpected response code to client") + assert.Equal(t, spec.expBody, string(respBody), "Unexpected response body to client") + assert.Equal(t, spec.expBackendRequestCount, backendRequestCount, "Unexpected number of requests sent to backend") + assert.Equal(t, spec.expBackendRequestCount, endpoints.hostRequestCount, "Unexpected number of requests for backend hosts") + + // Assert on metrics. + gathered, err := metricsRegistry.Gather() + require.NoError(t, err) + require.Len(t, gathered, 1) + require.Len(t, gathered[0].Metric, 1) + assert.NotEmpty(t, gathered[0].Metric[0].GetHistogram().GetSampleCount()) + assert.Equal(t, spec.expLabels, toMap(gathered[0].Metric[0].Label)) + }) + } +} + +func TestMetricsViaLinter(t *testing.T) { + registry := prometheus.NewPedanticRegistry() + MustRegister(registry) + + problems, err := testutil.GatherAndLint(registry) + require.NoError(t, err) + require.Empty(t, problems) +} + +type testDeploymentManager struct { + models map[string]string +} + +func (t *testDeploymentManager) ResolveDeployment(model string) (string, bool) { + deploy, ok := t.models[model] + return deploy, ok +} + +func (t *testDeploymentManager) AtLeastOne(model string) { + +} + +type testEndpointManager struct { + address string + + requestedService string + requestedPort string + + hostRequestCount int +} + +func (t *testEndpointManager) AwaitHostAddress(ctx context.Context, service, portName string) (string, error) { + t.hostRequestCount++ + t.requestedService = service + t.requestedPort = portName + return t.address, nil +} + +type testQueueManager struct { + requestedDeploymentName string + requestedID string +} + +func (t *testQueueManager) EnqueueAndWait(ctx context.Context, deploymentName, id string) func() { + t.requestedDeploymentName = deploymentName + t.requestedID = id + return func() {} +} + +func toMap(s []*io_prometheus_client.LabelPair) map[string]string { + r := make(map[string]string, len(s)) + for _, v := range s { + r[v.GetName()] = v.GetValue() + } + return r +} diff --git a/pkg/proxy/metrics.go b/pkg/proxy/metrics.go deleted file mode 100644 index 5cd229e..0000000 --- a/pkg/proxy/metrics.go +++ /dev/null @@ -1,32 +0,0 @@ -package proxy - -import ( - "net/http" - - "github.com/prometheus/client_golang/prometheus" -) - -var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: "http_response_time_seconds", - Help: "Duration of HTTP requests.", - Buckets: prometheus.DefBuckets, -}, []string{"model", "status_code"}) - -func MustRegister(r prometheus.Registerer) { - r.MustRegister(httpDuration) -} - -// captureStatusResponseWriter is a custom HTTP response writer that captures the status code. -type captureStatusResponseWriter struct { - http.ResponseWriter - statusCode int -} - -func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter { - return &captureStatusResponseWriter{ResponseWriter: responseWriter} -} - -func (srw *captureStatusResponseWriter) WriteHeader(code int) { - srw.statusCode = code - srw.ResponseWriter.WriteHeader(code) -} diff --git a/pkg/proxy/metrics_test.go b/pkg/proxy/metrics_test.go deleted file mode 100644 index 5940e44..0000000 --- a/pkg/proxy/metrics_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package proxy - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/go-logr/logr" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/testutil" - "github.com/prometheus/client_model/go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/substratusai/lingo/pkg/deployments" - "k8s.io/apimachinery/pkg/runtime" - utilruntime "k8s.io/apimachinery/pkg/util/runtime" - clientgoscheme "k8s.io/client-go/kubernetes/scheme" - - ctrl "sigs.k8s.io/controller-runtime" - "sigs.k8s.io/controller-runtime/pkg/cache" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/config" - "sigs.k8s.io/controller-runtime/pkg/manager" -) - -func TestMetrics(t *testing.T) { - specs := map[string]struct { - request *http.Request - expCode int - expLabels map[string]string - }{ - "with mode name": { - request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"model":"my_model"}`)), - expCode: http.StatusNotFound, - expLabels: map[string]string{ - "model": "my_model", - "status_code": "404", - }, - }, - "unknown model name": { - request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader("{}")), - expCode: http.StatusBadRequest, - expLabels: map[string]string{ - "model": "unknown", - "status_code": "400", - }, - }, - } - for name, spec := range specs { - t.Run(name, func(t *testing.T) { - httpDuration.Reset() - registry := prometheus.NewPedanticRegistry() - MustRegister(registry) - - deplMgr, err := deployments.NewManager(&fakeManager{}) - require.NoError(t, err) - h := NewHandler(deplMgr, nil, nil) - recorder := httptest.NewRecorder() - - // when - h.ServeHTTP(recorder, spec.request) - - // then - assert.Equal(t, spec.expCode, recorder.Code) - gathered, err := registry.Gather() - require.NoError(t, err) - require.Len(t, gathered, 1) - require.Len(t, gathered[0].Metric, 1) - assert.NotEmpty(t, gathered[0].Metric[0].GetHistogram().GetSampleCount()) - assert.Equal(t, spec.expLabels, toMap(gathered[0].Metric[0].Label)) - }) - } -} - -func TestMetricsViaLinter(t *testing.T) { - registry := prometheus.NewPedanticRegistry() - MustRegister(registry) - - problems, err := testutil.GatherAndLint(registry) - require.NoError(t, err) - require.Empty(t, problems) -} - -func toMap(s []*io_prometheus_client.LabelPair) map[string]string { - r := make(map[string]string, len(s)) - for _, v := range s { - r[v.GetName()] = v.GetValue() - } - return r -} - -// for test setup only -type fakeManager struct { - ctrl.Manager -} - -func (m *fakeManager) GetCache() cache.Cache { - return nil -} - -func (m *fakeManager) GetScheme() *runtime.Scheme { - s := runtime.NewScheme() - utilruntime.Must(clientgoscheme.AddToScheme(s)) - return s -} - -func (m *fakeManager) Add(_ manager.Runnable) error { - return nil -} - -func (m *fakeManager) GetLogger() logr.Logger { - return logr.Discard() -} - -func (m *fakeManager) GetControllerOptions() config.Controller { - return config.Controller{} -} - -func (m *fakeManager) GetClient() client.Client { - return nil -} diff --git a/pkg/proxy/request.go b/pkg/proxy/request.go new file mode 100644 index 0000000..476f097 --- /dev/null +++ b/pkg/proxy/request.go @@ -0,0 +1,126 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strconv" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" +) + +// proxyRequest keeps track of the state of a request that is to be proxied. +type proxyRequest struct { + // r is the original request. It is stored here so that is can be cloned + // and sent to the backend while preserving the original request body. + r *http.Request + // body will be stored here if the request body needed to be read + // in order to determine the model. + body []byte + + // metadata: + + id string + status int + model string + backendDeployment string + attempt int + + // metrics: + + timer *prometheus.Timer +} + +func newProxyRequest(r *http.Request) *proxyRequest { + pr := &proxyRequest{ + r: r, + id: uuid.New().String(), + status: http.StatusOK, + } + + pr.timer = prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { + httpDuration.WithLabelValues(pr.model, strconv.Itoa(pr.status)).Observe(v) + })) + + return pr + +} + +// done should be called when the original client request is complete. +func (pr *proxyRequest) done() { + pr.timer.ObserveDuration() +} + +// parseModel attempts to determine the model from the request. +// It first checks the "X-Model" header, and if that is not set, it +// attempts to unmarshal the request body as JSON and extract the +// .model field. +func (pr *proxyRequest) parseModel() error { + pr.model = pr.r.Header.Get("X-Model") + if pr.model != "" { + return nil + } + + var err error + pr.body, err = io.ReadAll(pr.r.Body) + if err != nil { + return fmt.Errorf("read: %w", err) + } + + var payload struct { + Model string `json:"model"` + } + if err := json.Unmarshal(pr.body, &payload); err != nil { + return fmt.Errorf("unmarshal json: %w", err) + } + pr.model = payload.Model + + if pr.model == "" { + return fmt.Errorf("no model specified") + } + + return nil +} + +// sendErrorResponse sends an error response to the client and +// records the status code. If the status code is 5xx, the error +// message is not included in the response body. +func (pr *proxyRequest) sendErrorResponse(w http.ResponseWriter, status int, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.Printf("sending error response: %v: %v", status, msg) + + pr.setStatus(w, status) + + if status >= 500 { + // Don't leak internal error messages to the client. + msg = http.StatusText(status) + } + + if err := json.NewEncoder(w).Encode(struct { + Error string `json:"error"` + }{ + Error: msg, + }); err != nil { + log.Printf("error encoding error response: %v", err) + } +} + +func (pr *proxyRequest) setStatus(w http.ResponseWriter, code int) { + pr.status = code + w.WriteHeader(code) +} + +// httpRequest returns a new http.Request that is a clone of the original +// request, preserving the original request body even if it was already +// read (i.e. if the body was inspected to determine the model). +func (pr *proxyRequest) httpRequest() *http.Request { + clone := pr.r.Clone(pr.r.Context()) + if pr.body != nil { + clone.Body = io.NopCloser(bytes.NewReader(pr.body)) + } + return clone +}