Skip to content

Improve CORS config #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ docker-run: docker-build
-e LLAMA_SERVER_PATH=/app/bin \
-e MODELS_PATH=/models \
-e LLAMA_ARGS="$(LLAMA_ARGS)" \
-e DMR_ORIGINS="$(DMR_ORIGINS)" \
$(DOCKER_IMAGE)

# Show help
13 changes: 9 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
@@ -47,10 +47,14 @@ func main() {
llamacpp.ShouldUpdateServerLock.Unlock()
}

modelManager := models.NewManager(log, models.ClientConfig{
StoreRootPath: modelPath,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
})
modelManager := models.NewManager(
log,
models.ClientConfig{
StoreRootPath: modelPath,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
},
nil,
)

llamaServerPath := os.Getenv("LLAMA_SERVER_PATH")
if llamaServerPath == "" {
@@ -85,6 +89,7 @@ func main() {
llamaCppBackend,
modelManager,
http.DefaultClient,
nil,
)

router := routing.NewNormalizedServeMux()
51 changes: 47 additions & 4 deletions pkg/inference/cors.go
Original file line number Diff line number Diff line change
@@ -2,13 +2,30 @@ package inference

import (
"net/http"
"os"
"strings"
)

// CorsMiddleware handles CORS and OPTIONS preflight requests and sets the necessary CORS headers.
func CorsMiddleware(next http.Handler) http.Handler {
// CorsMiddleware handles CORS and OPTIONS preflight requests with optional allowedOrigins.
// If allowedOrigins is nil or empty, it falls back to getAllowedOrigins().
func CorsMiddleware(allowedOrigins []string, next http.Handler) http.Handler {
if len(allowedOrigins) == 0 {
allowedOrigins = getAllowedOrigins()
}

// Explicitly disable all origins.
if allowedOrigins == nil {
return next
}

allowAll := len(allowedOrigins) == 1 && allowedOrigins[0] == "*"
allowedSet := make(map[string]struct{}, len(allowedOrigins))
for _, o := range allowedOrigins {
allowedSet[o] = struct{}{}
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers for all requests.
if origin := r.Header.Get("Origin"); origin != "" {
if origin := r.Header.Get("Origin"); origin != "" && (allowAll || originAllowed(origin, allowedSet)) {
w.Header().Set("Access-Control-Allow-Origin", origin)
}

@@ -24,3 +41,29 @@ func CorsMiddleware(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}

func originAllowed(origin string, allowedSet map[string]struct{}) bool {
_, ok := allowedSet[origin]
return ok
}

// getAllowedOrigins retrieves allowed origins from the DMR_ORIGINS environment variable.
// If the variable is not set it returns nil, indicating no origins are allowed.
func getAllowedOrigins() (origins []string) {
dmrOrigins := os.Getenv("DMR_ORIGINS")
if dmrOrigins == "" {
return nil
}

for _, o := range strings.Split(dmrOrigins, ",") {
if trimmed := strings.TrimSpace(o); trimmed != "" {
origins = append(origins, trimmed)
}
}

if len(origins) == 0 {
return nil
}

return origins
}
10 changes: 5 additions & 5 deletions pkg/inference/models/manager.go
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ type ClientConfig struct {
}

// NewManager creates a new model's manager.
func NewManager(log logging.Logger, c ClientConfig) *Manager {
func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Manager {
// Create the model distribution client.
distributionClient, err := distribution.NewClient(
distribution.WithStoreRootPath(c.StoreRootPath),
@@ -78,7 +78,7 @@ func NewManager(log logging.Logger, c ClientConfig) *Manager {
http.Error(w, "not found", http.StatusNotFound)
})

for route, handler := range m.routeHandlers() {
for route, handler := range m.routeHandlers(allowedOrigins) {
m.router.HandleFunc(route, handler)
}

@@ -91,7 +91,7 @@ func NewManager(log logging.Logger, c ClientConfig) *Manager {
return m
}

func (m *Manager) routeHandlers() map[string]http.HandlerFunc {
func (m *Manager) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc {
handlers := map[string]http.HandlerFunc{
"POST " + inference.ModelsPrefix + "/create": m.handleCreateModel,
"GET " + inference.ModelsPrefix: m.handleGetModels,
@@ -105,14 +105,14 @@ func (m *Manager) routeHandlers() map[string]http.HandlerFunc {
}
for route, handler := range handlers {
if strings.HasPrefix(route, "GET ") {
handlers[route] = inference.CorsMiddleware(handler).ServeHTTP
handlers[route] = inference.CorsMiddleware(allowedOrigins, handler).ServeHTTP
}
}
return handlers
}

func (m *Manager) GetRoutes() []string {
routeHandlers := m.routeHandlers()
routeHandlers := m.routeHandlers(nil)
routes := make([]string, 0, len(routeHandlers))
for route := range routeHandlers {
routes = append(routes, route)
11 changes: 6 additions & 5 deletions pkg/inference/scheduling/scheduler.go
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@ func NewScheduler(
defaultBackend inference.Backend,
modelManager *models.Manager,
httpClient *http.Client,
allowedOrigins []string,
) *Scheduler {
// Create the scheduler.
s := &Scheduler{
@@ -61,15 +62,15 @@ func NewScheduler(
http.Error(w, "not found", http.StatusNotFound)
})

for route, handler := range s.routeHandlers() {
for route, handler := range s.routeHandlers(allowedOrigins) {
s.router.HandleFunc(route, handler)
}

// Scheduler successfully initialized.
return s
}

func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc {
openAIRoutes := []string{
"POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions",
"POST " + inference.InferencePrefix + "/{backend}/v1/completions",
@@ -80,10 +81,10 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
}
m := make(map[string]http.HandlerFunc)
for _, route := range openAIRoutes {
m[route] = inference.CorsMiddleware(http.HandlerFunc(s.handleOpenAIInference)).ServeHTTP
m[route] = inference.CorsMiddleware(allowedOrigins, http.HandlerFunc(s.handleOpenAIInference)).ServeHTTP
// Register OPTIONS for CORS preflight.
optionsRoute := "OPTIONS " + route[strings.Index(route, " "):]
m[optionsRoute] = inference.CorsMiddleware(http.HandlerFunc(s.handleOpenAIInference)).ServeHTTP
m[optionsRoute] = inference.CorsMiddleware(allowedOrigins, http.HandlerFunc(s.handleOpenAIInference)).ServeHTTP
}
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends
@@ -93,7 +94,7 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
}

func (s *Scheduler) GetRoutes() []string {
routeHandlers := s.routeHandlers()
routeHandlers := s.routeHandlers(nil)
routes := make([]string, 0, len(routeHandlers))
for route := range routeHandlers {
routes = append(routes, route)