|
338 | 338 | # `AOTI documentation <https://pytorch.org/docs/main/torch.compiler_aot_inductor.html>`_: |
339 | 339 | # |
340 | 340 |
|
341 | | -from tempfile import TemporaryDirectory |
342 | | - |
343 | | -from torch._inductor import aoti_compile_and_package, aoti_load_package |
344 | | - |
345 | | -with TemporaryDirectory() as tmpdir: |
346 | | - path = str(Path(tmpdir) / "model.pt2") |
347 | | - with torch.no_grad(): |
348 | | - pkg_path = aoti_compile_and_package( |
349 | | - exported_policy, |
350 | | - args=(), |
351 | | - kwargs={"pixels": pixels}, |
352 | | - # Specify the generated shared library path |
353 | | - package_path=path, |
354 | | - ) |
355 | | - print("pkg_path", pkg_path) |
356 | | - |
357 | | - compiled_module = aoti_load_package(pkg_path) |
358 | | - |
359 | | -print(compiled_module(pixels=pixels)) |
360 | | - |
361 | | -##################################### |
362 | | -# An extra feature of AOTInductor is its capacity of dealing with dynamic shapes. This can be useful if you don't know |
363 | | -# the shape of your input data ahead of time. For instance, we may want to run our policy for one, two or more |
364 | | -# observations at a time. For this, let us re-export our policy, marking a new unsqueezed batch dimension as dynamic: |
365 | | - |
366 | | -batch_dim = torch.export.Dim("batch", min=1, max=32) |
367 | | -pixels_unsqueeze = pixels.unsqueeze(0) |
368 | | -exported_dynamic_policy = torch.export.export( |
369 | | - policy_transform, |
370 | | - args=(), |
371 | | - kwargs={"pixels": pixels_unsqueeze}, |
372 | | - strict=False, |
373 | | - dynamic_shapes={"pixels": {0: batch_dim}}, |
374 | | -) |
375 | | -# Then recompile and export |
376 | | -pkg_path = aoti_compile_and_package( |
377 | | - exported_dynamic_policy, |
378 | | - args=(), |
379 | | - kwargs={"pixels": pixels_unsqueeze}, |
380 | | - package_path=path, |
381 | | -) |
| 341 | +# from tempfile import TemporaryDirectory |
| 342 | +# |
| 343 | +# from torch._inductor import aoti_compile_and_package, aoti_load_package |
| 344 | +# |
| 345 | +# with TemporaryDirectory() as tmpdir: |
| 346 | +# path = str(Path(tmpdir) / "model.pt2") |
| 347 | +# with torch.no_grad(): |
| 348 | +# pkg_path = aoti_compile_and_package( |
| 349 | +# exported_policy, |
| 350 | +# args=(), |
| 351 | +# kwargs={"pixels": pixels}, |
| 352 | +# # Specify the generated shared library path |
| 353 | +# package_path=path, |
| 354 | +# ) |
| 355 | +# print("pkg_path", pkg_path) |
| 356 | +# |
| 357 | +# compiled_module = aoti_load_package(pkg_path) |
| 358 | +# |
| 359 | +# print(compiled_module(pixels=pixels)) |
382 | 360 |
|
383 | 361 | ##################################### |
384 | | -# More information about this can be found in the |
385 | | -# `AOTInductor tutorial <https://pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`_. |
386 | 362 | # |
387 | 363 | # Exporting TorchRL models with ONNX |
388 | 364 | # ---------------------------------- |
|
0 commit comments