Skip to content

Commit

Permalink
Persist copy_ in training graph for inputs that don't require grad (#…
Browse files Browse the repository at this point in the history
…111046)

In this PR, we try to keep the input mutations in the forward graph IFF input mutation is data mutation and not metadata mutation and doesn't require grad. This is for optimizing inductor training graphs. (For more details: #109240)

We keep the input mutation in the graph by wrapping the original callable in a wrapper function where in the end we add input.copy_(updated_input) call which is then traced via make_fx. Previously, this was only enabled for forward-only path but unconditionally disabled for joint graph.

Another caveat is that when we are tracing through tensor subclasses, we won't allow any input mutations to be preserved in the graph. The reason is that it makes the code logic quite ugly for no obvious performance improvement.

Most of the changes in this PR are mechanical and I didn't have to make any change to the partitioner. Previously forward/backward heavily relied on metadata field `num_mutated_inps` to figure out whether something is returned as extra output or not. But now since we keep some mutations in the graph, we need to propogate something similar to `num_mutated_inps - num_graph_handled_inps`.

Pull Request resolved: #111046
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Nov 9, 2023
1 parent 2c4be77 commit 84d64d7
Show file tree
Hide file tree
Showing 5 changed files with 483 additions and 53 deletions.
193 changes: 193 additions & 0 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
skipIfRocm,
)
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode
import copy
import torch
import torch.nn as nn
import torch.utils._pytree as pytree
Expand Down Expand Up @@ -2192,6 +2193,198 @@ def f(a):

self.assertEqual(inp_ref.grad, inp_test.grad)

def test_buffer_copied_in_graph(self):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.zeros(1))
self.w1 = torch.nn.Parameter(torch.zeros(1))
self.w2 = torch.nn.Parameter(torch.zeros(1))

def forward(self, x):
self.buf.add_(1)
return (self.w1 * x * self.w2).sum() + self.buf.sum()

model_for_eager = MyModel()
model_for_compile = copy.deepcopy(model_for_eager)

fw_graph_cell = [None]
compiled_f = aot_module(
model_for_compile,
fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)),
bw_compiler=nop,
keep_inference_input_mutations=True,
)
inp_ref = torch.ones(1, requires_grad=True)
inp_test = torch.ones(1, requires_grad=True)

out_ref = model_for_eager(inp_ref.clone())
out_test = compiled_f(inp_test.clone())

self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\
def forward(self, primals_1, primals_2, primals_3, primals_4):
add = torch.ops.aten.add.Tensor(primals_3, 1)
mul = torch.ops.aten.mul.Tensor(primals_1, primals_4)
mul_1 = torch.ops.aten.mul.Tensor(mul, primals_2)
sum_1 = torch.ops.aten.sum.default(mul_1); mul_1 = None
sum_2 = torch.ops.aten.sum.default(add)
add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = None
return [add_1, primals_1, primals_2, primals_4, mul]""")

self.assertEqual(out_ref, out_test)

out_ref.sum().backward()
out_test.sum().backward()

eager_grads = [p.grad for _, p in model_for_eager.named_parameters()]
compile_grads = [p.grad for _, p in model_for_compile.named_parameters()]

self.assertEqual(eager_grads, compile_grads)
self.assertEqual(inp_ref.grad, inp_test.grad)

def test_buffer_copied_in_graph_with_different_shapes(self):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.ones(4, 4))
self.w = torch.nn.Parameter(torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]]))

def forward(self, x):
self.buf.add_(1)
return (self.w @ x).sum() + self.buf.sum()

model_for_eager = MyModel()
model_for_compile = copy.deepcopy(model_for_eager)

fw_graph_cell = [None]
compiled_f = aot_module(
model_for_compile,
fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)),
bw_compiler=nop,
keep_inference_input_mutations=True,
)
inp_ref = torch.ones(2, 4, requires_grad=True)
inp_test = torch.ones(2, 4, requires_grad=True)

out_ref = model_for_eager(inp_ref.clone())
out_test = compiled_f(inp_test.clone())

self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\
def forward(self, primals_1, primals_2, primals_3):
add = torch.ops.aten.add.Tensor(primals_2, 1)
mm = torch.ops.aten.mm.default(primals_1, primals_3)
sum_1 = torch.ops.aten.sum.default(mm); mm = None
sum_2 = torch.ops.aten.sum.default(add)
add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = None
return [add_1, primals_1, primals_3]""")
self.assertEqual(out_ref, out_test)

