@@ -49,6 +49,7 @@ def __new__(
49
49
tensor : torch .Tensor ,
50
50
dtype : torch .dtype ,
51
51
):
52
+ logger .debug (f"__new__: Creating ScaledGroupedMMTensor with dtype={ dtype } " )
52
53
return torch .Tensor ._make_wrapper_subclass (
53
54
cls ,
54
55
tensor .size (),
@@ -67,11 +68,14 @@ def __init__(
67
68
tensor : torch .Tensor ,
68
69
dtype : torch .dtype ,
69
70
):
70
- self ._data = tensor
71
+ self ._data = tensor . to ( dtype )
71
72
self ._dtype = dtype
73
+ logger .debug (f"__init__: ScaledGroupedMMTensor with self._data.dtype={ self ._data .dtype } and dtype={ dtype } " )
72
74
73
75
@classmethod
74
76
def __torch_function__ (cls , func , types , args , kwargs = {}):
77
+ logger .debug (f"func: { func .__name__ } , args={ args } , kwargs={ kwargs } " )
78
+
75
79
# override the grouped mm op to use the differentiable _scaled_grouped_mm
76
80
if func .__name__ in cls .grouped_mm_func_names :
77
81
# Use torchao scaled grouped mm with dynamic quant for
@@ -97,22 +101,13 @@ def __torch_function__(cls, func, types, args, kwargs={}):
97
101
98
102
@classmethod
99
103
def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
100
- logger .debug (f"{ func .__name__ } , args={ args } , kwargs={ kwargs } " )
104
+ logger .debug (f"dispatch: { func .__name__ } , args={ args } , kwargs={ kwargs } " )
101
105
# detach is special case
102
106
if func == torch .ops .aten .detach .default :
103
107
return ScaledGroupedMMTensor (args [0 ]._data , args [0 ]._dtype )
104
108
105
109
# unwrap args and kwargs
106
- dtype : Optional [torch .dtype ] = None
107
-
108
- def unwrap (t ):
109
- nonlocal dtype
110
- if dtype is None :
111
- dtype = t ._dtype
112
- else :
113
- assert t ._dtype == dtype
114
- return t ._data
115
-
110
+ unwrap = lambda x : x ._data
116
111
args , kwargs = pytree .tree_map_only (
117
112
ScaledGroupedMMTensor , unwrap , (args , kwargs or {})
118
113
)
@@ -127,7 +122,7 @@ def unwrap(t):
127
122
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
128
123
return pytree .tree_map_only (
129
124
torch .Tensor ,
130
- lambda x : ScaledGroupedMMTensor (x , dtype ),
125
+ lambda x : ScaledGroupedMMTensor (x , x . dtype ),
131
126
out ,
132
127
)
133
128
@@ -154,7 +149,7 @@ def fsdp_pre_all_gather(
154
149
module : nn .Module ,
155
150
mp_policy : MixedPrecisionPolicy ,
156
151
):
157
- all_gather_inputs = (self ._data . to ( mp_policy . param_dtype ) ,)
152
+ all_gather_inputs = (self ._data ,)
158
153
all_gather_metadata = ()
159
154
logger .debug (f"fsdp_pre_all_gather: self._data.dtype={ self ._data .dtype } , param_dtype: { mp_policy .param_dtype } " )
160
155
return all_gather_inputs , all_gather_metadata
@@ -171,11 +166,10 @@ def fsdp_post_all_gather(
171
166
logger .debug (f"fsdp_post_all_gather: data.dtype={ data .dtype } , param_dtype: { param_dtype } " )
172
167
173
168
if out is not None :
174
- with torch .no_grad ():
175
- out .copy_ (data )
169
+ # with torch.no_grad():
170
+ # out.copy_(data)
176
171
return
177
172
178
- upcast_data = data .to (param_dtype )
179
- output = ScaledGroupedMMTensor (upcast_data , param_dtype )
180
- inner_tensors = (upcast_data ,)
173
+ output = ScaledGroupedMMTensor (data , param_dtype )
174
+ inner_tensors = (data ,)
181
175
return output , inner_tensors
0 commit comments