Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import (
HuggingFaceStorageReader,
HuggingFaceStorageWriter,
)
from torch.distributed.checkpoint import HuggingFaceStorageWriter
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
consolidate_safetensors_files_on_every_rank,
)
Expand Down Expand Up @@ -249,6 +246,9 @@ def load_state_dict(state_dict):
self.initial_load_model_only = checkpoint_config.initial_load_model_only
self.initial_load_in_hf = checkpoint_config.initial_load_in_hf
self.initial_load_path = checkpoint_config.initial_load_path
self.initial_load_in_hf_quantized = (
checkpoint_config.initial_load_in_hf_quantized
)
self.last_save_model_only = checkpoint_config.last_save_model_only
self.last_save_in_hf = checkpoint_config.last_save_in_hf
if self.last_save_in_hf:
Expand Down Expand Up @@ -418,6 +418,7 @@ def dcp_load(
state_dict: dict[str, Any],
checkpoint_id: str,
from_hf: bool,
from_quantized: bool,
) -> None:
"""Load the checkpoint with dcp.
Args:
Expand All @@ -432,10 +433,13 @@ def dcp_load(
self.sd_adapter is not None
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
hf_state_dict = self.sd_adapter.to_hf(state_dict)
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(
checkpoint_id, from_quantized
)

dcp.load(
hf_state_dict,
storage_reader=HuggingFaceStorageReader(path=checkpoint_id),
storage_reader=hf_storage_reader,
)

state_dict = self.sd_adapter.from_hf(hf_state_dict)
Expand Down Expand Up @@ -544,13 +548,21 @@ def load(self, step: int = -1) -> bool:

model_only = False
from_hf = False
from_quantized = False
if not os.path.exists(self.folder):
model_only = self.initial_load_model_only
from_hf = self.initial_load_in_hf
from_quantized = self.initial_load_in_hf_quantized
if from_hf:
assert (
model_only
), "Only model can be loaded when loading from HF's safetensors checkpoint."

if from_quantized:
assert (
from_hf
), "Quantized checkpoint can only be loaded from HuggingFace format."

if self.initial_load_path:
checkpoint_id = self.initial_load_path
if not os.path.isdir(checkpoint_id):
Expand Down Expand Up @@ -602,6 +614,7 @@ def load(self, step: int = -1) -> bool:
states,
checkpoint_id=checkpoint_id,
from_hf=from_hf,
from_quantized=from_quantized,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
Expand Down Expand Up @@ -679,6 +692,7 @@ def _ft_load(self) -> None:
checkpoint_id=checkpoint_id,
# FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader.
from_hf=False,
from_quantized=False,
)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,14 @@ class Checkpoint:
non-tensors. The default value is False.
"""

initial_load_in_hf_quantized: bool = False
"""
Enable loading of HuggingFace's safetensors format with quantized state dict keys. The option
is only used when `initial_load_path` and `initial_load_path_in_hf` is specified. This will load
checkpoints in HF's model definition and dequantize on model weights if necessary. To support
this parameter, the model need to define proper HuggingFaceStorageReader to perform dequantize.
"""

last_save_model_only: bool = True
"""
When last_save_model_only=True, only the model will be saved at the end of training,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/deepseek_v3/__init__.py
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this intended to stay here? It looks like a debugging change that's been left in this PR by mistake? The correct number of layers looks like 61 to me from here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, yes you are right and let me fix this configuration. Thanks for pointing out

Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
dim=7168,
inter_dim=18432,
moe_inter_dim=2048,
n_layers=61,
n_layers=4,
n_dense_layers=3,
n_heads=128,
moe_args=MoEArgs(
Expand Down
73 changes: 0 additions & 73 deletions torchtitan/models/deepseek_v3/model/quantization.py

This file was deleted.

86 changes: 26 additions & 60 deletions torchtitan/models/deepseek_v3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import re
from typing import Any

import torch
from torch.distributed.checkpoint import HuggingFaceStorageReader

from torch.distributed.tensor import DTensor
from torchtitan.models.utils import MoEStateDictAdapter

from .args import DeepSeekV3ModelArgs

from .quantization import calculate_scale_shape, dequantize_from_fp8


class DeepSeekV3StateDictAdapter(MoEStateDictAdapter):
"""
Expand Down Expand Up @@ -70,60 +71,33 @@ def __init__(
}
)

def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]:
def get_hf_storage_reader(
self, path: str, from_quantized: bool = False
) -> HuggingFaceStorageReader:
"""
Dequantize the weights from float8 to float32.
Override default get_hf_storage_reader function to return QuantizedHFStorageReader.
"""
if from_quantized:
from torch.distributed.checkpoint.quantized_hf_storage import (
QuantizedHuggingFaceStorageReader,
)

