Skip to content

Add /models/prune #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.23.7
require (
github.com/containerd/containerd/v2 v2.0.4
github.com/containerd/platforms v1.0.0-rc.1
github.com/docker/model-distribution v0.0.0-20250512190053-b3792c042d57
github.com/docker/model-distribution v0.0.0-20250521125643-a9b8592eff18
github.com/jaypipes/ghw v0.16.0
github.com/opencontainers/go-digest v1.0.0
github.com/opencontainers/image-spec v1.1.1
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZ
github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M=
github.com/docker/model-distribution v0.0.0-20250512190053-b3792c042d57 h1:ZqfKknb+0/uJid8XLFwSl/osjE+WuS6o6I3dh3ZqO4U=
github.com/docker/model-distribution v0.0.0-20250512190053-b3792c042d57/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/docker/model-distribution v0.0.0-20250521121637-af0fc7f16ad1 h1:akgUvCRqic2fLyq5zhWF1I6xunWXHwKaQBZYRStpqf0=
github.com/docker/model-distribution v0.0.0-20250521121637-af0fc7f16ad1/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/docker/model-distribution v0.0.0-20250521123835-b72b1c87354a h1:VMwswLhzJVPhovYLlYEyzVYMjEqWDewcyKoKA8q89PY=
github.com/docker/model-distribution v0.0.0-20250521123835-b72b1c87354a/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/docker/model-distribution v0.0.0-20250521125643-a9b8592eff18 h1:tB4cBxmfR35osqXeKrqUbxBMARCUv+YRKJfqyrb9Qg0=
github.com/docker/model-distribution v0.0.0-20250521125643-a9b8592eff18/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
Expand Down
7 changes: 6 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ func main() {
modelManager,
log.WithFields(logrus.Fields{"component": "llama.cpp"}),
llamaServerPath,
func() string { wd, _ := os.Getwd(); return wd }(),
func() string {
wd, _ := os.Getwd()
d := filepath.Join(wd, "updated-inference")
_ = os.MkdirAll(d, 0o755)
return d
}(),
)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
Expand Down
24 changes: 24 additions & 0 deletions pkg/diskusage/diskusage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package diskusage

import (
"io/fs"
"path/filepath"
)

func Size(path string) (float64, error) {
var size int64
err := filepath.WalkDir(path, func(_ string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.Type().IsRegular() {
info, err := d.Info()
if err != nil {
return err
}
size += info.Size()
}
return nil
})
return float64(size), err
}
2 changes: 2 additions & 0 deletions pkg/inference/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,6 @@ type Backend interface {
Run(ctx context.Context, socket, model string, mode BackendMode) error
// Status returns a description of the backend's state.
Status() string
// GetDiskUsage returns the disk usage of the backend.
GetDiskUsage() (float64, error)
}
9 changes: 9 additions & 0 deletions pkg/inference/backends/llamacpp/llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"runtime"
"strconv"

"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
Expand Down Expand Up @@ -199,3 +200,11 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
func (l *llamaCpp) Status() string {
return l.status
}

func (l *llamaCpp) GetDiskUsage() (float64, error) {
size, err := diskusage.Size(l.updatedServerStoragePath)
if err != nil {
return 0, fmt.Errorf("error while getting store size: %v", err)
}
return size, nil
}
4 changes: 4 additions & 0 deletions pkg/inference/backends/mlx/mlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.Back
func (m *mlx) Status() string {
return "not running"
}

func (m *mlx) GetDiskUsage() (float64, error) {
return 0, nil
}
4 changes: 4 additions & 0 deletions pkg/inference/backends/vllm/vllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.Bac
func (v *vLLM) Status() string {
return "not running"
}

func (v *vLLM) GetDiskUsage() (float64, error) {
return 0, nil
}
31 changes: 31 additions & 0 deletions pkg/inference/models/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-distribution/registry"
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/logging"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -97,6 +98,7 @@ func (m *Manager) routeHandlers() map[string]http.HandlerFunc {
"GET " + inference.ModelsPrefix + "/{name...}": m.handleGetModel,
"DELETE " + inference.ModelsPrefix + "/{name...}": m.handleDeleteModel,
"POST " + inference.ModelsPrefix + "/{nameAndAction...}": m.handleModelAction,
"DELETE " + inference.ModelsPrefix + "/prune": m.handlePrune,
"GET " + inference.InferencePrefix + "/{backend}/v1/models": m.handleOpenAIGetModels,
"GET " + inference.InferencePrefix + "/{backend}/v1/models/{name...}": m.handleOpenAIGetModel,
"GET " + inference.InferencePrefix + "/v1/models": m.handleOpenAIGetModels,
Expand Down Expand Up @@ -399,6 +401,35 @@ func (m *Manager) handlePushModel(w http.ResponseWriter, r *http.Request, model
}
}

// handlePrune handles DELETE <inference-prefix>/models/prune requests.
func (m *Manager) handlePrune(w http.ResponseWriter, _ *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}

