|
| 1 | +# Owner(s): ["oncall: pt2"] |
| 2 | +# flake8: noqa: B950 |
| 3 | + |
| 4 | +import unittest |
| 5 | +import warnings |
| 6 | +from functools import partial |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch._functorch.config |
| 10 | +from functorch.compile import ( |
| 11 | + aot_function, |
| 12 | + default_decompositions, |
| 13 | + min_cut_rematerialization_partition, |
| 14 | +) |
| 15 | +from torch._dynamo.graph_bytecode_inputs import reset_user_object_tracking |
| 16 | +from torch._inductor.utils import run_fw_bw_and_get_code |
| 17 | +from torch.testing import FileCheck |
| 18 | +from torch.testing._internal.common_utils import run_tests, TestCase |
| 19 | +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU |
| 20 | + |
| 21 | + |
| 22 | +USE_NETWORKX = False |
| 23 | +try: |
| 24 | + import networkx # noqa: F401 |
| 25 | + |
| 26 | + USE_NETWORKX = True |
| 27 | +except ImportError: |
| 28 | + warnings.warn("Some tests use networkx but it was not installed", UserWarning) |
| 29 | + |
| 30 | + |
| 31 | +def extract_graph(fx_g, _, graph_cell): |
| 32 | + graph_cell[0] = fx_g |
| 33 | + return fx_g |
| 34 | + |
| 35 | + |
| 36 | +def get_fw_bw_graph( |
| 37 | + f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False |
| 38 | +): |
| 39 | + fw_graph_cell = [None] |
| 40 | + bw_graph_cell = [None] |
| 41 | + aot_function( |
| 42 | + f, |
| 43 | + fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| 44 | + bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| 45 | + partition_fn=partitioner, |
| 46 | + decompositions=default_decompositions, |
| 47 | + dynamic=dynamic, |
| 48 | + )(*inps).sum().backward() |
| 49 | + return (fw_graph_cell[0], bw_graph_cell[0]) |
| 50 | + |
| 51 | + |
| 52 | +class ActivationOffloadingTests(TestCase): |
| 53 | + """Tests activation offloading functionality""" |
| 54 | + |
| 55 | + def setUp(self): |
| 56 | + super().setUp() |
| 57 | + |
| 58 | + def fn(x): |
| 59 | + return (x[0] + x[1]).sin() + (x[2] + x[3]).sin() + (x[4] + x[5]).sin() |
| 60 | + |
| 61 | + def mark_one_cos_for_offloading(gm, joint_inputs): |
| 62 | + for node in gm.graph.nodes: |
| 63 | + if node.name == "cos_1": |
| 64 | + node.meta["should_offload"] = True |
| 65 | + return gm |
| 66 | + |
| 67 | + dim = 10 |
| 68 | + self.x = [ |
| 69 | + torch.randn(dim, dim, requires_grad=True, device=GPU_TYPE) for _ in range(6) |
| 70 | + ] |
| 71 | + self.fn = fn |
| 72 | + self.joint_custom_pass = mark_one_cos_for_offloading |
| 73 | + |
| 74 | + """ |
| 75 | + The first set of tests are for the case of adding offload nodes to the fwd and bwd graphs. |
| 76 | + """ |
| 77 | + |
| 78 | + @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| 79 | + @torch._functorch.config.patch(enable_activation_offloading=True) |
| 80 | + def test_partitioner_offload(self): |
| 81 | + torch._dynamo.reset() |
| 82 | + torch._functorch.config.joint_custom_pass = self.joint_custom_pass |
| 83 | + fw_graph, bw_graph = get_fw_bw_graph(self.fn, [self.x]) |
| 84 | + |
| 85 | + self.assertExpectedInline( |
| 86 | + fw_graph.code.strip(), |
| 87 | + """\ |
| 88 | +def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6): |
| 89 | + add = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None |
| 90 | + sin = torch.ops.aten.sin.default(add) |
| 91 | + add_1 = torch.ops.aten.add.Tensor(primals_3, primals_4); primals_3 = primals_4 = None |
| 92 | + sin_1 = torch.ops.aten.sin.default(add_1) |
| 93 | + add_2 = torch.ops.aten.add.Tensor(sin, sin_1); sin = sin_1 = None |
| 94 | + add_3 = torch.ops.aten.add.Tensor(primals_5, primals_6); primals_5 = primals_6 = None |
| 95 | + sin_2 = torch.ops.aten.sin.default(add_3) |
| 96 | + add_4 = torch.ops.aten.add.Tensor(add_2, sin_2); add_2 = sin_2 = None |
| 97 | + cos = torch.ops.aten.cos.default(add_3); add_3 = None |
| 98 | + cos_1 = torch.ops.aten.cos.default(add_1); add_1 = None |
| 99 | + cpu_offload_cos_1 = torch.ops.prims.device_put.default(cos_1, device(type='cpu'), non_blocking = True); cos_1 = None |
| 100 | + cos_2 = torch.ops.aten.cos.default(add); add = None |
| 101 | + return (add_4, cos, cpu_offload_cos_1, cos_2)""", |
| 102 | + ) |
| 103 | + |
| 104 | + self.assertExpectedInline( |
| 105 | + bw_graph.code.strip(), |
| 106 | + """\ |
| 107 | +def forward(self, cos, cpu_offload_cos_1, cos_2, tangents_1): |
| 108 | + mul = torch.ops.aten.mul.Tensor(tangents_1, cos); cos = None |
| 109 | + gpu_reload_cos_1 = torch.ops.prims.device_put.default(cpu_offload_cos_1, device(type='cuda', index=0), non_blocking = True); cpu_offload_cos_1 = None |
| 110 | + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, gpu_reload_cos_1); gpu_reload_cos_1 = None |
| 111 | + mul_2 = torch.ops.aten.mul.Tensor(tangents_1, cos_2); tangents_1 = cos_2 = None |
| 112 | + return (mul_2, mul_2, mul_1, mul_1, mul, mul)""", |
| 113 | + ) |
| 114 | + |
| 115 | + def test_inductor_offload(self): |
| 116 | + torch._dynamo.reset() |
| 117 | + |
| 118 | + def run_compiled(): |
| 119 | + torch._functorch.config.enable_activation_offloading = True |
| 120 | + torch._functorch.config.joint_custom_pass = self.joint_custom_pass |
| 121 | + return torch.compile(self.fn)(self.x) |
| 122 | + |
| 123 | + _, (fw_code, bw_code) = run_fw_bw_and_get_code(run_compiled) |
| 124 | + |
| 125 | + ( |
| 126 | + FileCheck() |
| 127 | + .check("buf3 = empty_strided_cpu_pinned(") |
| 128 | + .check("buf3.copy_(buf2, True)") |
| 129 | + .run(fw_code) |
| 130 | + ) |
| 131 | + |
| 132 | + ( |
| 133 | + FileCheck() |
| 134 | + .check("buf1 = empty_strided_cuda(") |
| 135 | + .check("buf1.copy_(cpu_offload_cos_1, True)") |
| 136 | + .check("del cpu_offload_cos_1") |
| 137 | + .run(bw_code) |
| 138 | + ) |
| 139 | + |
| 140 | + @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| 141 | + @torch._functorch.config.patch( |
| 142 | + enable_activation_offloading=True, |
| 143 | + activation_offload_separate_stream=True, |
| 144 | + ) |
| 145 | + def test_partitioner_offload_sep_stream(self): |
| 146 | + reset_user_object_tracking() |
| 147 | + torch._dynamo.reset() |
| 148 | + torch._functorch.config.joint_custom_pass = self.joint_custom_pass |
| 149 | + fw_graph, bw_graph = get_fw_bw_graph(self.fn, [self.x]) |
| 150 | + |
| 151 | + self.assertExpectedInline( |
| 152 | + fw_graph.code.strip(), |
| 153 | + """\ |
| 154 | +def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6): |
| 155 | + add = torch.ops.aten.add.Tensor(primals_1, primals_2); primals_1 = primals_2 = None |
| 156 | + sin = torch.ops.aten.sin.default(add) |
| 157 | + add_1 = torch.ops.aten.add.Tensor(primals_3, primals_4); primals_3 = primals_4 = None |
| 158 | + sin_1 = torch.ops.aten.sin.default(add_1) |
| 159 | + add_2 = torch.ops.aten.add.Tensor(sin, sin_1); sin = sin_1 = None |
| 160 | + add_3 = torch.ops.aten.add.Tensor(primals_5, primals_6); primals_5 = primals_6 = None |
| 161 | + sin_2 = torch.ops.aten.sin.default(add_3) |
| 162 | + add_4 = torch.ops.aten.add.Tensor(add_2, sin_2); add_2 = sin_2 = None |
| 163 | + cos = torch.ops.aten.cos.default(add_3); add_3 = None |
| 164 | + cos_1 = torch.ops.aten.cos.default(add_1); add_1 = None |
| 165 | + record_event_default = torch.ops.streams.record_event.default(2, 0); record_event_default = None |
| 166 | + stream_in_cpu_offload_cos_1 = torch.ops.streams.fork.default(0, 1); stream_in_cpu_offload_cos_1 = None |
| 167 | + wait_event_default = torch.ops.streams.wait_event.default(2, 1); wait_event_default = None |
| 168 | + cpu_offload_cos_1 = torch.ops.prims.device_put.default(cos_1, device(type='cpu'), non_blocking = True); cos_1 = None |
| 169 | + stream_out_cpu_offload_cos_1 = torch.ops.streams.join.default(1, 0); stream_out_cpu_offload_cos_1 = None |
| 170 | + cos_2 = torch.ops.aten.cos.default(add); add = None |
| 171 | + return (add_4, cos, cpu_offload_cos_1, cos_2)""", |
| 172 | + ) |
| 173 | + |
| 174 | + self.assertExpectedInline( |
| 175 | + bw_graph.code.strip(), |
| 176 | + """\ |
| 177 | +def forward(self, cos, cpu_offload_cos_1, cos_2, tangents_1): |
| 178 | + mul = torch.ops.aten.mul.Tensor(tangents_1, cos); cos = None |
| 179 | + stream_in_gpu_reload_cos_1 = torch.ops.streams.fork.default(3, 4); stream_in_gpu_reload_cos_1 = None |
| 180 | + wait_stream_default = torch.ops.streams.wait_stream.default(4, 3); wait_stream_default = None |
| 181 | + gpu_reload_cos_1 = torch.ops.prims.device_put.default(cpu_offload_cos_1, device(type='cuda', index=0), non_blocking = True); cpu_offload_cos_1 = None |
| 182 | + record_event_default = torch.ops.streams.record_event.default(5, 4); record_event_default = None |
| 183 | + stream_out_gpu_reload_cos_1 = torch.ops.streams.join.default(4, 3); stream_out_gpu_reload_cos_1 = None |
| 184 | + wait_event_default = torch.ops.streams.wait_event.default(5, 3); wait_event_default = None |
| 185 | + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, gpu_reload_cos_1); gpu_reload_cos_1 = None |
| 186 | + mul_2 = torch.ops.aten.mul.Tensor(tangents_1, cos_2); tangents_1 = cos_2 = None |
| 187 | + return (mul_2, mul_2, mul_1, mul_1, mul, mul)""", |
| 188 | + ) |
| 189 | + |
| 190 | + @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| 191 | + @torch._functorch.config.patch( |
| 192 | + enable_activation_offloading=True, |
| 193 | + activation_offload_separate_stream=True, |
| 194 | + ) |
| 195 | + def test_partitioner_offload_sep_stream_accuracy(self): |
| 196 | + # Run without compilation to get reference gradients |
| 197 | + x_ref = [x.detach().clone().requires_grad_(True) for x in self.x] |
| 198 | + out_ref = self.fn(x_ref) |
| 199 | + out_ref.sum().backward() |
| 200 | + grads_ref = [inp.grad for inp in x_ref] |
| 201 | + |
| 202 | + # Run with aot_eager compilation and offloading enabled |
| 203 | + reset_user_object_tracking() |
| 204 | + torch._dynamo.reset() |
| 205 | + torch._functorch.config.joint_custom_pass = self.joint_custom_pass |
| 206 | + x_compile = [x.detach().clone().requires_grad_(True) for x in self.x] |
| 207 | + compiled_fn = torch.compile(self.fn, backend="aot_eager") |
| 208 | + out_compiled = compiled_fn(x_compile) |
| 209 | + out_compiled.sum().backward() |
| 210 | + grads_compiled = [inp.grad for inp in x_compile] |
| 211 | + |
| 212 | + # Verify gradients match between reference and compiled versions |
| 213 | + for grad_ref, grad_compiled in zip(grads_ref, grads_compiled): |
| 214 | + torch.testing.assert_close( |
| 215 | + grad_compiled, |
| 216 | + grad_ref, |
| 217 | + rtol=1e-5, |
| 218 | + atol=1e-5, |
| 219 | + ) |
| 220 | + |
| 221 | + |
| 222 | +if __name__ == "__main__": |
| 223 | + if HAS_GPU: |
| 224 | + run_tests() |
0 commit comments