out_ref.sum().backward()
out_test.sum().backward()

eager_grads = [p.grad for _, p in model_for_eager.named_parameters()]
compile_grads = [p.grad for _, p in model_for_compile.named_parameters()]

self.assertEqual(eager_grads, compile_grads)

self.assertEqual(inp_ref.grad, inp_test.grad)

def test_buffer_batch_norm(self):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = torch.nn.BatchNorm1d(100)

def forward(self, x):
return self.m(x)

model_for_eager = MyModel()
model_for_compile = copy.deepcopy(model_for_eager)

fw_graph_cell = [None]
bw_graph_cell = [None]
compiled_f = aot_module(
model_for_compile,
fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)),
bw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=bw_graph_cell)),
keep_inference_input_mutations=True,
)
inp_ref = torch.ones(20, 100, requires_grad=True)
inp_test = torch.ones(20, 100, requires_grad=True)

out_ref = model_for_eager(inp_ref.clone())
out_test = compiled_f(inp_test.clone())

self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6):
add = torch.ops.aten.add.Tensor(primals_5, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(primals_6, primals_1, primals_2, primals_3, primals_4, True, 0.1, 1e-05); primals_2 = None
getitem = _native_batch_norm_legit_functional[0]
getitem_1 = _native_batch_norm_legit_functional[1]
getitem_2 = _native_batch_norm_legit_functional[2]
getitem_3 = _native_batch_norm_legit_functional[3]
getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = None
copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = None
copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = None
return [getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4]""") # noqa: B950

self.assertEqual(out_ref, out_test)

out_ref.sum().backward()
out_test.sum().backward()

eager_grads = [p.grad for _, p in model_for_eager.named_parameters()]
compile_grads = [p.grad for _, p in model_for_compile.named_parameters()]
self.assertEqual(eager_grads, compile_grads)

self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\
def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4, tangents_1):
native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(tangents_1, primals_6, primals_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); tangents_1 = primals_6 = primals_1 = getitem_3 = getitem_4 = getitem_1 = getitem_2 = None
getitem_5 = native_batch_norm_backward[0]
getitem_6 = native_batch_norm_backward[1]
getitem_7 = native_batch_norm_backward[2]; native_batch_norm_backward = None
return [getitem_6, getitem_7, None, None, None, getitem_5]""") # noqa: B950

self.assertEqual(inp_ref.grad, inp_test.grad)

def test_new_inp_requires_grad_now(self):
def f(x, y):
return x.add_(y)

fw_graph_cell = [None]
bw_graph_cell = [None]
compiled_f = aot_function(
f,
fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)),
bw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=bw_graph_cell)),
keep_inference_input_mutations=True,
)

inp_ref = (torch.ones(20, 100, requires_grad=False), torch.ones(20, 100, requires_grad=True))
inp_test = (torch.ones(20, 100, requires_grad=False), torch.ones(20, 100, requires_grad=True))

out_ref = f(*inp_ref)
out_test = compiled_f(*inp_test)

# There is no copy_ method
self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\
def forward(self, primals_1, primals_2):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
add = torch.ops.aten.add.Tensor(clone, primals_2); clone = primals_2 = None
return [add, add]""") # noqa: B950

self.assertEqual(out_ref, out_test)

out_ref.sum().backward()
out_test.sum().backward()

self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\
def forward(self, tangents_1):
return [None, tangents_1]""") # noqa: B950

def test_real_weights_in_symbolic_mode(self):
from functorch.experimental import functionalize

Expand Down
137 changes: 137 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import importlib
import itertools
import math
import operator
import os
import random
import re
Expand Down Expand Up @@ -2647,6 +2648,142 @@ def forward(self, x):
expected = mod(x)
self.assertTrue(torch.allclose(res, expected))

