-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
mkldnn.py
346 lines (311 loc) · 11.7 KB
/
mkldnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import copy
from functools import reduce
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._dynamo.utils import detect_fake_mode
from torch.fx.experimental.optimization import replace_node_module
from torch.fx.experimental.symbolic_shapes import free_symbols
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn.modules.utils import _pair
from . import config
def is_group_depthwise_conv_transpose(m):
return (
type(m) in [nn.ConvTranspose2d] and m.groups > 1 and m.groups == m.in_channels
)
class PackedConv2d(nn.Conv2d):
def __init__(
self,
conv: nn.Module,
input_size: Optional[list],
):
super().__init__(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
conv.bias is not None,
conv.padding_mode,
conv.weight.device,
conv.weight.dtype,
)
self._update_module_params(conv, input_size)
def _update_module_params(self, conv, input_size):
self.__dict__ = copy.deepcopy(conv.__dict__)
self.weight = torch.nn.Parameter(
torch._C._nn.mkldnn_reorder_conv2d_weight(
self.weight.to_mkldnn(),
self.padding,
self.stride,
self.dilation,
self.groups,
input_size,
)
if input_size is not None
else self.weight.to_mkldnn(),
requires_grad=self.weight.requires_grad,
)
def _conv_forward(self, input, weight, bias):
if self.padding_mode != "zeros":
return torch.ops.mkldnn._convolution_pointwise(
F.pad(
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight,
bias,
_pair(0),
self.stride,
self.dilation,
self.groups,
"none",
[],
"",
)
return torch.ops.mkldnn._convolution_pointwise(
input,
weight,
bias,
self.padding,
self.stride,
self.dilation,
self.groups,
"none",
[],
"",
)
def forward(self, input):
return self._conv_forward(input, self.weight, self.bias)
class PackedLinearFP32(nn.Linear):
def __init__(self, linear: nn.Module, input_size: Optional[list]):
super().__init__(
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
linear.weight.dtype,
)
self._update_module_params(linear, input_size)
def _update_module_params(self, linear, input_size):
self.__dict__ = copy.deepcopy(linear.__dict__)
self.batch_size = reduce(lambda x, y: x * y, input_size[:-1])
self.packed_weight = torch.nn.Parameter(
torch.ops.mkl._mkl_reorder_linear_weight(
self.weight.to_mkldnn(), self.batch_size
),
requires_grad=self.weight.requires_grad,
)
def forward(self, input):
y = torch.ops.mkl._mkl_linear(
input, self.packed_weight, self.weight, self.bias, self.batch_size
)
return y
class PackedLinearBF16(nn.Linear):
def __init__(self, linear: nn.Module, input_size: Optional[list]):
super().__init__(
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
linear.weight.dtype,
)
self._update_module_params(linear, input_size)
def _update_module_params(self, linear, input_size):
self.__dict__ = copy.deepcopy(linear.__dict__)
self.batch_size = (
reduce(lambda x, y: x * y, input_size[:-1])
if input_size is not None
else None
)
self.packed_weight = torch.nn.Parameter(
torch.ops.mkldnn._reorder_linear_weight(
self.weight.to_mkldnn(),
self.batch_size,
),
requires_grad=self.weight.requires_grad,
)
def forward(self, input):
y = torch.ops.mkldnn._linear_pointwise(
input,
self.packed_weight,
self.bias,
"none",
[],
"",
)
return y
class PackedConvTranspose2d(nn.ConvTranspose2d):
def __init__(
self,
conv_transpose: nn.Module,
input_size: Optional[list],
):
super().__init__(
conv_transpose.in_channels,
conv_transpose.out_channels,
conv_transpose.kernel_size,
conv_transpose.stride,
conv_transpose.padding,
conv_transpose.output_padding,
conv_transpose.groups,
conv_transpose.bias is not None,
conv_transpose.dilation,
conv_transpose.padding_mode,
conv_transpose.weight.device,
conv_transpose.weight.dtype,
)
self._update_module_params(conv_transpose, input_size)
def _update_module_params(self, conv_transpose, input_size):
self.__dict__ = copy.deepcopy(conv_transpose.__dict__)
packed_weight = (
torch.ops.mkldnn._reorder_convolution_transpose_weight(
self.weight.to_mkldnn(),
self.padding,
self.output_padding,
self.stride,
self.dilation,
self.groups,
input_size,
)
if input_size is not None
else self.weight.transpose(0, 1).to_mkldnn()
)
self.weight = torch.nn.Parameter(
packed_weight,
requires_grad=self.weight.requires_grad,
)
def _conv_transpose_forward(self, input, weight, bias):
if self.padding_mode != "zeros":
raise ValueError(
"Only `zeros` padding mode is supported for PackedConvTranspose2d"
)
return torch.ops.mkldnn._convolution_transpose_pointwise(
input,
weight,
bias,
self.padding,
self.output_padding,
self.stride,
self.dilation,
self.groups,
"none",
[],
"",
)
def forward(self, input):
return self._conv_transpose_forward(input, self.weight, self.bias)
def packed_conv_eval(conv: nn.Module, input_size: Optional[list]):
assert not (conv.training), "Fusion only for eval!"
return PackedConv2d(
conv,
input_size,
)
def packed_conv_transpose_eval(conv_transpose: nn.Module, input_size: Optional[list]):
assert not (conv_transpose.training), "Fusion only for eval!"
return PackedConvTranspose2d(conv_transpose, input_size)
def packed_linear_eval(linear: nn.Module, input_size: Optional[list]):
assert not (linear.training), "Fusion only for eval!"
if linear.weight.dtype == torch.bfloat16:
return PackedLinearBF16(linear, input_size)
return PackedLinearFP32(linear, input_size)
def mkldnn_fuse_fx(gm: torch.fx.GraphModule, example_inputs):
is_cpu = all(
example_input.device == torch.device("cpu")
for example_input in example_inputs
if isinstance(example_input, torch.Tensor)
)
# make sure the autograd and autocast are disabled.
if torch.is_grad_enabled() or torch.is_autocast_cpu_enabled():
return gm
if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
return gm
if not is_cpu:
return gm
fake_mode = detect_fake_mode(example_inputs)
# NB: free_symbols test here is a BIG hammer. ShapeProp doesn't
# work with symbolic shapes though, see
# https://github.com/pytorch/pytorch/pull/103512
if config.cpp.weight_prepack and not any(free_symbols(e) for e in example_inputs):
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
gm = pack_module(gm)
return gm
def pack_module(gm: torch.fx.GraphModule):
modules = dict(gm.named_modules())
for node in gm.graph.nodes:
if node.op == "call_module":
assert isinstance(node.target, str)
cur_module = modules[node.target]
if type(cur_module) in computation_op_packed_map:
if (
cur_module.weight.device != torch.device("cpu")
or cur_module.weight.dtype not in [torch.bfloat16, torch.float32]
or any(size == 0 for size in cur_module.weight.shape)
):
continue
if cur_module.training:
continue
if (
cur_module.weight.dtype == torch.bfloat16
and not torch.ops.mkldnn._is_mkldnn_bf16_supported()
):
continue
if free_symbols(node.args[0].meta.get("tensor_meta").shape):
computation_node_input_size = None
# Conv2d and ConvTranspose2d weight format are dependent on input size,
# but ShapeProp may be failed to get the input size, so we skip them.
if not (
type(cur_module) in [torch.nn.Linear]
and cur_module.weight.dtype == torch.bfloat16
):
continue
else:
computation_node_input_size = tuple(
int(x) for x in node.args[0].meta.get("tensor_meta").shape
)
if any(size == 0 for size in computation_node_input_size):
continue
if type(cur_module) in [torch.nn.Linear]:
# for fp32 linear, only packed when has mkl.
if (
cur_module.weight.dtype == torch.float32
and (not torch._C.has_mkl)
) or len(computation_node_input_size) < 2:
continue
else:
if len(computation_node_input_size) != 4:
continue
if type(cur_module) in [nn.Conv2d] and isinstance(
cur_module.padding, str
):
continue
# TODO: remove this when group depthwise ConvTranspose is supported
if type(cur_module) in [nn.ConvTranspose2d] and (
is_group_depthwise_conv_transpose(cur_module)
or len(node.args) > 1
or len(node.kwargs) > 0
or any(
not isinstance(output_padding, int)
or not isinstance(stride, int)
or output_padding >= stride
for output_padding, stride in zip(
cur_module.output_padding, cur_module.stride
)
) # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big
):
continue
new_module = computation_op_packed_map[type(cur_module)](
cur_module, computation_node_input_size
)
assert isinstance(new_module, nn.Module)
replace_node_module(node, modules, new_module)
gm.graph.lint()
gm.recompile()
return gm
computation_op_packed_map = {
nn.Linear: packed_linear_eval,
nn.Conv2d: packed_conv_eval,
nn.ConvTranspose2d: packed_conv_transpose_eval,
}