Skip to content

Commit

Permalink
Linear autodiff revert revert (#51613)
Browse files Browse the repository at this point in the history
Summary:
patch PR #50856 and rollbak the revert D26105797 (e488e3c)

Pull Request resolved: #51613

Reviewed By: mruberry

Differential Revision: D26253999

Pulled By: ngimel

fbshipit-source-id: a20b1591de06dd277e4cd95542e3291a2f5a252c
  • Loading branch information
jjsjann123 authored and facebook-github-bot committed Feb 5, 2021
1 parent 6dcbf39 commit 4d703d0
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 57 deletions.
14 changes: 0 additions & 14 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,6 @@

namespace at { namespace native {

// in order to dispatch mkldnn linear to addmm which can be removed after linear ported.
Tensor& mkldnn_addmm_wraper_out(Tensor& result, const Tensor& bias,
const Tensor& input, const Tensor& weight, Scalar beta, Scalar alpha) {
TORCH_CHECK(false,
"mkldnn_addmm_wraper_out: in-place mkldnn operations are not supported yet");
}

Tensor mkldnn_addmm_wraper(const Tensor& bias, const Tensor& input,
const Tensor& weight, Scalar beta, Scalar alpha) {
TORCH_CHECK(input.dim() == 2,
"mkldnn_addmm_wraper: input needs to has dim 2, input dim ", input.dim());
return at::mkldnn_linear(input, weight.t(), bias);
}

Tensor linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
if (input.is_mkldnn()) {
return at::mkldnn_linear(input, weight, bias);
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2172,7 +2172,7 @@
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: nn

- func: mkldnn_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
- func: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
python_module: nn
dispatch:
Expand Down Expand Up @@ -4331,7 +4331,6 @@
CUDA: addmm_out_cuda
SparseCPU: addmm_out_sparse_dense_cpu
SparseCUDA: addmm_out_sparse_dense_cuda
MkldnnCPU: mkldnn_addmm_wraper_out

- func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: function, method
Expand All @@ -4340,7 +4339,6 @@
CUDA: addmm_cuda
SparseCPU: addmm_sparse_dense_cpu
SparseCUDA: addmm_sparse_dense_cuda
MkldnnCPU: mkldnn_addmm_wraper

- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
variants: method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
("aten::_foreach_div", datetime.date(2021, 2, 25)),
("aten::_foreach_div_", datetime.date(2021, 2, 25)),
("aten::_foreach_addcdiv", datetime.date(2021, 2, 25)),
("aten::mkldnn_linear", datetime.date(2021, 3, 2)),
]

def allow_listed(schema, allow_list):
Expand Down
31 changes: 31 additions & 0 deletions test/jit/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,37 @@ def transpose(x):
(torch.ones(1, 10, dtype=torch.float), ),
None, verbose=False, example_outputs=outputs_f2)

def test_onnx_export_preprocess_decompose_linear(self):
def t(x, weight, bias):
return torch.nn.functional.linear(x, weight, bias)

foo = torch.jit.script(t)
foo(torch.zeros(2, 4), torch.randn(3, 4), torch.randn(3))
# run it twice in case we need to remove profiling nodes
graph = foo.graph_for(
torch.zeros(2, 4), torch.randn(3, 4), torch.randn(3))

nodes = []
for n in graph.nodes():
nodes.append(n.kind())
self.assertEqual(nodes, ['aten::linear'])
torch._C._jit_pass_onnx_preprocess(graph)

nodes = []
for n in graph.nodes():
nodes.append(n.kind())
for b in n.blocks():
nodes_b = []
for n_n in b.nodes():
nodes_b.append(n_n.kind())
nodes.append(nodes_b)

self.assertEqual(
nodes,
['aten::dim', 'prim::Constant', 'aten::eq', 'prim::If',
['prim::Constant', 'aten::t', 'aten::addmm'],
['prim::Constant', 'aten::t', 'aten::matmul', 'aten::add']])

def test_onnx_export_shape_reshape(self):
class Foo(torch.nn.Module):
def forward(self, x):
Expand Down
23 changes: 1 addition & 22 deletions test/jit/test_remove_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,19 @@
import sys

import torch
from torch.nn import functional as F
from torch.testing import FileCheck

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, freeze_rng_state
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")

class TestRemoveMutation(JitTestCase):
def test_lower_linear(self):
# linear is one of main use cases of removing mutation so add test so it doesnt regress
@torch.jit.script
def foo(x):
return F.linear(x, torch.randn(20, 20), torch.randn(20))

self.run_pass('inline', foo.graph)
self.run_pass('peephole', foo.graph)
self.run_pass('constant_propagation', foo.graph)
FileCheck().check("aten::add_").run(foo.graph)
input = torch.randn(20, 20)
with freeze_rng_state():
out1 = foo(input)

self.run_pass('remove_mutation', foo.graph)
FileCheck().check_not("aten::add_").run(foo.graph)
with freeze_rng_state():
out2 = foo(input)
self.assertEqual(out1, out2)

def test_aten_inplace(self):
def test_not_new_alias(x):
y = x[0]
Expand Down
6 changes: 2 additions & 4 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3191,10 +3191,8 @@ def test_profiler_shapes(self):
print(prof.function_events)

top_level_expected_events_and_shapes = [
(None, [[30, 20]]),
('aten::addmm', [[30], [128, 20], [20, 30], [], []]),
(None, [[40, 30]]),
('aten::addmm', [[40], [128, 30], [30, 40], [], []])
('aten::linear', [[128, 20], [30, 20], [30]]),
('aten::linear', [[128, 30], [40, 30], [40]])
]

expected_iter = iter(top_level_expected_events_and_shapes)
Expand Down
53 changes: 53 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10665,6 +10665,59 @@ def randint():
FileCheck().check("Double(*, *, requires_grad=0, device=cpu)") \
.check_not("Float(*, *, requires_grad=0, device=cpu)").run(randint.graph_for())

def test_linear_grad(self):
with enable_profiling_mode_for_profiling_tests():
def t(x: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor]):
return torch.nn.functional.linear(x, w, b)

x_init = torch.randn(4, 2)
w_init = torch.randn(3, 2)
b_init = torch.randn(3)
grad = torch.randn(4, 3)

with disable_autodiff_subgraph_inlining():
# script module
jit_t = torch.jit.script(t)

x = x_init.detach().requires_grad_()
w = w_init.detach().requires_grad_()
b = b_init.detach().requires_grad_()
x_ref = x_init.detach().requires_grad_()
w_ref = w_init.detach().requires_grad_()
b_ref = b_init.detach().requires_grad_()

# profiling/optimization runs
jit_o = jit_t(x, w, b)
jit_o.backward(grad)
jit_o = jit_t(x, w, b)
jit_o.backward(grad)

x.grad.zero_()
w.grad.zero_()
b.grad.zero_()
jit_o = jit_t(x, w, b)
jit_o.backward(grad)
o = t(x_ref, w_ref, b_ref)
o.backward(grad)

self.assertEqual(jit_o, o)
self.assertEqual(x.grad, x_ref.grad)
self.assertEqual(w.grad, w_ref.grad)
self.assertEqual(b.grad, b_ref.grad)

x.grad.zero_()
w.grad.zero_()
x_ref.grad.zero_()
w_ref.grad.zero_()
jit_o = jit_t(x, w, None)
jit_o.backward(grad)
o = t(x_ref, w_ref, None)
o.backward(grad)

self.assertEqual(jit_o, o)
self.assertEqual(x.grad, x_ref.grad)
self.assertEqual(w.grad, w_ref.grad)

@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "the profiling version of test_rand")
def test_rand_profiling(self):
def test_rand():
Expand Down
9 changes: 6 additions & 3 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@
tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj())

- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: "grad.is_mkldnn() ? grad.to_dense() : maybe_multiply(grad, beta.conj())"
mat1: "grad.is_mkldnn() ? mkldnn_linear_backward_input(mat1.sizes(), grad, mat2.t()) : mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)"
mat2: "grad.is_mkldnn() ? (std::get<0>(mkldnn_linear_backward_weights(grad, mat1, mat2.t(), true))).t() : mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)"
self: maybe_multiply(grad, beta.conj())
mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)
mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)

- name: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor
self: maybe_multiply(grad, beta)
Expand Down Expand Up @@ -1888,6 +1888,9 @@
- name: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, false, false, false, false, grad_input_mask)

- name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask)

# fft
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
self: fft_r2c_backward(grad, dim, normalization, onesided, self.size(dim.back()))
Expand Down
75 changes: 75 additions & 0 deletions torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,84 @@ static void fuseListAndListUnpack(Block* b) {
}
}

static void decomposeLinear(Block* b) {
std::vector<Node*> linear_nodes;
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
for (auto* child_block : it->blocks()) {
decomposeLinear(child_block);
}
if (it->kind() == aten::linear) {
linear_nodes.push_back(*it);
}
}
for (Node* node : linear_nodes) {
auto g = b->owningGraph();

if (node->inputs()[2]->mustBeNone()) {
auto t_weight_n =
g->create(aten::t, {node->inputs()[1]}, 1)->insertBefore(node);
auto matmul_n =
g->create(aten::matmul, {node->inputs()[0], t_weight_n->output()}, 1)
->insertBefore(node);
node->output()->replaceAllUsesWith(matmul_n->output());
node->destroy();
} else {
auto dim_n =
g->create(aten::dim, {node->inputs()[0]}, 1)->insertBefore(node);
auto const_2 = g->insertConstant(IValue(2));
const_2->node()->moveBefore(node);
auto eq_n = g->create(aten::eq, {dim_n->output(), const_2}, 1)
->insertBefore(node);

auto if_n = g->create(prim::If, {eq_n->output()}, 1)->insertBefore(node);

auto true_block = if_n->addBlock();
auto false_block = if_n->addBlock();

{
WithInsertPoint guard(true_block->return_node());
auto const_1 = g->insertConstant(IValue(1.0));
auto t_weight_n = g->create(aten::t, {node->inputs()[1]}, 1)
->insertBefore(true_block->return_node());
auto addmm_n = g->create(
aten::addmm,
{node->inputs()[2],
node->inputs()[0],
t_weight_n->output(),
const_1,
const_1},
1)
->insertBefore(true_block->return_node());
true_block->registerOutput(addmm_n->output());
}

{
WithInsertPoint guard(false_block->return_node());
auto const_1 = g->insertConstant(IValue(1.0));
auto t_weight_n = g->create(aten::t, {node->inputs()[1]}, 1)
->insertBefore(false_block->return_node());
auto matmul_n =
g->create(
aten::matmul, {node->inputs()[0], t_weight_n->output()}, 1)
->insertBefore(false_block->return_node());
auto add_n =
g->create(
aten::add, {matmul_n->output(), node->inputs()[2], const_1}, 1)
->insertBefore(false_block->return_node());
false_block->registerOutput(add_n->output());
}
node->output()->replaceAllUsesWith(if_n->output());
node->destroy();
}
}
}

} // namespace

