Skip to content

Commit f061d0c

Browse files
committed
activation offloading implementation
ghstack-source-id: 89e2cbe Pull Request resolved: #167880
1 parent a6a0379 commit f061d0c

File tree

4 files changed

+684
-0
lines changed

4 files changed

+684
-0
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Owner(s): ["oncall: pt2"]
2+
# flake8: noqa: B950
3+
4+
import unittest
5+
import warnings
6+
from functools import partial
7+
8+
import torch
9+
import torch._functorch.config
10+
from functorch.compile import (
11+
aot_function,
12+
default_decompositions,
13+
min_cut_rematerialization_partition,
14+
)
15+
from torch._dynamo.graph_bytecode_inputs import reset_user_object_tracking
16+
from torch._inductor.utils import run_fw_bw_and_get_code
17+
from torch.testing import FileCheck
18+
from torch.testing._internal.common_utils import run_tests, TestCase
19+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
20+
21+
22+
USE_NETWORKX = False
23+
try:
24+
import networkx # noqa: F401
25+
26+
USE_NETWORKX = True
27+
except ImportError:
28+
warnings.warn("Some tests use networkx but it was not installed", UserWarning)
29+
30+
31+
def extract_graph(fx_g, _, graph_cell):
32+
graph_cell[0] = fx_g
33+
return fx_g
34+
35+
36+
def get_fw_bw_graph(
37+
f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False
38+
):
39+
fw_graph_cell = [None]
40+
bw_graph_cell = [None]
41+
aot_function(
42+
f,
43+
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
44+
bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
45+
partition_fn=partitioner,
46+
decompositions=default_decompositions,
47+
dynamic=dynamic,
48+
)(*inps).sum().backward()
49+
return (fw_graph_cell[0], bw_graph_cell[0])
50+
51+
52+
class ActivationOffloadingTests(TestCase):
53+
"""Tests activation offloading functionality"""
54+
55+
def setUp(self):
56+
super().setUp()
57+
58+
def fn(x):
59+
return (x[0] + x[1]).sin() + (x[2] + x[3]).sin() + (x[4] + x[5]).sin()
60+
61+
def mark_one_cos_for_offloading(gm, joint_inputs):
62+
for node in gm.graph.nodes:
63+
if node.name == "cos_1":
64+
node.meta["should_offload"] = True
65+
return gm
66+
67+
dim = 10
68+
self.x = [
69+
torch.randn(dim, dim, requires_grad=True, device=GPU_TYPE) for _ in range(6)
70+
]
71+
self.fn = fn
72+
self.joint_custom_pass = mark_one_cos_for_offloading
73+
74+
"""
75+
The first set of tests are for the case of adding offload nodes to the fwd and bwd graphs.
76+
"""
77+
78+
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
79+
@torch._functorch.config.patch(enable_activation_offloading=True)
80+
def test_partitioner_offload(self):
81+
torch._dynamo.reset()
82+
torch._functorch.config.joint_custom_pass = self.joint_custom_pass
83+
fw_graph, bw_graph = get_fw_bw_graph(self.fn, [self.x])
84+
85+
self.assertExpectedInline(
86+
fw_graph.code.strip(),
87+
"""\
88+
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6):
89+
add = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None
90+
sin = torch.ops.aten.sin.default(add)
91+
add_1 = torch.ops.aten.add.Tensor(primals_3, primals_4); primals_3 = primals_4 = None
92+
sin_1 = torch.ops.aten.sin.default(add_1)
93+
add_2 = torch.ops.aten.add.Tensor(sin, sin_1); sin = sin_1 = None
94+
add_3 = torch.ops.aten.add.Tensor(primals_5, primals_6); primals_5 = primals_6 = None
95+
sin_2 = torch.ops.aten.sin.default(add_3)
96+
add_4 = torch.ops.aten.add.Tensor(add_2, sin_2); add_2 = sin_2 = None
97+
cos = torch.ops.aten.cos.default(add_3); add_3 = None
98+
cos_1 = torch.ops.aten.cos.default(add_1); add_1 = None
99+
cpu_offload_cos_1 = torch.ops.prims.device_put.default(cos_1, device(type='cpu'), non_blocking = True); cos_1 = None
100+
cos_2 = torch.ops.aten.cos.default(add); add = None
101+
return (add_4, cos, cpu_offload_cos_1, cos_2)""",
102+
)
103+
104+
self.assertExpectedInline(
105+
bw_graph.code.strip(),
106+
"""\
107+
def forward(self, cos, cpu_offload_cos_1, cos_2, tangents_1):
108+
mul = torch.ops.aten.mul.Tensor(tangents_1, cos); cos = None
109+
gpu_reload_cos_1 = torch.ops.prims.device_put.default(cpu_offload_cos_1, device(type='cuda', index=0), non_blocking = True); cpu_offload_cos_1 = None
110+
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, gpu_reload_cos_1); gpu_reload_cos_1 = None
111+
mul_2 = torch.ops.aten.mul.Tensor(tangents_1, cos_2); tangents_1 = cos_2 = None
112+
return (mul_2, mul_2, mul_1, mul_1, mul, mul)""",
113+
)
114+
115+
def test_inductor_offload(self):
116+
torch._dynamo.reset()
117+
118+
def run_compiled():
119+
torch._functorch.config.enable_activation_offloading = True
120+
torch._functorch.config.joint_custom_pass = self.joint_custom_pass
121+
return torch.compile(self.fn)(self.x)
122+
123+
_, (fw_code, bw_code) = run_fw_bw_and_get_code(run_compiled)
124+
125+
(
126+
FileCheck()
127+
.check("buf3 = empty_strided_cpu_pinned(")
128+
.check("buf3.copy_(buf2, True)")
129+
.run(fw_code)
130+
)
131+
132+
(
133+
FileCheck()
134+
.check("buf1 = empty_strided_cuda(")
135+
.check("buf1.copy_(cpu_offload_cos_1, True)")
136+
.check("del cpu_offload_cos_1")
137+
.run(bw_code)
138+
)
139+
140+
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
141+
@torch._functorch.config.patch(
142+
enable_activation_offloading=True,
143+
activation_offload_separate_stream=True,
144+
)
145+
def test_partitioner_offload_sep_stream(self):
146+
reset_user_object_tracking()
147+
torch._dynamo.reset()
148+
torch._functorch.config.joint_custom_pass = self.joint_custom_pass
149+
fw_graph, bw_graph = get_fw_bw_graph(self.fn, [self.x])
150+
151+
self.assertExpectedInline(
152+
fw_graph.code.strip(),
153+
"""\
154+
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6):
155+
add = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None
156+
sin = torch.ops.aten.sin.default(add)
157+
add_1 = torch.ops.aten.add.Tensor(primals_3, primals_4); primals_3 = primals_4 = None
158+
sin_1 = torch.ops.aten.sin.default(add_1)
159+
add_2 = torch.ops.aten.add.Tensor(sin, sin_1); sin = sin_1 = None
160+
add_3 = torch.ops.aten.add.Tensor(primals_5, primals_6); primals_5 = primals_6 = None
161+
sin_2 = torch.ops.aten.sin.default(add_3)
162+
add_4 = torch.ops.aten.add.Tensor(add_2, sin_2); add_2 = sin_2 = None
163+
cos = torch.ops.aten.cos.default(add_3); add_3 = None
164+
cos_1 = torch.ops.aten.cos.default(add_1); add_1 = None
165+
record_event_default = torch.ops.streams.record_event.default(2, 0); record_event_default = None
166+
stream_in_cpu_offload_cos_1 = torch.ops.streams.fork.default(0, 1); stream_in_cpu_offload_cos_1 = None
167+
wait_event_default = torch.ops.streams.wait_event.default(2, 1); wait_event_default = None
168+
cpu_offload_cos_1 = torch.ops.prims.device_put.default(cos_1, device(type='cpu'), non_blocking = True); cos_1 = None
169+
stream_out_cpu_offload_cos_1 = torch.ops.streams.join.default(1, 0); stream_out_cpu_offload_cos_1 = None
170+
cos_2 = torch.ops.aten.cos.default(add); add = None
171+
return (add_4, cos, cpu_offload_cos_1, cos_2)""",
172+
)
173+
174+
self.assertExpectedInline(
175+
bw_graph.code.strip(),
176+
"""\
177+
def forward(self, cos, cpu_offload_cos_1, cos_2, tangents_1):
178+
mul = torch.ops.aten.mul.Tensor(tangents_1, cos); cos = None
179+
stream_in_gpu_reload_cos_1 = torch.ops.streams.fork.default(3, 4); stream_in_gpu_reload_cos_1 = None
180+
wait_stream_default = torch.ops.streams.wait_stream.default(4, 3); wait_stream_default = None
181+
gpu_reload_cos_1 = torch.ops.prims.device_put.default(cpu_offload_cos_1, device(type='cuda', index=0), non_blocking = True); cpu_offload_cos_1 = None
182+
record_event_default = torch.ops.streams.record_event.default(5, 4); record_event_default = None
183+
stream_out_gpu_reload_cos_1 = torch.ops.streams.join.default(4, 3); stream_out_gpu_reload_cos_1 = None
184+
wait_event_default = torch.ops.streams.wait_event.default(5, 3); wait_event_default = None
185+
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, gpu_reload_cos_1); gpu_reload_cos_1 = None
186+
mul_2 = torch.ops.aten.mul.Tensor(tangents_1, cos_2); tangents_1 = cos_2 = None
187+
return (mul_2, mul_2, mul_1, mul_1, mul, mul)""",
188+
)
189+
190+
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
191+
@torch._functorch.config.patch(
192+
enable_activation_offloading=True,
193+
activation_offload_separate_stream=True,
194+
)
195+
def test_partitioner_offload_sep_stream_accuracy(self):
196+
# Run without compilation to get reference gradients
197+
x_ref = [x.detach().clone().requires_grad_(True) for x in self.x]
198+
out_ref = self.fn(x_ref)
199+
out_ref.sum().backward()
200+
grads_ref = [inp.grad for inp in x_ref]
201+
202+
# Run with aot_eager compilation and offloading enabled
203+
reset_user_object_tracking()
204+
torch._dynamo.reset()
205+
torch._functorch.config.joint_custom_pass = self.joint_custom_pass
206+
x_compile = [x.detach().clone().requires_grad_(True) for x in self.x]
207+
compiled_fn = torch.compile(self.fn, backend="aot_eager")
208+
out_compiled = compiled_fn(x_compile)
209+
out_compiled.sum().backward()
210+
grads_compiled = [inp.grad for inp in x_compile]
211+
212+
# Verify gradients match between reference and compiled versions
213+
for grad_ref, grad_compiled in zip(grads_ref, grads_compiled):
214+
torch.testing.assert_close(
215+
grad_compiled,
216+
grad_ref,
217+
rtol=1e-5,
218+
atol=1e-5,
219+
)
220+
221+
222+
if __name__ == "__main__":
223+
if HAS_GPU:
224+
run_tests()

0 commit comments

Comments
 (0)