scale_inv_keys = []
for key, weight in state_dict.items():
if key.endswith(".weight") and key + "_scale_inv" in state_dict:
scale_inv = state_dict[key + "_scale_inv"]
dequantized_weight = dequantize_from_fp8(
weight, scale_inv, dtype=torch.float32
)
# update the weight and remove the scale_inv tensor
state_dict[key] = dequantized_weight
scale_inv_keys.append(key + "_scale_inv")

for key in scale_inv_keys:
state_dict.pop(key)

return state_dict

def _add_quantization_scale_inv_tensors(
self, state_dict: dict[str, Any]
) -> dict[str, Any]:
"""
Add quantization scale tensors the state_dict.
"""
non_quantized_keys = [
"input_layernorm.weight",
"post_attention_layernorm.weight",
"norm.weight",
"lm_head.weight",
"embed_tokens.weight",
"mlp.gate.weight",
]

weight_scale_inv_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".weight") and not any(
non_quantized_key in key for non_quantized_key in non_quantized_keys
):
expected_scale_shape = calculate_scale_shape(value)
# add weight_scale_inv to the state_dict
weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones(
expected_scale_shape, dtype=torch.float32
)

state_dict.update(weight_scale_inv_state_dict)
return state_dict
# NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model.
# If loading checkpoints without quantization, use HuggingFaceStorageReader instead
BLOCK_SIZE = 128
return QuantizedHuggingFaceStorageReader(
path=path,
target_dtype=torch.float32,
block_size=BLOCK_SIZE,
Comment on lines +90 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these two be configurable? If not we can remove these two lines to use default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean block_size and taget_dtype? PyTorch default value is

thread_count: int = 1,
target_dtype: torch.dtype = torch.float32,
block_size: int = 128,

I explicit leave block_size here to make the dequantize algorithm not so mysterious - The user can easily know it's block-wise dequantized with blocksize 128

thread_count=4,
)
else:
return HuggingFaceStorageReader(path)

def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
"""
1. Convert between the HF shape and the torchtitan shape.
2. Split the GroupedExperts' weight into separate expert's wegiht.
2. Split the GroupedExperts' weight into separate expert's weight.
"""
to_hf_map = {v: k for k, v in self.from_hf_map.items()}

Expand Down Expand Up @@ -172,24 +146,16 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
new_key = to_hf_map[key]
hf_state_dict[new_key] = value

# Prepare for dequantization
hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors(
hf_state_dict
)
return hf_state_dict_with_scale_inv
return hf_state_dict

def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
"""
1. When loading from HF checkpoint, dequantize the weights from float8 to float32.
2. Convert between the HF shape and the torchtitan shape.
3. Concate separate expert's wegiht into GroupedExperts' weight.
3. Concat separate expert's weight into GroupedExperts' weight.
"""

# dequantize the tensor in state_dict and remove the scale_inv tensor

hf_state_dict = self._dequantize(hf_state_dict)
state_dict = {}

expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}}

for key, value in hf_state_dict.items():
Expand All @@ -215,7 +181,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
layer_num,
value.device_mesh,
)
else: # keep this path to be compatibile with offline conversion
else: # keep this path to be compatible with offline conversion
stacked_value = self._concatenate_expert_weights(
expert_weights_by_layer,
titan_abstract_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=true
enable = true
components = ["loss"] # ["model", "loss"]

[quantize.linear.float8]
Expand Down
29 changes: 27 additions & 2 deletions torchtitan/protocols/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
import re
from abc import ABC, abstractmethod
from typing import Any

logger = logging.getLogger()
from torch.distributed.checkpoint import HuggingFaceStorageReader

from torchtitan.tools.logging import logger

from .model import BaseModelArgs

Expand Down Expand Up @@ -58,6 +59,21 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
"""
pass

@abstractmethod
def get_hf_storage_reader(
self, path: str, from_quantized: bool = False
) -> HuggingFaceStorageReader:
"""Returns hf storage reader to read HF checkpoint

Args:
path: the path to read HF checkpoint

Returns:
The HuggingFace storage reader to read from HF checkpoint

"""
pass


class StateDictAdapter(BaseStateDictAdapter):
"""State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping"""
Expand Down Expand Up @@ -86,3 +102,12 @@ def __init__(
self.fqn_to_index_mapping[hf_key] = int(indx)
else:
self.fqn_to_index_mapping = None

def get_hf_storage_reader(
self, path: str, from_quantized: bool = False
) -> HuggingFaceStorageReader:
if from_quantized:
logger.warning(
"Loading from quantized checkpoint format is not supported for this model."
)
return HuggingFaceStorageReader(path)