Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Implement sharded state loader #4690

Merged
merged 22 commits into from
May 16, 2024
67 changes: 67 additions & 0 deletions examples/save_sharded_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import argparse
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs

"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.

Example usage:

python save_sharded_state.py \
--model /path/to/load \
--quantization deepspeedfp \
--tensor-parallel-size 8 \
--output /path/to/save

Then, the model can be loaded with

llm = LLM(
model="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""

parser = argparse.ArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",
required=True,
type=str,
help="path to output checkpoint")
parser.add_argument("--pattern",
type=str,
help="string pattern of saved filenames")


def main(args):
engine_args = EngineArgs.from_cli_args(args)
model_path = engine_args.model
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = LLM(**dataclasses.asdict(engine_args))
# Prepare output directory
Path(args.output).mkdir(exist_ok=True)
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=args.output,
pattern=args.pattern,
max_size=5 * 1024**3)
aurickq marked this conversation as resolved.
Show resolved Hide resolved
# Copy metadata files to output directory
for file in os.listdir(model_path):
if not any(
file.endswith(ext) for ext in (".bin", ".pt", ".safetensors")):
aurickq marked this conversation as resolved.
Show resolved Hide resolved
shutil.copy(f"{model_path}/{file}", args.output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this set the config.json or quant_config.json next to the model weights to inform vLLM loading the model what type of quantization the model checkpoint is in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, it's just copying the config.json and quant_config.json from the input checkpoint that's being converted, which works for the use cases we've tested. Actually, I am not sure if it's correct to override these configs (or add a quant_config.json where it wasn't there previously) because then the config may mismatch the final loaded states?



if __name__ == "__main__":
args = parser.parse_args()
main(args)
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ class LoadFormat(str, enum.Enum):
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"


@dataclass
Expand Down
10 changes: 10 additions & 0 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def remove_lora(self, lora_id: int) -> bool:
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")

def save_sharded_state(self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self._run_workers("save_sharded_state",
path=path,
pattern=pattern,
max_size=max_size)

@abstractmethod
def _run_workers(
self,
Expand Down
91 changes: 91 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,94 @@ def load_model(self, *, model_config: ModelConfig,
vision_language_config)


class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_states.py` for creating a sharded checkpoint.
"""

DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}")

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
from safetensors.torch import load_file

from vllm.distributed import get_tensor_model_parallel_rank
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
model_config.model,
self.pattern.format(rank=rank, part="*"),
)
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!"
)
state_dict = dict(model.state_dict())
for path in filepaths:
for key, val in load_file(path).items():
state_dict[key].copy_(val)
state_dict.pop(key)
assert not state_dict
return model.eval()

@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from safetensors.torch import save_file

from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
part = 0
total_size = 0
state_dict: Dict[str, torch.Tensor] = {}
for name, tensor in model.state_dict().items():
param_size = tensor.nelement() * tensor.element_size()
if max_size is not None and total_size + param_size > max_size:
save_file(
state_dict,
os.path.join(path, pattern.format(rank=rank, part=part)),
)
part += 1
total_size = 0
state_dict = {}
state_dict[name] = tensor
total_size += param_size
if len(state_dict) > 0:
save_file(
state_dict,
os.path.join(path, pattern.format(rank=rank, part=part)),
)


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""

Expand All @@ -359,4 +447,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)

if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)

return DefaultModelLoader(load_config)
14 changes: 14 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ def load_model(self) -> None:
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ShardedStateLoader
ShardedStateLoader.save_model(
self.model,
path,
pattern=pattern,
max_size=max_size,
)

def set_block_size(self, block_size: int) -> None:
self.block_size = block_size

Expand Down
12 changes: 12 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ def init_device(self) -> None:
def load_model(self):
self.model_runner.load_model()

def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_runner.save_sharded_state(
path,
pattern=pattern,
max_size=max_size,
)

@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
Expand Down