Skip to content

Commit 84b8661

Browse files
committed
remove reshape in linear backward
1 parent 8c12dd0 commit 84b8661

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch_xla/distributed/fsdp/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def backward(ctx, grad_output):
120120

121121

122122
def _xla_patched_nn_linear_forward(m, input):
123-
return XLAPatchedLinear.apply(input, m.weight, m.bias)
123+
return XLAPatchedMatmul.apply(input, m.weight.t())
124124

125125

126126
def apply_xla_patch_to_nn_linear(module):

0 commit comments

Comments
 (0)