def test_buffer_copied_in_graph(self):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.zeros(1))
self.w1 = torch.nn.Parameter(torch.zeros(1))
self.w2 = torch.nn.Parameter(torch.zeros(1))

def forward(self, x):
self.buf.add_(1)
return (self.w1 * x * self.w2).sum() + self.buf.sum()

model_for_eager = MyModel()
model_for_compile = copy.deepcopy(model_for_eager)

eager_version_counters = [
buffer._version for _, buffer in model_for_eager.named_buffers()
]
compile_version_counters = [
buffer._version for _, buffer in model_for_compile.named_buffers()
]

compiled_f = torch.compile(model_for_compile, backend="inductor")

inp_ref = torch.ones(1, requires_grad=True)
inp_test = torch.ones(1, requires_grad=True)

out_ref = model_for_eager(inp_ref.clone())
out_test = compiled_f(inp_test.clone())

eager_version_counters_after = [
buffer._version for _, buffer in model_for_eager.named_buffers()
]
compile_version_counters_after = [
buffer._version for _, buffer in model_for_compile.named_buffers()
]

eager_delta = list(
map(operator.sub, eager_version_counters_after, eager_version_counters)
)
compile_delta = list(
map(operator.sub, compile_version_counters_after, compile_version_counters)
)

self.assertEqual(eager_delta, compile_delta)

def test_buffer_copied_in_graph_with_different_shapes(self):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.ones(4, 4))
self.w = torch.nn.Parameter(
torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]])
)

def forward(self, x):
self.buf.add_(1)
return (self.w @ x).sum() + self.buf.sum()

model_for_eager = MyModel()
model_for_compile = copy.deepcopy(model_for_eager)

eager_version_counters = [
buffer._version for _, buffer in model_for_eager.named_buffers()
]
compile_version_counters = [
buffer._version for _, buffer in model_for_compile.named_buffers()
]

compiled_f = torch.compile(model_for_compile, backend="inductor")

inp_ref = torch.ones(2, 4, requires_grad=True)
inp_test = torch.ones(2, 4, requires_grad=True)

out_ref = model_for_eager(inp_ref.clone())
out_test = compiled_f(inp_test.clone())

eager_version_counters_after = [
buffer._version for _, buffer in model_for_eager.named_buffers()
]
compile_version_counters_after = [
buffer._version for _, buffer in model_for_compile.named_buffers()
]

eager_delta = list(
map(operator.sub, eager_version_counters_after, eager_version_counters)
)
compile_delta = list(
map(operator.sub, compile_version_counters_after, compile_version_counters)
)

self.assertEqual(eager_delta, compile_delta)

def test_buffer_batch_norm(self):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = torch.nn.BatchNorm1d(100)

def forward(self, x):
return self.m(x)

model_for_eager = MyModel()
model_for_compile = copy.deepcopy(model_for_eager)

eager_version_counters = [
buffer._version for _, buffer in model_for_eager.named_buffers()
]
compile_version_counters = [
buffer._version for _, buffer in model_for_compile.named_buffers()
]

compiled_f = torch.compile(model_for_compile, backend="inductor")

inp_ref = torch.ones(20, 100, requires_grad=True)
inp_test = torch.ones(20, 100, requires_grad=True)

out_ref = model_for_eager(inp_ref.clone())
out_test = compiled_f(inp_test.clone())

eager_version_counters_after = [
buffer._version for _, buffer in model_for_eager.named_buffers()
]
compile_version_counters_after = [
buffer._version for _, buffer in model_for_compile.named_buffers()
]

eager_delta = list(
map(operator.sub, eager_version_counters_after, eager_version_counters)
)
compile_delta = list(
map(operator.sub, compile_version_counters_after, compile_version_counters)
)

self.assertEqual(eager_delta, compile_delta)

def test_adaptive_avg_pool_with_output_size_0(self):
m1 = nn.AdaptiveAvgPool1d(0)
self.common(m1, (torch.randn(1, 2),))
Expand Down

0 comments on commit 84d64d7

Please sign in to comment.