Skip to content

Commit 16822cc

Browse files
Merge pull request #266 from vast-ai/workergroup-cold-workers
adding cold workers for workergroups
2 parents d06ec44 + 953edad commit 16822cc

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vast.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,6 +2109,7 @@ def generate_ssh_key(auto_yes=False):
21092109
argument("--min_load", help="[NOTE: this field isn't currently used at the workergroup level] minimum floor load in perf units/s (token/s for LLms)", type=float),
21102110
argument("--target_util", help="[NOTE: this field isn't currently used at the workergroup level] target capacity utilization (fraction, max 1.0, default 0.9)", type=float),
21112111
argument("--cold_mult", help="[NOTE: this field isn't currently used at the workergroup level]cold/stopped instance capacity target as multiple of hot capacity target (default 2.0)", type=float),
2112+
argument("--cold_workers", help="min number of workers to keep 'cold' for this workergroup", type=int),
21122113
argument("--auto_instance", help="unused", type=str, default="prod"),
21132114
usage="vastai workergroup create [OPTIONS]",
21142115
help="Create a new autoscale group",
@@ -2131,7 +2132,7 @@ def create__workergroup(args):
21312132
#query = {"verified": {"eq": True}, "external": {"eq": False}, "rentable": {"eq": True}, "rented": {"eq": False}}
21322133
search_params = (args.search_params if args.search_params is not None else "" + query).strip()
21332134

2134-
json_blob = {"client_id": "me", "min_load": args.min_load, "target_util": args.target_util, "cold_mult": args.cold_mult, "test_workers" : args.test_workers, "template_hash": args.template_hash, "template_id": args.template_id, "search_params": search_params, "launch_args": args.launch_args, "gpu_ram": args.gpu_ram, "endpoint_name": args.endpoint_name, "endpoint_id": args.endpoint_id, "autoscaler_instance": args.auto_instance}
2135+
json_blob = {"client_id": "me", "min_load": args.min_load, "target_util": args.target_util, "cold_mult": args.cold_mult, "cold_workers" : args.cold_workers, "test_workers" : args.test_workers, "template_hash": args.template_hash, "template_id": args.template_id, "search_params": search_params, "launch_args": args.launch_args, "gpu_ram": args.gpu_ram, "endpoint_name": args.endpoint_name, "endpoint_id": args.endpoint_id, "autoscaler_instance": args.auto_instance}
21352136

21362137
if (args.explain):
21372138
print("request json: ")
@@ -5432,6 +5433,7 @@ def transfer__credit(args: argparse.Namespace):
54325433
argument("--min_load", help="minimum floor load in perf units/s (token/s for LLms)", type=float),
54335434
argument("--target_util", help="target capacity utilization (fraction, max 1.0, default 0.9)", type=float),
54345435
argument("--cold_mult", help="cold/stopped instance capacity target as multiple of hot capacity target (default 2.5)", type=float),
5436+
argument("--cold_workers", help="min number of workers to keep 'cold' for this workergroup", type=int),
54355437
argument("--test_workers",help="number of workers to create to get an performance estimate for while initializing workergroup (default 3)", type=int),
54365438
argument("--gpu_ram", help="estimated GPU RAM req (independent of search string)", type=float),
54375439
argument("--template_hash", help="template hash (**Note**: if you use this field, you can skip search_params, as they are automatically inferred from the template)", type=str),
@@ -5456,7 +5458,7 @@ def update__workergroup(args):
54565458
query = " verified=True rentable=True rented=False"
54575459
if args.search_params is not None:
54585460
query = args.search_params + query
5459-
json_blob = {"client_id": "me", "autojob_id": args.id, "min_load": args.min_load, "target_util": args.target_util, "cold_mult": args.cold_mult, "test_workers" : args.test_workers, "template_hash": args.template_hash, "template_id": args.template_id, "search_params": query, "launch_args": args.launch_args, "gpu_ram": args.gpu_ram, "endpoint_name": args.endpoint_name, "endpoint_id": args.endpoint_id}
5461+
json_blob = {"client_id": "me", "autojob_id": args.id, "min_load": args.min_load, "target_util": args.target_util, "cold_mult": args.cold_mult, "cold_workers": args.cold_workers, "test_workers" : args.test_workers, "template_hash": args.template_hash, "template_id": args.template_id, "search_params": query, "launch_args": args.launch_args, "gpu_ram": args.gpu_ram, "endpoint_name": args.endpoint_name, "endpoint_id": args.endpoint_id}
54605462
if (args.explain):
54615463
print("request json: ")
54625464
print(json_blob)

0 commit comments

Comments
 (0)