Skip to content

Commit af793d7

Browse files
cast in __init__
1 parent b6e08aa commit af793d7

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

torchao/prototype/moe_training/tensor.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __new__(
4949
tensor: torch.Tensor,
5050
dtype: torch.dtype,
5151
):
52+
logger.debug(f"__new__: Creating ScaledGroupedMMTensor with dtype={dtype}")
5253
return torch.Tensor._make_wrapper_subclass(
5354
cls,
5455
tensor.size(),
@@ -67,11 +68,14 @@ def __init__(
6768
tensor: torch.Tensor,
6869
dtype: torch.dtype,
6970
):
70-
self._data = tensor
71+
self._data = tensor.to(dtype)
7172
self._dtype = dtype
73+
logger.debug(f"__init__: ScaledGroupedMMTensor with self._data.dtype={self._data.dtype} and dtype={dtype}")
7274

7375
@classmethod
7476
def __torch_function__(cls, func, types, args, kwargs={}):
77+
logger.debug(f"func: {func.__name__}, args={args}, kwargs={kwargs}")
78+
7579
# override the grouped mm op to use the differentiable _scaled_grouped_mm
7680
if func.__name__ in cls.grouped_mm_func_names:
7781
# Use torchao scaled grouped mm with dynamic quant for
@@ -97,22 +101,13 @@ def __torch_function__(cls, func, types, args, kwargs={}):
97101

98102
@classmethod
99103
def __torch_dispatch__(cls, func, types, args, kwargs={}):
100-
logger.debug(f"{func.__name__}, args={args}, kwargs={kwargs}")
104+
logger.debug(f"dispatch: {func.__name__}, args={args}, kwargs={kwargs}")
101105
# detach is special case
102106
if func == torch.ops.aten.detach.default:
103107
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)
104108

105109
# unwrap args and kwargs
106-
dtype: Optional[torch.dtype] = None
107-
108-
def unwrap(t):
109-
nonlocal dtype
110-
if dtype is None:
111-
dtype = t._dtype
112-
else:
113-
assert t._dtype == dtype
114-
return t._data
115-
110+
unwrap = lambda x: x._data
116111
args, kwargs = pytree.tree_map_only(
117112
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
118113
)
@@ -127,7 +122,7 @@ def unwrap(t):
127122
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
128123
return pytree.tree_map_only(
129124
torch.Tensor,
130-
lambda x: ScaledGroupedMMTensor(x, dtype),
125+
lambda x: ScaledGroupedMMTensor(x, x.dtype),
131126
out,
132127
)
133128

@@ -154,7 +149,7 @@ def fsdp_pre_all_gather(
154149
module: nn.Module,
155150
mp_policy: MixedPrecisionPolicy,
156151
):
157-
all_gather_inputs = (self._data.to(mp_policy.param_dtype),)
152+
all_gather_inputs = (self._data,)
158153
all_gather_metadata = ()
159154
logger.debug(f"fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}")
160155
return all_gather_inputs, all_gather_metadata
@@ -171,11 +166,10 @@ def fsdp_post_all_gather(
171166
logger.debug(f"fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
172167

173168
if out is not None:
174-
with torch.no_grad():
175-
out.copy_(data)
169+
# with torch.no_grad():
170+
# out.copy_(data)
176171
return
177172

178-
upcast_data = data.to(param_dtype)
179-
output = ScaledGroupedMMTensor(upcast_data, param_dtype)
180-
inner_tensors = (upcast_data,)
173+
output = ScaledGroupedMMTensor(data, param_dtype)
174+
inner_tensors = (data,)
181175
return output, inner_tensors

0 commit comments

Comments
 (0)