Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions api/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ type CreateEndpointInput struct {
WorkersMin int `json:"workersMin"`
WorkersMax int `json:"workersMax"`
FlashBootType string `json:"flashBootType"`
ModelReferences []string `json:"modelReferences"`
}

// there are many more fields in the result of the query but I just care about these for CLI port
type Endpoint struct {
Name string `json:"name"`
Id string
Id string `json:"id"`
}
type EndpointOut struct {
Data *EndpointData `json:"data"`
Expand Down Expand Up @@ -225,6 +226,53 @@ func UpdateEndpointTemplate(endpointId string, templateId string) (err error) {
return
}

func UpdateEndpointModel(endpointId string, endpointName string, modelReferences []string) (err error) {
input := Input{
Query: `
mutation saveEndpoint($input: EndpointInput!) {
saveEndpoint(input: $input) {
id
modelReferences
}
}
`,
Variables: map[string]interface{}{"input": map[string]interface{}{
"id": endpointId,
"name": endpointName,
"modelReferences": modelReferences,
}},
}
res, err := Query(input)
if err != nil {
return
}
defer res.Body.Close()
rawData, err := io.ReadAll(res.Body)
if err != nil {
return
}
if res.StatusCode != 200 {
err = fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData))
return
}
data := make(map[string]interface{})
if err = json.Unmarshal(rawData, &data); err != nil {
return
}
gqlErrors, ok := data["errors"].([]interface{})
if ok && len(gqlErrors) > 0 {
firstErr, _ := gqlErrors[0].(map[string]interface{})
err = errors.New(firstErr["message"].(string))
return
}
gqldata, ok := data["data"].(map[string]interface{})
if !ok || gqldata == nil {
err = fmt.Errorf("data is nil: %s", string(rawData))
return
}
return
}

func GetEndpoints() (endpoints []*Endpoint, err error) {
input := Input{
Query: `
Expand All @@ -248,6 +296,7 @@ func GetEndpoints() (endpoints []*Endpoint, err error) {
workersMin
workersStandby
gpuCount
modelReferences
env {
key
value
Expand Down Expand Up @@ -289,4 +338,4 @@ func GetEndpoints() (endpoints []*Endpoint, err error) {
}
endpoints = data.Data.Myself.Endpoints
return
}
}
15 changes: 15 additions & 0 deletions cmd/endpoint/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package endpoint

import (
"github.com/spf13/cobra"
)

var Cmd = &cobra.Command{
Use: "endpoint",
Short: "manage serverless endpoints",
Long: "manage serverless endpoints on runpod",
}

func init() {
Cmd.AddCommand(modelCmd)
}
44 changes: 44 additions & 0 deletions cmd/endpoint/update_model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package endpoint

import (
"fmt"

"github.com/runpod/runpodctl/api"
"github.com/spf13/cobra"
)

var modelCmd = &cobra.Command{
Use: "model <endpoint-id> <model-ref> [model-ref...]",
Short: "update model references on an endpoint",
Long: "set the model references (cached models) for a serverless endpoint",
Args: cobra.MinimumNArgs(2),
RunE: runModel,
}

func runModel(cmd *cobra.Command, args []string) error {
endpointID := args[0]
modelRefs := args[1:]

endpoints, err := api.GetEndpoints()
if err != nil {
return fmt.Errorf("failed to list endpoints: %w", err)
}

var endpointName string
for _, ep := range endpoints {
if ep.Id == endpointID {
endpointName = ep.Name
break
}
}
if endpointName == "" {
return fmt.Errorf("endpoint %s not found", endpointID)
}

if err := api.UpdateEndpointModel(endpointID, endpointName, modelRefs); err != nil {
return fmt.Errorf("failed to update endpoint model: %w", err)
}

fmt.Printf("updated model references for endpoint %s (%s)\n", endpointName, endpointID)
return nil
}
10 changes: 8 additions & 2 deletions cmd/project/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,7 @@ func startProject(networkVolumeId string) error {
echo "Process $1 has been killed with SIGKILL."
return
fi
sleep 1
done
done

echo "Failed to kill process with PID: $1 after SIGKILL attempt."
exit 1
Expand Down Expand Up @@ -580,6 +579,7 @@ func deployProject(networkVolumeId string) (endpointId string, err error) {
flashboot := true
flashBootType := "FLASHBOOT"
idleTimeout := 5
var modelRefs []string
endpointConfig, ok := config.Get("endpoint").(*toml.Tree)
if ok {
if min, ok := endpointConfig.Get("active_workers").(int64); ok {
Expand All @@ -597,6 +597,11 @@ func deployProject(networkVolumeId string) (endpointId string, err error) {
if idle, ok := endpointConfig.Get("idle_timeout").(int64); ok {
idleTimeout = int(idle)
}
if refs, ok := endpointConfig.Get("model_refs").([]interface{}); ok {
for _, r := range refs {
modelRefs = append(modelRefs, r.(string))
}
}
}
if err != nil {
deployedEndpointId, err = api.CreateEndpoint(&api.CreateEndpointInput{
Expand All @@ -610,6 +615,7 @@ func deployProject(networkVolumeId string) (endpointId string, err error) {
WorkersMin: minWorkers,
WorkersMax: maxWorkers,
FlashBootType: flashBootType,
ModelReferences: modelRefs,
})
if err != nil {
fmt.Println("error making endpoint")
Expand Down
5 changes: 5 additions & 0 deletions cmd/project/tomlBuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ active_workers = 0
max_workers = 3
flashboot = true

# model_refs - List of model references to cache on the endpoint workers.
# Format: "owner/model-name" or "owner/model-name:branch".
# Example: ["runpod/stable-diffusion-v1-5", "meta-llama/Llama-2-7b-chat-hf"]
# model_refs = []

[runtime]
# python_version - Python version to use for the project.
#
Expand Down
4 changes: 3 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/runpod/runpodctl/cmd/billing"
"github.com/runpod/runpodctl/cmd/config"
"github.com/runpod/runpodctl/cmd/datacenter"
"github.com/runpod/runpodctl/cmd/endpoint"
"github.com/runpod/runpodctl/cmd/doctor"
"github.com/runpod/runpodctl/cmd/gpu"
"github.com/runpod/runpodctl/cmd/hub"
Expand Down Expand Up @@ -86,6 +87,7 @@ func registerCommands() {
rootCmd.AddCommand(pod.Cmd)
rootCmd.AddCommand(serverless.Cmd)
rootCmd.AddCommand(template.Cmd)
rootCmd.AddCommand(endpoint.Cmd)
rootCmd.AddCommand(model.Cmd)
rootCmd.AddCommand(volume.Cmd)
rootCmd.AddCommand(registry.Cmd)
Expand Down Expand Up @@ -147,7 +149,7 @@ func registerCommands() {
// Version flag
rootCmd.Version = version
rootCmd.Flags().BoolP("version", "v", false, "print the version of runpodctl")
rootCmd.SetVersionTemplate(`runpodctl {{ .Version }}
rootCmd.SetVersionTemplate(`runpodctl {{ .version }}
`)
}

Expand Down
Loading