void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
GRAPH_DEBUG("priot to decompose linear", graph);
decomposeLinear(graph->block());
GRAPH_DEBUG("after decompose linear", graph);
FuseWithListUnpack(graph->block());
ReplaceAddWithConcat(graph->block());
fuseListAndListUnpack(graph->block());
Expand Down
17 changes: 17 additions & 0 deletions torch/csrc/jit/runtime/symbolic_script.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,23 @@ const std::vector<std::string> functions = {
return grad_self, grad_other
return torch.matmul(self, other), backward
def linear(input : Tensor,
weight : Tensor,
bias : Optional[Tensor]):
result = torch.linear(input, weight, bias)
def backward(grad_output):
if bias is not None:
grad_bias = grad_output._grad_sum_to_size(bias.size())
else:
grad_bias = None
weight_size = weight.size()
grad_input = torch.matmul(grad_output, weight)
grad_weight = torch.matmul(grad_output.reshape(-1, weight_size[0]).t(), input.reshape(-1, weight_size[1]))
return grad_input, grad_weight, grad_bias
return result, backward
)",
R"(
def addcmul(self,
Expand Down
10 changes: 1 addition & 9 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,15 +1750,7 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
"""
if has_torch_function_variadic(input, weight):
return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
return torch._C._nn.linear(input, weight, bias)


def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/jit_metaprogramming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@
('tanh', (S, S, S), (), '', (True,)),
('sigmoid', (S, S, S), (), '', (True,)),
('log_softmax', (S, S, S), (0,), '', (True,)),
('linear', (S, S), ((M, S),), '', (True, ['aten::t', 'aten::matmul'])),
('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::add', 'aten::mm'])),
('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
Expand Down

0 comments on commit 4d703d0

Please sign in to comment.