if err := m.distributionClient.ResetStore(); err != nil {
m.log.Warnf("Failed to prune models: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

// GetDiskUsage returns the disk usage of the model store.
func (m *Manager) GetDiskUsage() (float64, error, int) {
if m.distributionClient == nil {
return 0, errors.New("model distribution service unavailable"), http.StatusServiceUnavailable
}

storePath := m.distributionClient.GetStorePath()
size, err := diskusage.Size(storePath)
if err != nil {
return 0, fmt.Errorf("error while getting store size: %v", err), http.StatusInternalServerError
}

return size, nil, http.StatusOK
}

// ServeHTTP implement net/http.Handler.ServeHTTP.
func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.router.ServeHTTP(w, r)
Expand Down
31 changes: 31 additions & 0 deletions pkg/inference/scheduling/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scheduling

import (
"strings"
"time"

"github.com/docker/model-runner/pkg/inference"
)
Expand Down Expand Up @@ -42,3 +43,33 @@ type OpenAIInferenceRequest struct {
// Model is the requested model name.
Model string `json:"model"`
}

// BackendStatus represents information about a running backend
type BackendStatus struct {
// BackendName is the name of the backend
BackendName string `json:"backend_name"`
// ModelName is the name of the model loaded in the backend
ModelName string `json:"model_name"`
// Mode is the mode the backend is operating in
Mode string `json:"mode"`
// LastUsed represents when this (backend, model, mode) tuple was last used
LastUsed time.Time `json:"last_used,omitempty"`
}

// DiskUsage represents the disk usage of the models and default backend.
type DiskUsage struct {
ModelsDiskUsage float64 `json:"models_disk_usage"`
DefaultBackendDiskUsage float64 `json:"default_backend_disk_usage"`
}

// UnloadRequest is used to specify which models to unload.
type UnloadRequest struct {
All bool `json:"all"`
Backend string `json:"backend"`
Model string `json:"model"`
}

// UnloadResponse is used to return the number of unloaded runners (backend, model).
type UnloadResponse struct {
UnloadedRunners int `json:"unloaded_runners"`
}
36 changes: 36 additions & 0 deletions pkg/inference/scheduling/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,42 @@ func (l *loader) evict(idleOnly bool) int {
return len(l.runners)
}

// evictRunner evicts a specific runner. The caller must hold the loader lock.
// It returns the number of remaining runners.
func (l *loader) evictRunner(backend, model string) int {
allBackends := backend == ""
for r, slot := range l.runners {
if (allBackends || r.backend == backend) && r.model == model {
l.log.Infof("Evicting %s backend runner with model %s in %s mode",
r.backend, r.model, r.mode,
)
l.slots[slot].terminate()
l.slots[slot] = nil
l.availableMemory += l.allocations[slot]
l.allocations[slot] = 0
l.timestamps[slot] = time.Time{}
delete(l.runners, r)
}
}
return len(l.runners)
}

// Unload unloads runners and returns the number of unloaded runners.
func (l *loader) Unload(ctx context.Context, unload UnloadRequest) int {
if !l.lock(ctx) {
return 0
}
defer l.unlock()

return len(l.runners) - func() int {
if unload.All {
return l.evict(false)
} else {
return l.evictRunner(unload.Backend, unload.Model)
}
}()
}

// stopAndDrainTimer stops and drains a timer without knowing if it was running.
func stopAndDrainTimer(timer *time.Timer) {
timer.Stop()
Expand Down
93 changes: 93 additions & 0 deletions pkg/inference/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"time"

"github.com/docker/model-distribution/distribution"
"github.com/docker/model-runner/pkg/inference"
Expand Down Expand Up @@ -81,6 +82,9 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
m[route] = s.handleOpenAIInference
}
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends
m["GET "+inference.InferencePrefix+"/df"] = s.GetDiskUsage
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
return m
}

Expand Down Expand Up @@ -224,6 +228,95 @@ func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
s.installer = newInstaller(s.log, s.backends, httpClient)
}

// GetRunningBackends returns information about all running backends
func (s *Scheduler) GetRunningBackends(w http.ResponseWriter, r *http.Request) {
runningBackends := s.getLoaderStatus()

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(runningBackends); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}

// getLoaderStatus returns information about all running backends managed by the loader
func (s *Scheduler) getLoaderStatus() []BackendStatus {
if !s.loader.lock(context.Background()) {
return []BackendStatus{}
}
defer s.loader.unlock()

result := make([]BackendStatus, 0, len(s.loader.runners))

for key, slot := range s.loader.runners {
if s.loader.slots[slot] != nil {
status := BackendStatus{
BackendName: key.backend,
ModelName: key.model,
Mode: key.mode.String(),
LastUsed: time.Time{},
}

if s.loader.references[slot] == 0 {
status.LastUsed = s.loader.timestamps[slot]
}

result = append(result, status)
}
}

return result
}

func (s *Scheduler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) {
modelsDiskUsage, err, httpCode := s.modelManager.GetDiskUsage()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get models disk usage: %v", err), httpCode)
return
}

// TODO: Get disk usage for each backend once the backends are implemented.
defaultBackendDiskUsage, err := s.defaultBackend.GetDiskUsage()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get disk usage for %s: %v", s.defaultBackend.Name(), err), http.StatusInternalServerError)
return
}

diskUsage := DiskUsage{modelsDiskUsage, defaultBackendDiskUsage}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(diskUsage); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}

// Unload unloads the specified runners (backend, model) from the backend.
// Currently, this doesn't work for runners that are handling an OpenAI request.
func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
if _, ok := err.(*http.MaxBytesError); ok {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "unknown error", http.StatusInternalServerError)
}
return
}

var unloadRequest UnloadRequest
if err := json.Unmarshal(body, &unloadRequest); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}

unloadedRunners := UnloadResponse{s.loader.Unload(r.Context(), unloadRequest)}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(unloadedRunners); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}

// ServeHTTP implements net/http.Handler.ServeHTTP.
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
Expand Down