In [1]:
import json
import os
import re
from itertools import islice
from typing import Optional, Union

In [3]:
def parse_args(args: str):
    generation_parameters_dict = None
    hf_overrides_dict = None

    # Use a more specific pattern to extract key={...} pairs
    dict_pattern = re.compile(r"(\w+)=\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}")
    dict_matches = dict_pattern.findall(args)

    for key, value in dict_matches:
        key = key.strip()
        # Add braces back to value since they were captured separately
        value = "{" + value + "}"

        if key == "generation_parameters":
            gen_params = re.sub(r"(\w+):", r'"\1":', value)
            generation_parameters_dict = json.loads(gen_params)
        elif key == "hf_overrides":
            # Handle hf_overrides similarly
            hf_params = re.sub(r"(\w+):", r'"\1":', value)
            hf_overrides_dict = json.loads(hf_params)

    # Remove all dict parameters from args before processing simple parameters
    args = re.sub(r"\w+=\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\},?", "", args).strip(",")

    # Parse remaining simple key=value pairs
    model_config = {}
    if args:  # Only process if there are remaining args
        for pair in args.split(","):
            pair = pair.strip()
            if "=" in pair:
                k, v = pair.split("=", 1)
                model_config[k] = v
            elif pair:  # Boolean flags without values
                model_config[pair] = True

    # Add the parsed dict parameters back
    if generation_parameters_dict is not None:
        model_config["generation_parameters"] = generation_parameters_dict
    if hf_overrides_dict is not None:
        model_config["hf_overrides"] = hf_overrides_dict

    return model_config

In [7]:
parse_args("model_name=gpt2,use_cache,generation_parameters={temperature:0.7},use_chat_template=true,hf_overrides={num_experts_per_tok:6}")

{'model_name': 'gpt2',
 'use_cache': True,
 'use_chat_template': 'true',
 'generation_parameters': {'temperature': 0.7},
 'hf_overrides': {'num_experts_per_tok': 6}}

In [9]:
parse_args("model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:30000,temperature:0.6,top_p:0.95,returns_logits:false},use_chat_template=true,tensor_parallel_size=$NUM_GPUS,enable_thinking=true")

{'model_name': '$MODEL',
 'dtype': 'bfloat16',
 'max_model_length': '32768',
 'gpu_memory_utilization': '0.8',
 'use_chat_template': 'true',
 'tensor_parallel_size': '$NUM_GPUS',
 'enable_thinking': 'true',
 'generation_parameters': {'max_new_tokens': 30000,
  'temperature': 0.6,
  'top_p': 0.95,
  'returns_logits': False}}

In [11]:
args = "model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:30000,temperature:0.6,top_p:0.95,returns_logits:false},use_chat_template=true,tensor_parallel_size=$NUM_GPUS,enable_thinking=true"
re.sub(r"\w+=\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\},?", "", args).strip(",")

'model_name=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,use_chat_template=true,tensor_parallel_size=$NUM_GPUS,enable_thinking=true'