We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8c12dd0 commit 84b8661Copy full SHA for 84b8661
torch_xla/distributed/fsdp/utils.py
@@ -120,7 +120,7 @@ def backward(ctx, grad_output):
120
121
122
def _xla_patched_nn_linear_forward(m, input):
123
- return XLAPatchedLinear.apply(input, m.weight, m.bias)
+ return XLAPatchedMatmul.apply(input, m.weight.t())
124
125
126
def apply_xla_patch_to_nn_linear(module):
0 commit comments