-
Notifications
You must be signed in to change notification settings - Fork 565
[DSV3] Offload dequantization process to DCP QuantizedHFReader #1804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
199519c
f816eb1
ceb1411
bcd786b
643bfb6
0d5fdba
f2a2011
ae5e9d5
fa1824c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,14 @@ | |
import re | ||
from typing import Any | ||
|
||
import torch | ||
from torch.distributed.checkpoint import HuggingFaceStorageReader | ||
|
||
from torch.distributed.tensor import DTensor | ||
from torchtitan.models.utils import MoEStateDictAdapter | ||
|
||
from .args import DeepSeekV3ModelArgs | ||
|
||
from .quantization import calculate_scale_shape, dequantize_from_fp8 | ||
|
||
|
||
class DeepSeekV3StateDictAdapter(MoEStateDictAdapter): | ||
""" | ||
|
@@ -70,60 +71,33 @@ def __init__( | |
} | ||
) | ||
|
||
def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: | ||
def get_hf_storage_reader( | ||
self, path: str, from_quantized: bool = False | ||
) -> HuggingFaceStorageReader: | ||
""" | ||
Dequantize the weights from float8 to float32. | ||
Override default get_hf_storage_reader function to return QuantizedHFStorageReader. | ||
""" | ||
if from_quantized: | ||
from torch.distributed.checkpoint.quantized_hf_storage import ( | ||
QuantizedHuggingFaceStorageReader, | ||
) | ||
|
||
scale_inv_keys = [] | ||
for key, weight in state_dict.items(): | ||
if key.endswith(".weight") and key + "_scale_inv" in state_dict: | ||
scale_inv = state_dict[key + "_scale_inv"] | ||
dequantized_weight = dequantize_from_fp8( | ||
weight, scale_inv, dtype=torch.float32 | ||
) | ||
# update the weight and remove the scale_inv tensor | ||
state_dict[key] = dequantized_weight | ||
scale_inv_keys.append(key + "_scale_inv") | ||
|
||
for key in scale_inv_keys: | ||
state_dict.pop(key) | ||
|
||
return state_dict | ||
|
||
def _add_quantization_scale_inv_tensors( | ||
self, state_dict: dict[str, Any] | ||
) -> dict[str, Any]: | ||
""" | ||
Add quantization scale tensors the state_dict. | ||
""" | ||
non_quantized_keys = [ | ||
"input_layernorm.weight", | ||
"post_attention_layernorm.weight", | ||
"norm.weight", | ||
"lm_head.weight", | ||
"embed_tokens.weight", | ||
"mlp.gate.weight", | ||
] | ||
|
||
weight_scale_inv_state_dict = {} | ||
for key, value in state_dict.items(): | ||
if key.endswith(".weight") and not any( | ||
non_quantized_key in key for non_quantized_key in non_quantized_keys | ||
): | ||
expected_scale_shape = calculate_scale_shape(value) | ||
# add weight_scale_inv to the state_dict | ||
weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( | ||
expected_scale_shape, dtype=torch.float32 | ||
) | ||
|
||
state_dict.update(weight_scale_inv_state_dict) | ||
return state_dict | ||
# NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. | ||
# If loading checkpoints without quantization, use HuggingFaceStorageReader instead | ||
BLOCK_SIZE = 128 | ||
return QuantizedHuggingFaceStorageReader( | ||
path=path, | ||
target_dtype=torch.float32, | ||
block_size=BLOCK_SIZE, | ||
Comment on lines
+90
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should these two be configurable? If not we can remove these two lines to use default. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean
I explicit leave block_size here to make the dequantize algorithm not so mysterious - The user can easily know it's block-wise dequantized with blocksize 128 |
||
thread_count=4, | ||
) | ||
else: | ||
return HuggingFaceStorageReader(path) | ||
|
||
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
1. Convert between the HF shape and the torchtitan shape. | ||
2. Split the GroupedExperts' weight into separate expert's wegiht. | ||
2. Split the GroupedExperts' weight into separate expert's weight. | ||
""" | ||
to_hf_map = {v: k for k, v in self.from_hf_map.items()} | ||
|
||
|
@@ -172,24 +146,16 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: | |
new_key = to_hf_map[key] | ||
hf_state_dict[new_key] = value | ||
|
||
# Prepare for dequantization | ||
hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( | ||
hf_state_dict | ||
) | ||
return hf_state_dict_with_scale_inv | ||
return hf_state_dict | ||
|
||
def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
1. When loading from HF checkpoint, dequantize the weights from float8 to float32. | ||
2. Convert between the HF shape and the torchtitan shape. | ||
3. Concate separate expert's wegiht into GroupedExperts' weight. | ||
3. Concat separate expert's weight into GroupedExperts' weight. | ||
""" | ||
|
||
# dequantize the tensor in state_dict and remove the scale_inv tensor | ||
|
||
hf_state_dict = self._dequantize(hf_state_dict) | ||
state_dict = {} | ||
|
||
expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} | ||
|
||
for key, value in hf_state_dict.items(): | ||
|
@@ -215,7 +181,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: | |
layer_num, | ||
value.device_mesh, | ||
) | ||
else: # keep this path to be compatibile with offline conversion | ||
else: # keep this path to be compatible with offline conversion | ||
stacked_value = self._concatenate_expert_weights( | ||
expert_weights_by_layer, | ||
titan_abstract_key, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this intended to stay here? It looks like a debugging change that's been left in this PR by mistake? The correct number of layers looks like 61 to me from here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, yes you are right and let me fix this configuration. Thanks for pointing out