|
9 | 9 | import typing |
10 | 10 | import unittest |
11 | 11 | from contextlib import contextmanager |
| 12 | +from copy import deepcopy |
12 | 13 | from typing import List, Optional, Tuple |
13 | 14 |
|
14 | 15 | import executorch.exir as exir |
|
31 | 32 | from executorch.exir.error import InternalError |
32 | 33 | from executorch.exir.passes import MemoryPlanningPass |
33 | 34 | from executorch.exir.passes.constant_prop_pass import constant_prop_pass |
| 35 | +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass |
34 | 36 | from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass |
35 | 37 | from executorch.exir.print_program import pretty_print, print_program # noqa |
36 | 38 | from executorch.exir.schema import ( |
|
56 | 58 | from executorch.extension.pybindings.portable_lib import ( |
57 | 59 | _load_for_executorch_from_buffer, |
58 | 60 | ) |
| 61 | +from executorch.runtime import Runtime |
59 | 62 |
|
60 | 63 | from functorch.experimental import control_flow |
61 | 64 | from torch import nn |
@@ -243,6 +246,56 @@ def forward(self, x): |
243 | 246 | ) |
244 | 247 | self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null) |
245 | 248 |
|
| 249 | + def test_initialized_mutable_buffer(self): |
| 250 | + """Test that mutable buffers can hold meaningful initialized state.""" |
| 251 | + |
| 252 | + class TestModule(torch.nn.Module): |
| 253 | + def __init__(self): |
| 254 | + super().__init__() |
| 255 | + # Mutable buffer with non-empty initial state. |
| 256 | + self.register_buffer("cache_pos", torch.arange(0, 10)) |
| 257 | + |
| 258 | + def forward(self, x): |
| 259 | + self.cache_pos.add_(1) |
| 260 | + return self.cache_pos |
| 261 | + |
| 262 | + m = TestModule() |
| 263 | + example_inputs = (torch.ones(10),) |
| 264 | + ep = torch.export.export(m, example_inputs) |
| 265 | + edge = to_edge( |
| 266 | + ep, |
| 267 | + compile_config=EdgeCompileConfig( |
| 268 | + _check_ir_validity=False, |
| 269 | + ), |
| 270 | + ) |
| 271 | + |
| 272 | + # Save a copy of the edge program since to_executorch is |
| 273 | + # stateful to some degree. |
| 274 | + edge_copy = deepcopy(edge) |
| 275 | + et_config = ExecutorchBackendConfig( |
| 276 | + passes=[InitializedMutableBufferPass(["cache_pos"])], |
| 277 | + ) |
| 278 | + et_program_init_pass = edge.to_executorch(config=et_config) |
| 279 | + et_program_regular = edge_copy.to_executorch() |
| 280 | + |
| 281 | + runtime = Runtime.get() |
| 282 | + program_init_pass = runtime.load_program(et_program_init_pass.buffer) |
| 283 | + method_init_pass = program_init_pass.load_method("forward") |
| 284 | + |
| 285 | + program_regular = runtime.load_program(et_program_regular.buffer) |
| 286 | + method_regular = program_regular.load_method("forward") |
| 287 | + |
| 288 | + # Test that the mutable buffer is initialized. |
| 289 | + torch.allclose( |
| 290 | + method_init_pass.execute((example_inputs))[0], torch.arange(1, 11) |
| 291 | + ) |
| 292 | + # Test that the mutable buffer is uninitialized and starts with default zeros, |
| 293 | + # we test equality with torch.ones because of the mutation += 1 in the model forward. |
| 294 | + torch.allclose( |
| 295 | + method_regular.execute((example_inputs))[0], |
| 296 | + torch.ones(10, dtype=torch.int64), |
| 297 | + ) |
| 298 | + |
246 | 299 | def test_int_list_input(self): |
247 | 300 | class M(torch.nn.Module): |
248 | 301 | def forward(self, x, y, z): |
|
0 commit comments