Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion test/distributed/checkpoint/test_hf_safetensor_e2e.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Owner(s): ["oncall: distributed checkpointing"]

import importlib
import os

import torch
import torch.distributed.checkpoint as dist_cp
from torch import distributed as dist
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
Expand Down Expand Up @@ -116,6 +118,64 @@ def test_load_into_empty_dict(self) -> None:
)


class TestDistributedHFSafetensorsConsolidation(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_consolidate_to_one_file(self) -> None:
if importlib.util.find_spec("safetensors") is None:
print("safetensors not installed")
return

import safetensors

global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (self.world_size,)
mesh_1d = init_device_mesh(self.device_type, mesh_shape)

# Create local tensor with row-wise sharding
rows_per_rank = global_tensor.shape[0] // self.world_size
start_row = self.rank * rows_per_rank
end_row = start_row + rows_per_rank
local_tensor = global_tensor[start_row:end_row].clone()

# Create DTensor with row-wise sharding
dtensor = DTensor.from_local(
local_tensor,
device_mesh=mesh_1d,
placements=[Shard(0)],
shape=global_tensor.shape,
stride=(4, 1),
)

global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)

checkpoint_dir = self.temp_dir
consolidated_output_dir = os.path.join(checkpoint_dir, "consolidated")
os.makedirs(consolidated_output_dir, exist_ok=True)

state_dict_to_save = {"dtensor": dtensor}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp.HuggingFaceStorageWriter(
path=checkpoint_dir,
save_sharded=True,
consolidated_output_path=consolidated_output_dir,
),
)
dist.barrier()

if self.rank == 0:
file_path = os.path.join(
consolidated_output_dir, "model-00001-of-00001.safetensors"
)
loaded_dict = safetensors.torch.load_file(file_path)
self.assertEqual(loaded_dict.keys(), {"dtensor"})
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))

dist.barrier()


ONE_D_PLACEMENTS = [
[Shard(0)],
[Replicate()],
Expand Down
39 changes: 32 additions & 7 deletions torch/distributed/checkpoint/hf_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
consolidate_safetensors_files,
)
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
from torch.distributed.checkpoint._hf_utils import (
_gen_file_name,
Expand Down Expand Up @@ -57,8 +60,11 @@ def __init__(
self,
path: str,
fqn_to_index_mapping: Optional[dict[str, int]] = None,
thread_count: int = 1,
token: Optional[str] = None,
save_sharded: bool = False,
consolidated_output_path: Optional[str] = None,
num_threads_consolidation: Optional[int] = None,
) -> None:
"""
Initialize the huggingface writer pointing to path.
Expand All @@ -67,14 +73,18 @@ def __init__(
path: hf directory where the checkpoint will be read from.
Needs to have .safetensors files, but can be from any fsspec supported storage,
including localFS and hf://.
This needs to be a remote path if you want to enable consolidation after saving.
fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to.
Indices are from 1 to N, where N is the number of files. If not provided,
the tensors will be written to a single file. If none, then all the tensors on the
same rank will be written to the same file.
token: The token to use to authenticate with huggingface hub.
save_sharded: If True, save the checkpoint as a sharded checkpoint where every rank saves its own shard.
Default is False which assumes full tensors are being saved.

consolidated_output_path: If provided, the output path where the consolidated files will be written in the finish step.
This needs to be a local fs path right now.
num_threads_consolidation: Number of threads to use for parallel processing of saving data to output files.
If not provided, the default value is the number of output files.
"""

if token is not None:
Expand All @@ -88,16 +98,25 @@ def __init__(
path=path,
serialization_format=SerializationFormat.SAFETENSORS,
)
self._fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping
self._save_sharded = save_sharded
self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping
self.save_sharded: bool = save_sharded
self.consolidated_output_path: Optional[str] = consolidated_output_path

self.num_threads_consolidation: int = 1
if num_threads_consolidation:
self.num_threads_consolidation = num_threads_consolidation
elif self.fqn_to_index_mapping:
self.num_threads_consolidation = max(self.fqn_to_index_mapping.values())

self.thread_count: int = thread_count

def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
new_plans = []
for i, plan in enumerate(plans, start=1):
storage_data: dict[str, Any] = {}
if self._fqn_to_index_mapping is not None:
storage_data["fqn_to_index_mapping"] = self._fqn_to_index_mapping
if self._save_sharded:
if self.fqn_to_index_mapping is not None:
storage_data["fqn_to_index_mapping"] = self.fqn_to_index_mapping
if self.save_sharded:
storage_data["shard_index"] = i

new_plans.append(dataclasses.replace(plan, storage_data=storage_data))
Expand Down Expand Up @@ -136,8 +155,14 @@ def write_data(
return super()._write_data(planner, file_queue)

def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
if self._save_sharded:
if self.save_sharded and not self.consolidated_output_path:
return
if self.save_sharded:
return consolidate_safetensors_files(
input_dir=str(self.path),
output_dir=self.consolidated_output_path, # type: ignore[arg-type]
num_threads=self.num_threads_consolidation,
)

metadata_to_write = {}
storage_md = {}
Expand Down
Loading