Skip to content

PyTorch 2.1: automatic dynamic shape compilation, distributed checkpointing

Compare
Choose a tag to compare
@jerryzh168 jerryzh168 released this 04 Oct 17:32
· 9043 commits to main since this release
7bcf7da

PyTorch 2.1 Release Notes

  • Highlights
  • Backwards Incompatible Change
  • Deprecations
  • New Features
  • Improvements
  • Bug fixes
  • Performance
  • Documentation
  • Developers
  • Security

Highlights

We are excited to announce the release of PyTorch® 2.1! PyTorch 2.1 offers automatic dynamic shape support in torch.compile, torch.distributed.checkpoint for saving/loading distributed training jobs on multiple ranks in parallel, and torch.compile support for the NumPy API.

In addition, this release offers numerous performance improvements (e.g. CPU inductor improvements, AVX512 support, scaled-dot-product-attention support) as well as a prototype release of torch.export, a sound full-graph capture mechanism, and torch.export-based quantization.

Along with 2.1, we are also releasing a series of updates to the PyTorch domain libraries. More details can be found in the library updates blog.

This release is composed of 6,682 commits and 784 contributors since 2.0. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.1. More information about how to get started with the PyTorch 2-series can be found at our Getting Started page.

Summary:

  • torch.compile now includes automatic support for detecting and minimizing recompilations due to tensor shape changes using automatic dynamic shapes.
  • torch.distributed.checkpoint enables saving and loading models from multiple ranks in parallel, as well as resharding due to changes in cluster topology.
  • torch.compile can now compile NumPy operations via translating them into PyTorch-equivalent operations.
  • torch.compile now includes improved support for Python 3.11.
  • New CPU performance features include inductor improvements (e.g. bfloat16 support and dynamic shapes), AVX512 kernel support, and scaled-dot-product-attention kernels.
  • torch.export, a sound full-graph capture mechanism is introduced as a prototype feature, as well as torch.export-based quantization.
  • torch.sparse now includes prototype support for semi-structured (2:4) sparsity on NVIDIA® GPUs.
Stable Beta Prototype Performance Improvements
Automatic Dynamic Shapes torch.export() AVX512 kernel support
torch.distributed.checkpoint torch.export-based Quantization CPU optimizations for scaled-dot-product-attention (SDPA)
torch.compile + NumPy semi-structured (2:4) sparsity CPU optimizations for bfloat16
torch.compile + Python 3.11 cpp_wrapper for torchinductor
torch.compile + autograd.Function
third-party device integration: PrivateUse1

*To see a full list of public 2.1, 2.0, and 1.13 feature submissions click here.

For more details about these highlighted features, you can look at the release blogpost.
Below are the full release notes for this release.

Backwards Incompatible Changes

Building PyTorch from source now requires C++ 17 (#100557)

The PyTorch codebase has migrated from the C++14 to the C++17 standard, so a C++17 compatible compiler is now required to compile PyTorch, to integrate with libtorch, or to implement a C++ PyTorch extension.

Disable torch.autograd.{backward, grad} for complex scalar output (#92753)

Gradients are not defined for functions that don't return real outputs; we now raise an error if you try to call backward on complex outputs. Previously, the complex component of the output was implicitly ignored. If you wish to preserve this behavior, you must now explicitly call .real on your complex outputs before calling .grad() or .backward().

Example

def fn(x):
    return (x * 0.5j).sum()

x = torch.ones(1, dtype=torch.double, requires_grad=True)
o = fn(x)

2.0.1

o.backward()

2.1

o.real.backward()

Update non-reentrant checkpoint to allow nesting and support autograd.grad (#90105)

As a part of a larger refactor to torch.utils.checkpoint, we changed the interaction activation checkpoint and retain_graph=True. Previously in 2.0.1, recomputed activations are kept alive if retain_graph=True, in PyTorch 2.1, non-reentrant impl now clears recomputed tensors on backward immediately upon unpack, even if retain_graph=True. This has the following additional implications: (1) Accessing ctx.saved_tensor twice in the same backward will now raise an error. (2) Accessing _saved_tensors multiple times will silently recompute forward multiple times.

2.1

class Func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        out = x.exp()
        ctx.save_for_backward(out)
        return out

    @staticmethod
    def backward(ctx, x);
        out, = ctx.saved_tensors
        # Calling ctx.saved_tensors again will raise in 2.1
        out, = ctx.saved_tensors
        return out

a = torch.tensor(1., requires_grad=True)

def fn(x):
    return Func.apply(x)


out = torch.utils.checkpoint(fn, (a,), use_reentrant=False)

def fn2(x):
    return x.exp()

out = torch.utils.checkpoint(fn2, (a,), use_reentrant=False)

out.grad_fn._saved_result
# Calling _saved_result will trigger another unpack, and lead to forward being
# recomputed again
out.grad_fn._saved_result

Only sync buffers when broadcast_buffers is True (#100729)

  • In PyTorch 2.0.1 and previous releases, when users use DistributedDataParallel (DDP), all buffers were synced automatically even if users set flag broadcast_buffers to be False:
from torch.nn.parallel import DistributedDataParallel as DDP
module = torch.nn.Linear(4, 8)
module = DDP(module) # Buffer is synchronized across all devices.
module = DDP(module, broadcast_buffers=False) # Buffer is synchronized across all devices.
...
  • Starting with PyTorch 2.1, if users specify the flag broadcast_buffers to be False, we don’t sync the buffer across devices:
from torch.nn.parallel import DistributedDataParallel as DDP
module = torch.nn.Linear(4, 8)
module = DDP(module) # Buffer is synchronized across all devices.
module = DDP(module, broadcast_buffers=False) # Buffer is NOT synchronized across all devices
...

Remove store barrier after PG init (#99937)

  • In PyTorch 2.0.1 and previous releases, after we initialize PG, we always call store based barrier:
from torch.distributed.distributed_c10d import init_process_group
init_process_group(...) # Will call _store_based_barrier in the end.
...
  • Starting with PyTorch 2.1, after we initialize PG, the environment variable TORCH_DIST_INIT_BARRIER controls whether we call store based barrier or not:
from torch.distributed.distributed_c10d import init_process_group
import os
os.environ["TORCH_DIST_INIT_BARRIER"] = "1" # This is the default behavior
init_process_group(...) # Will call _store_based_barrier in the end.
os.environ["TORCH_DIST_INIT_BARRIER"] = "0"
init_process_group(...) # Will not call _store_based_barrier in the end.
...

Disallow non-bool masks in torch.masked_{select, scatter, fill} (#96112, #97999, #96594)

Finish the deprecation cycle for non-bool masks. Functions now require the dtype of the mask to be torch.bool.

>>> # 2.0.1
>>> inp = torch.rand(3)
>>> mask = torch.tensor([0, 1, 0], dtype=torch.uint8)
>>> torch.masked_select(inp, mask)
UserWarning: masked_select received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/TensorAdvancedIndexing.cpp:1855.)
  torch.masked_select(inp, mask)

>>> torch.masked_select(inp, mask.to(dtype=torch.bool))
# Works fine

>>> correct_mask = torch.tensor([0, 1, 0], dtype=torch.bool)
>>> torch.masked_select(inp, correct_mask)
# Works fine

>>> # 2.1
>>> inp = torch.rand(3)
>>> mask = torch.tensor([0, 1, 0], dtype=torch.uint8)
>>> torch.masked_select(inp, mask)
RuntimeError: masked_select: expected BoolTensor for mask

>>> correct_mask = torch.tensor([0, 1, 0], dtype=torch.bool)
>>> torch.masked_select(inp, correct_mask)
# Works fine

>>> torch.masked_select(inp, mask.to(dtype=torch.bool))
# Works fine

Fix the result of torch.unique to make it consistent with NumPy when dim is specified (#101693)

The dim argument was clarified and its behavior aligned to match the one from NumPy to signify which sub-tensor to consider when considering uniqueness. See the documentation for more details, https://pytorch.org/docs/stable/generated/torch.unique.html

Make the Index Rounding Mode Consistent Between the 2D and 3D GridSample Nearest Neighbor Interpolations (#97000)

Prior to this change, for torch.nn.functional.grid_sample(mode='nearest') the forward 2D kernel used std::nearbyint whereas the forward 3D kernel used std::round in order to determine the nearest pixel locations after un-normalization of the grid. Additionally, the backward kernels for both used std::round. This PR fixes the inconsistencies to use std::nearbyint which rounds values that are exactly <>.5 to the nearest even which is consistent with the behavior of torch.round. Unnormalized indices that are exactly <>.5 will now be rounded to the nearest even instead of being rounded away from 0.

Turned input shapes (aka record_shapes) off by default for on-demand tracing (#97917)

Profiler traces collected by on-demand tracing via IPC Fabric will have record_shapes off my default.

  • In v2.0.1:
    By default, profiler trace files’ cpu_op activities will contain metadata fields: Input Dims, and Input type.

  • In v2.1.0:
    By default, profiler trace files’ cpu_op activities will no longer contain metadata fields for input shapes. If turned on via Kineto config, it will show metadata fields: Input Dims, Input type and Concrete Inputs.

When called with a 0-dim tensor input, torch.aminmax would previously inconsistently return a 1D tensor output on CPU, but a 0D tensor output on CUDA. This has been fixed, so we consistently return a 0D tensor in both cases. (#96171).

In v2.0.1:

>>> torch.aminmax(torch.tensor(1, device='cpu'), dim=0, keepdim=True)
__main__:1: UserWarning: An output with one or more elements was resized since it had shape [], which does not match the required output shape [1]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (Triggered internally at ../aten/src/ATen/native/Resize.cpp:24.)
torch.return_types.aminmax(
min=tensor([1]),
max=tensor([1]))
>>> torch.aminmax(torch.tensor(1, device='cpu'), dim=0, keepdim=False)
torch.return_types.aminmax(
min=tensor(1),
max=tensor(1))

In v2.1.0:

>>> torch.aminmax(torch.tensor(1, device='cpu'), dim=0, keepdim=True)
torch.return_types.aminmax(
min=tensor(1),
max=tensor(1))
>>> torch.aminmax(torch.tensor(1, device='cpu'), dim=0, keepdim=False)
torch.return_types.aminmax(
min=tensor(1),
max=tensor(1))

Change to the default behavior for custom operators registered to the dispatcher, that do not have anything registered to an Autograd dispatch key

If you have a custom operator that has a CPU/CUDA kernel registered to the CPU/CUDA dispatch key, but has no implementation at the Autograd key, then:

Old behavior: When calling this operator with tensor inputs that require gradients, the tensor outputs would silently not require gradients.

New behavior: When calling this operator with tensor inputs that do require gradients, the tensor outputs would require gradients (as long as the outputs are floating-point or complex), and will error if you try to backpropagate through them.

There is more information on how to recover the old behavior in the PR: (#104481, #105078)

torch.autograd.Function Raise an error if input is returned as-is and saved for forward or backward in setup_context (#98051)

If you are writing a custom autograd Function and you have implemented your autograd Function using setup_context, and if your forward function returns an input as-is as output, then saving that tensor for forward or backward now raises an error. You should return an alias of the input instead.

2.0.1

class Cube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        return x ** 3, x

    @staticmethod
    def setup_context(ctx, inputs, outputs):
        cube, x = outputs
        ctx.save_for_backward(x)

    @staticmethod
    def backward(ctx, grad_output, grad_x):
        # NB: grad_x intentionally not used in computation
        x, = ctx.saved_tensors
        result = grad_output * 3 * x ** 2
        return result

2.1

class Cube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        return x ** 3, x.view_as(x)

    ...

Deprecations

Deprecate not specifying the use_reentrant flag explicitly when using torch.utils.checkpoint (#100551)

In PyTorch 2.1, if the use_reentrant flag is not explicitly passed, a warning is raised. To retain current behavior, pass use_reentrant=True. The default value will be updated to use_reentrant=False in the future. We recommend using use_reentrant=False.

2.1

torch.utils.checkpoint(fn, (a,)) # Warns in 2.1

Deprecate torch.has_* attributes (#103279)

Use the version in the particular backend module at torch.backends.* to access these flags.
Also note that we now properly differente is_built() (compile time availability) and is_available() (runtime availability) in these modules.

Deprecate check_sparse_nnz argument for torch.autograd.gradcheck (#97187)

2.0.1

torch.autograd.gradcheck(fn, inputs, check_sparse_nnz=True)

2.1

torch.autograd.gradcheck(fn, inputs, masked=True)

NVFuser integration with TorchScript is deprecated (#105185)

NVFuser replaced Neural Network Compiler (NNC) as the default GPU fuser for TorchScript in PyTorch 1.13. In PyTorch 2.1, TorchScript switched its default fuser back to NNC. Additionally, NVFuser for TorchScript is now deprecated. Currently, users can still manually choose to use NVFuser instead of NNC, see fuser options for details on how to do this.

New features

Release Engineering

  • Adding AArch64 wheel builds (#104109)
  • CUDA 12.1 support for PyTorch binaries (#107295)
  • Compile PyTorch on M1 natively (instead of cross compiling from x86) (#95719)
  • Enable UCC Distributed Communication Backend in CI (#100395)

Python Frontend

  • Enable torch.device to be used as a context manager to change the default device (#106514)
  • Add torch.Tensor.dim_order field to access current dimension permutation (#106835)
  • Add torch.backends.cpu.get_cpu_capability() to expose cpu properties (#100164)
  • Add support for PrivateUse1 device (out of tree device) for untyped storage (#100868)
  • Add torch._foreach_pow (#92303)

optim

  • Provide NAdamW implementation through the decoupled_weight_decay flag (#103881, #107706)
  • Add xpu support for foreach kernels (#106021)
  • Add capturable API w/ tests + fix differentiable for NAdam (#106615)
  • Add pre hooks and post hooks for optimizer load_state_dict() and state_dict() (#105953, #106209)

torch.compile

Sparse Frontend

Autograd

  • Add backward support for out-of-place foreach functions (#93901)
  • Add backward support for for in-place foreach functions (#96405)
  • Add backward support on _foreach_zero_ (#101149)
  • Add forward mode AD to out-place foreach functions (#106320)
  • Add forward over backward support for torch.nn.functional.logsigmoid (#99288)
  • Add forward mode AD to in-place foreach functions (#100695)
  • Add forward mode AD for torch.renorm (#100798)

torch.nn

  • Add an optional scale kwarg to scaled_dot_product_attention (#95259)
  • Add non-recursive nn.Module.to_empty option (#104197)
  • Add keyword argument to allow disabling bias for LayerNorm (#101683)
  • Add option to always call nn.Module global/non-global forward hooks (#104278)
  • Add keyword argument to allow disabling bias for Transformer (#101687)

torch.export

functorch

  • Add experimental support for functorch transforms and torch.compile composition (#98328, #106610, #107462, and others)
  • Add functorch.einops.rearrange (#101957)

Distributed

c10d

  • Add PrivateUse1 for dispatching PyTorch Distributed Collectives to support custom device. (#98137)
  • Add new Store methods: append, multi_get, multi_set. (#100379)
  • Implement coalesced all_gather_into_tensor (#101157)
  • Implement coalesced reduce_scatter_tensor (#103561)
  • Add back in reduce_scatter_tensor_coalesced (#104345)
  • Design a new fake process group aimed at running a single rank with a fake process group without needing multiple processes. A fake process group (not related to FakeTensor) is a process group which doesn't actually do any communication, but instead just just hallucinates communication. (#102180, #102238, #104213, #104428)
  • Enable barrier to support the specified device (#99589)
  • Add xpu to the default device supported by user specified backend (#103410)
  • Support third-party devices to use the init_process_group method without specifying the backend (#107113)

Distributed Tensor

  • Add DTensor constructor function: ones/empty/full (#100933, #101022, #103165)
  • Enable DTensor based Native sequence parallelism (#94369)
  • Enable DDP + TP 2D parallelism (#106583)
  • Enable deviceMesh to use dispatchable PG to support custom backend (#102336)
  • Allow ONNX Runtime (ORT) backend for DTensor (#101914)

FullyShardedDataParallel:

  • Introduce CustomPolicy in FSDP wrapping (#104986)
  • Add FSDP support for creating hybrid-sharded process group for custom backend (#100622)

DTensor based Distributed Checkpoint

  • Add 1D DTensor based DCP (#94868)

Profiler

  • Add a global flag to record concrete shapes, which are Scalar lists in profiler traces (#101043, #101292)
  • Add export_memory_timeline to save memory timeline plot to file (#96137, #96535)
  • Reintroduce forward-backward links in profiler traces with a global flag (#102424, #102492)
  • Add Kineto synchronization events in profiler traces (#105187, #105144)
  • Add support for cuLaunchKernel in profiler traces for triton kernel launches including flow events in profiler traces (#99571)
  • Add CUDA runtime events up to CUDA 12.0 for traces, and added flow events for H100’s cudaLaunchKernelExC (#106293)

ONNX

New TorchDynamo ONNX Exporter

New torch.compile ONNX Runtime backend (#107973, #106929, #106589)

Usage: `torch.compile(..., backend="onnxrt")`
    Available when `torch.onnx.is_onnxrt_backend_supported()` returns `True`
    Additional Python package dependencies: `onnx`, `onnxscript`, `onnxruntime`

Additional TorchScript ONNX exporter operators:

Others

  • Add initial support for FP8 ONNX export (#107962)

MPS

  • Add support for MPSProfiler (#100635, #101002, #101692)
  • Enable saved models to be loaded directly to MPS through torch.jit.load (#102204)
  • Introduce torch.mps.Event() APIs (#102121)

torch.fx

  • Add attribute node matching in the subgraph rewriter (#98604)
  • Add variadic arg matching in the subgraph matcher (#99431)
  • Add a flag to ignore literals with subgraph matcher (#97683)
  • Add a prototype source_fn based partitioner to partition modules that were flattened in export (#98628, #101121)
  • Add aggressive_merge to CapabilityBasedPartitioner which merges independent subgraphs (#100195)

Quantization

  • Add various uninterpreted bit tensor data types (torch.{bits1x8,bits2x4,bits4x2,bits8,bits16}) (#95860)
  • Add basic cuda support for float8 dtypes (#105807)
  • Add Autocast Support for XLA (#96370)

Export Quantization:

JIT

  • Register ops for torch.get_cpu_capability, Tensor.is_xla, so they can be captured by torchscript (#100723)
  • Provide __prepare_scriptable__ on non-nn.Module classes as an escape hatch to provide a scriptable alternate implementation (#106229)
  • Introduce API to deepcopy a JIT module onto a new device (#106521)

Vulkan

  • Add Vulkan support for the following operators: aten::unsqueezefor 2d to 3d (#101719), aten::cat operator for 1d, 2d, 3d and 4d (#102128), aten::expand (#103930), aten::flip (#106628), gelu (#102762), aten::masked_fill (#104444), aten::pow (#105550), at::softmax 1,2,3 dimension tensors (#105012), at::softmax along all dimensions for 4-dim Tensors (#102988), sum.dim_IntList (#105612), sum.dim_IntList with keepdim (#106159), at::select.int operator, 4 dim/rank tensor case (#96228), aten::stack (#103344), aten::uniform (#102431), aten::unsqueeze, 1d->2d, 3d->4d (#102987), aten::repeat (#103255), aten::tile (#103944), aten::zero_ (#103042), aten::zeros (#103703), convert_qconv2d_context (#97714), "height" and "width" dimension for select operator (#94612), t and transpose operators for 2d, 3d and 4d tensors (#101808), unary ops (#104994), upsample_bilinear2d (#98022), upsample_nearest2d and quantized_upsample_nearest2d (#97467), quantize_per_tensor vulkan backend function (#106641), quantized binary ops (add/sub/mul/div), and adding graph rewrites for quantized add, mul, conv2d and conv2d_relu (#97468)
  • Add broadcast support for 4D tensors where the batch and channel of a tensor are different (#104718)
  • Templatize BinaryOp.cpp (#105380)

Improvements

Python Frontend

  • Support non-ASCII characters in model file paths for torch.{save,load} (#99453)
  • Enable registering fallthroughs via torch.library (#106086)
  • Add support for saving and loading with any Endianness to torch.{save,load} (#94503)
  • Add torch.storage.UntypedStorage.get_device method (#99818)
  • Add type hint for torch.__init__.py (#106214, #103807), torch.Tensor.retains_grad (#103528)
  • Add support for HPU device for serialization (#101680)
  • Add support for XPU device for old-style Tensor classes (#96656), storage resize_ (#105262)
  • Mark torch.bincount deterministic on CUDA if weights are not given (#105244)
  • Properly expose all constraints on the torch.distributions (#106458)
  • Add itemsize and nbytes properties to Tensor (#98322)
  • Add complex support for torch.expm1 (#96644)
  • Add nonzero_static op to pytorch to unblock export (#97417)
  • Tweak heuristic for Scaled Dot Product Attention (SDPA) selection based off of data (and a decision tree) (#99644)
  • Fix print tensor.type() issue. (#96381)
  • Improve error messages in THPVariable_set_grad (#100683)
  • Enable new_full's fill_value argument type to be complex, for more accurate type checking (#91345)
  • Add 0-dim (zero dimension) Tensor overload to _foreach_mul (#106677)
  • Add in-place _foreach_copy (#107226)
  • Support floating point correction value for std/var operators (#94073)

Dataloader and DataPipe

  • Fix validate_input_col for partial functions (#95067)
  • Add support for pin memory on custom device (#97621)
  • Fix collation logic (#97789)
  • Add context to NotImplementedErrors in dataset.py (#100667)
  • Add __getitems__ to description of Dataset API, and also better support within Subset (#100375)
  • Adding StackDataset (#101338)
  • Change DataPipe display name in profiler (#100042)
  • Add copy option to fork DataPipe (#96030)
  • Fix missing imports in DataPipe interface file (#97458)
  • Do not materialize entire randperm in RandomSampler (#103339)

torch.nn

  • Add check that embedding_bag's weight is 2D (#94931)
  • Improve error message for instance norm when channels is incorrect (#94624)
  • Add generator argument to nn.init.trunc_normal_ (#100810)
  • Improve clip_grad_norm to use torch.linalg.vector_norm (#102429)
  • Add uint8 support for CPU images in interpolate(mode='bicubic’) (#103252)
  • Allow nn.ChannelShuffle to run without error on CUDA tensors (#105351)
  • Add nn.CircularPad{3/4/5d and fixed no_batch_dim support for CircularPad (#106632)
  • Add reset_parameters for torch.nn.PRelu (#106507)
  • Use accumulate type to improve accuracy of grid_sample on half precision inputs on CUDA (#96586)
  • Improve performance for vectorized bilinear interpolate cpu uint8 channels last (#96848))
  • Add differentiable mkldnn_rnn_layer_backward to support double backward of LSTM (#100627)
  • Add is_causal API for TransformerDecoder (#97166)
  • Add is_causal hints for Transformer (#106143)
  • Enable channels last for reflection padding on CPU (#102518, #102597)
  • Add bfloat16 support for reflection and replication padding (#102949)
  • Add SyncBatchNorm support for custom device (#104250)
  • Add channels last 3d support for BatchNorm on CPU (#97774)

functorch

  • Add torch.vmap support for torch.complex (#96032), overloads of float_power, where, and comparison ops. (#96744), linalg.lu_factor (#94328), ldl_factor (#97518), torch.copysign (#96018), torch.nn.functional.smooth_l1_loss (#98357), nn.functional.huber_loss (#99235, #99236), special bessel functions (#99543), torch.nn.functional.{max_pool1d, max_pool3d} batch_rule (#99517, #99522), Tensor.index_fill (#99229), torch.bucketize (#95783), smooth_l1_loss_backward (#99429)

optim

  • Merge and improve torch optim optimizer type stubs (#102593)
  • Allow fused optimizers to call _foreach_zero_ in zero_grad (#97159)
  • Add multi Stochastic Weight Averaging (SWA) support for custom device (#103297)
  • Use torch._foreach_lerp for SWA update (#103550)
  • Add XLA friendly codepath to single_tensor_adamw (#102858)

Linear Algebra

  • lerp(cpu): Add half support (#105607)
  • norm(cpu): Accumulate in float when inputs are half or bfloat16 (#95166)
  • matmul: Avoid unnecessary copies (#97355)
  • matmul backwards: Don’t create large intermediary tensors (#95261)
  • addmm: Call to mkldnn_matmul on AArch64 (#91763)
  • addmm(cuda): Enable addmm + GELU epilogue fusion (#103811)
  • dot/mv/gemm(cpu): Accumulate in float for bfloat16 inputs in the fallback path (#96074)
  • bmm: Heuristics for AArch64 (#107167)
  • baddbmm(cpu): Fix grain size setting (#98297)
  • mm(cuda): Expose cuBLAS int8@int8 -> int32 (#96685)
  • triu/tril: complete dtype support. (#101414)
  • Enable hipSOLVER in ROCm builds (#97370)
  • Improve error message in ADDMM_META(). (#105309)
  • Allow setting TORCH_LINALG_PREFER_CUSOLVER=1 to prefer cusolver as linear algebra library globally (#106226)
  • ldl_factor(cuda): Enable hipSOLVER backend in ROCM (#102665)
  • Add SymInt support for {tensordot,inner,linalg.{matrix_power,tensorinv}}. (#100356, #101940, #102465)
  • Add fake tensor support for SVD. (#100130)

Autograd

  • Improve torch.utils.checkpoint with use_reentrant=False:
    • Support recursive checkpointing; allow grad calls within checkpointed function (#90105)
    • Allow the specification of a pair of context functions via context_fn= (#96783)
    • Stop recomputation early if possible; enabled by default but also expose a way to disable (#96866)
    • Improve debuggability of activation checkpoint; expose debug= and determinism_check kwargs (#103859)
  • Allow torch.inference_mode, torch.no_grad, torch.enable_grad decorators to be used without parens (#107086)
  • Allow torch.autograd.set_multithreading_enabled to act as function and context manager (#105291)
  • Add materialize_grads parameter to torch.autograd.grad() (#97015)
  • Allow torch.autograd.Function to save non-input leaf tensors for backward (#104039)
  • sampled_addmm: backward performance improvements (#103544)

Sparse

  • Add rudimentary support for addmv(strided, CSR, strided) on CPUs without MKL support (#97353, #97730)
  • Implement sparse semantics support in gradcheck (#94714, #95405, #96095, #107150)
  • Add add(COO, COO) for BFloat16 on CPU (#96767)
  • Add support for negative dim to torch.sparse.softmax for COO (#102172)
  • Add support for dim to sum for CSR on CPU and CUDA (#99292)
  • Add integer overflow checks to sparse tensor invariant checker for large compressed tensor dimensions and large nnz (#102530)

Nested Tensor

  • Support zeros_like() and randn_like() for nested tensor (#96527, #96528)
  • Add backwards for layer norm for nested tensor (#94781)
  • Support elementwise add / mul for [B, *] nested, [B, 1] dense (CUDA only) (#95620)
  • Enabling FlashAttention for SDPA when given NestedTensor (#95438)
  • Add sub, sgn abs ops for nested tensor (#97837)
  • Implement last dim split_with_sizes for NestedTensor(forward only, non-SymInt-ified) (#97446)

Foreach Frontend

  • Move tensor grouping to ATen (#103912)
  • Disable grouping by dtype and device if compiling (#102771)
  • Add fused support for XPU devices (#104517)

Build Frontend

  • _mm_prefetch is for Intel, changed to __prefetch for ARM64 (#96638)
  • Build PyTorch with -Wnewline-eof (#99687)
  • conditional CMAKE_CUDA_STANDARD (#104240)
  • cmake: allow USE_SYSTEM_ZSTD (#104611)

CPU

  • Introduce fast path for equal and concat: (#100024, #106727)
  • Add channel last 3d support for MaxPool3d on CPU (#97775)
  • Add Half support for logsigmoid, threshold, elu, gelu, hardtanh, hardsigmoid, hardswish, hardshrink, softshrink, leakyrelu, softplus, glu, silu, mish, and prelu on CPU (#98745)
  • Make index_add_ error if input source shape is wrong (#100321)
  • Enable taskset core pinning in addition to numactl (#96011)
  • Add explicit vectorization for Half dtype on CPU (#96076)
  • Add Half support for sigmoid on CPU (#96077)
  • Add Half to cat fast path on CPU (#96078)
  • Use float as accumulate type for reduce Ops: min, max, minmax on CPU (#96079)

CUDA

  • Support bf16 dtype for conv_depthwise3d and searchsorted (#97819, #99426)
  • Support integer dtypes for padding (cpu and cuda) (#107755)
  • Support complex dtype for Sigmoid Linear Unit (SILU) (#106854)
  • Add additional stream priority for cuda streams (#101956)
  • Prevent grad scale from overflowing (#98876)
  • nn.EmbeddingBag bound check (#96022)
  • Hide set_device change (#94864)

MPS

torch.export

  • Change attributes of ExportedProgram to properties and add BC decorator #106170
  • Throw explicit error when constraining on static values (#101655)
  • Store the arguments used to trace the exported program in itself to facilitate (#107906)
  • Add kwargs support for export. (#105337)
  • ExportedProgram.transform updates graph_signature automatically (#107792)
  • Support preserving calling convention to some modules so that they can be properly unflattened. (#106798)
  • Make pass base composable (#103701)
  • Remove unused flags in export (#106336)
  • Update the core Aten operator set:
    • Add 23 ops to core Aten set (#107766)
    • Remove split.Tensor from core Aten (#107938)
  • Allow registration of dataclass as pytree node (serialization of it not supported yet) (#106160)
  • Support re-exportability (#106531)

torch.fx

  • Rewrote graph traversal to avoid recursion (#95723)
  • Reorder the Fx execution order to in-time get_attr rather than putting all get_attr ahead (#95014(#95014 ))
  • Preserve node.meta for get_attr nodes in fx.Transformer (#95245)
  • Preserve output node metadata (#95426)
  • Copy nn_module_stack metadata whenever we create a new node when tracing (#95358)
  • Prettify assert expr in self.symbol_to_source failure (#95972)
  • Allow torch.fx to take Modules that return dataclass (#99576)
  • Add has_side_effect to add to a list of side effecting functions (#97288)
  • Change placeholder check instanceof PHBase (#102008)
  • Add metadata to PHBase placeholders (#102195)
  • Make fx.wrap idempotent (#104838)
  • Enable Python dispatcher when ShapeProp with fake mode (#103512)

Quantization

  • Add quantization support for pixel_shuffle, pixel_unshuffle, narrow, ConvTranspose, ConvTranspose3d (#94769, #96160, #97126, #97125, #101926)
  • Support static quantization for LSTM and MultiheadAttention (#96343, #96436, #101299, #95636)
  • Force weight observer/fake_quant to be on the same device as the weight tensor (#106755)
  • Add serialization method for quantized hardswish (#94486)
  • Enable quantized_max_pool3d (#101654)
  • Quantization oneDNN backend only support VNNI CPU (#103653)
  • Fix bug in fuse_modules (#105069)
  • Add torch.matmul in FloatFunctional/QFunctional (#106831)
  • Support quantized Sub, Multiply in XNNPACK (#104090, #104134)

Profiler

General Profiling

  • Profiler permitted CPU events with CUPTI Range Profiler mode (#97048)
  • Make Profiler API agnostic with respect to target hardware (#101554, #106142)
  • Improve on-demand profiling options for Input Shapes, Memory, Stack, Flops, and Modules (#97380, #97556)
  • When record_inputs=True, record scalar lists of length <= 30 (#100593)
  • Disable Kineto event profiler by default--due to flakiness; fixed thread sanitizer issue; and refactored stress_test (#105144)
  • Bump Kineto to C++17 (#106293)
  • tb_plugin to support HDFS and improved memory view (#106672)
  • Make on-demand update duration configurable, and improved start time for on-demand tracing (#101952)

Memory Profiling

  • Add support for HTML plot of memory profile via export_memory_timeline (#99751, #101316)
  • Include more uncategorized events in memory profiles (#101200)
  • Add export of raw memory events with timestamp via export_memory_timeline_raw (#105094)

ONNX

TorchScript ONNX exporter

  • Add Autocast support to MatMul through explicit cast (#98346)
  • Add missing spaces between sentences in warning text (#105527)
  • Improve shape inference for Slice (#105755)
  • Do not run deduplicate_initializers when keep_initializers_as_inputs is True (#96320)
  • Remove legacy diagnostic printing (#106498)
  • Re-purpose name field of GraphProto (#107408)
  • Add constant folding for Softmax op (#102861)
  • Add autograd_inlining flag to torch.onnx.export (#104067)
  • Update opset version warning text (#106830)

Distributed

Activation checkpointing

  • Enable checkpoint_wrapper acccept auto_wrap_policy (#102672)
  • Add warns on reentrant use (#102890)

DistributedDataParallel (DDP)

  • Enable delayed all reduce in DDP (#96673)
  • Enable DDP native mixed precision (#92882)
  • Add an API to remove autograd hooks from DDP (#96490)
  • Enable fused optimizer for DDP (#98270)
  • Perform input casting in pre-forward (#100131)
  • Implement new Store methods in PrefixStore. (#100380)
  • Unify _cast_forward_inputs (#102680)
  • Multiple forward support for static graph (#103487)
  • Add methods to DDP to check for backward finalization. (#100773)
  • Support optim in backward after DDP init (#105991, #105995)

FullyShardedDataParallel (FSDP)

  • Add alignment padding for use_orig_params=True (#97667)
  • Allow non-uniform requires_grad for use_orig_params=True (#98221)
  • Move only current FSDP's states to GPU during init (#98319)
  • Reshard frozen params in backward (#101982)
  • Support unfreezing params for reshard-only hook (#104186)
  • Standardize meta device init within FSDP (#104189)
  • Annotate modules for fully_shard (#104363)
  • Make limit_all_gathers=True default for FSDP (#104900)
  • Add record_function for explicit prefetching (#105985)
  • Optimize away intermediate div_ for Hybrid Sharding Data Parallel (HSDP) (#106034)
  • Check valid param freezing for ModuleWrapPolicy (#104427)
  • Allow ModuleWrapPolicy to take Iterable (#104999)
  • Enable async all-reduce for Hybrid Sharding Data Parallel (HSDP) #106080)
  • Relax sharded_grad assert to allow IDLE state (#96584)
  • Copy step tensor so that each parameter has its own step (#96313)
  • Make FSDP optim_state_dict aware of DDP prefix (#96415)
  • Consolidate the arguments and logic of optim_state_dict and optim_state_dict_to_load (#96534)
  • Make it full precision in eval mode (#97645)
  • Include duplicate parameters and modules when calling named_parameters and named_modules (#99448)
  • Make param_groups optional for FSDP optim.state_dict (#99117)
  • Support rank0_only when use_orig_params is True (#99624)
  • Consolidate rank0_only load logic (#99647)
  • Make fsdp device-agnostic for custom-backend which implements cuda-semantics (#99024)
  • Ensure that customized non-tensor optimizer state can be saved (#99214)
  • Avoid elementwise dispatch of gradient unscaling/validation ops in _foreach_non_finite_check_and_unscale_cpu_ (#100108)
  • Do not flatten states when use_orig_param is True and sharding is NO_SHARD (#100189)
  • Make set_state_type to SHARDED_STATE_DICT compatible with NO_SHARD sharding_strategy (#100208)
  • Allow each fully_shard unit to cast foward inputs for mixed precision config (#100290)
  • Restore the state_dict_config for NO_SHARD (#100855)
  • Skip unshard call during checkpointing for NO_SHARD sharding strategy (#101095)
  • Add ignored_states to FSDP/fully_shard (#102056)
  • Start to generalize modules to ignore for mixed precision (#102010)
  • Implement a workaround for FSDP init issue for 2D Parallel (#104398)
  • Improve support for CPU tensors. (#103171)
  • Avoid calling optim.state_dict() to get the initial empty states (#103609)
  • Use _get_module_fsdp_state_if_fully_sharded_module for state_dict (#103783)
  • Validate ignored_modules, ignored_states (#104273)
  • Check module.training for _root_cast_forward_inputs (#104223)
  • Add Option for eval in fp32/bf16 (#104682)
  • The correct way to initialize optimizer states if the corresponding param is empty (#104765)
  • Make optim_state_dict_to_load work with use_orig_param=False + NO_SHARD (#107185)
  • Expose optimizer state_dict config (#105949)
  • Enable custom device support in fsdp checkpoint (#107289)

Distributed Tensor (Prototype Release)

  • Add _tenso.zero function (#95863)
  • Enable the nn.Embedding op for DTensor (#96702, #104820)
  • Support creating DTensor in submesh (#95458)
  • Enable correct behavior of random ops in DTensor and Tensor Parallel (#98198, #98577, #103235, #103910, #106535)
  • Implement aten.equal sharding prop for DTensor (#97170)
  • Set cuda device automatically, and refactor error handling (#97583)
  • Set default value for DTensor ops on non-participating devices (#95852)
  • Change sharding algorithm to be in line with torch.chunk (#98722, #106250)
  • Add a new ColwiseParallel style when Pairwise cannot be used directly (#100137)
  • Enable more generic attention module sharding for PTD Tensor Parallelism (#100508)
  • Adopt strategy based sharding prop in DTensor ops (#100607, #101203)
  • Support torch.save/load with DTensor (#103106)
  • Allow DTensor to support cuda-like device (#102468)
  • Add an input resharding wrapper for TP and unit test for 2D + AC (#103334)
  • Add validate mesh flag to DeviceMesh (#104807)
  • Improve allgather unpadding logic (#103219)
  • Use stack to manage mesh resources (#101202)

Distributed (c10d)

  • Enable the handling of bool tensors in Gloo. (#105354)
  • Enable avoid recordStream calls in ProcessGroupNCCL an option (#89880)
  • Remove stack trace captures from import (#97274)
  • Update _store_based_barrier implementation to reduce load on rank 0 (#98000)
  • Remove lock for nccl collective launch for nccl 2.0+ (#97904)
  • Figure out the correct device to use for object collectives (#100954)
  • Start gloo sequence numbers at 0. (#101422)
  • Add missing torch.distributed.ReduceOp.AVG in type stubs (#101534)
  • Determine collective device from _get_pg_default_device rather than from explicit cuda (#101533)
  • Enable configuration of NCCL communicators (#97394)
  • Make default backend need to check for nccl availability (#102470)
  • Add is_backend_available for c10d backend. (#101945)
  • Add one flag value for direct teardown without comm abort (#102599)
  • Make it the default that PG do not perform barrier after init (#103033)
  • Avoid workEnqueue when capturing cuda graph for NCCL process group (#103503)
  • Ensure ncclCommAbort can abort stuck ncclCommInitRank (#103925)
  • Allow previously initialized MPI (#105023)
  • Increase socket buffer size to allow ProcessGroup init up to 12k ranks (#107878)
  • Change --standalone to bind to a random port (#107734)
  • Initial commit of collective_utils (#101037)

Distributed Checkpoint

  • Upstream fsspec storage read/write to PT (#98387)
  • Rewrote read slicing to use a wrapper. (#99167)
  • Consolidate OSS FsspecWriter/Reader and internal FsspecWriter/Reader (#104724)

Torch Elastic

  • Allow elastic agent to fail fast (#99051)

RPC

  • Add size check before calling .back() in rpc/script_call.cpp (#94297)

Dynamo

  • Support nn.Module forward hooks in Torch Dynamo (#92125)
  • Many graph break fixes - (#94949(#94949, #94658, #102247 and others).
  • Translation validator for dynamo guards (#102563)
  • Update dynamo sum dtype handling to match eager (#103037)
  • Switch calling convention back to real tensors (#99320)

Inductor

  • Support more operators that fallback to eager previously: rad2deg, deg2rad, count_nonzero, bitwise_right_shift, quantized.max_pool2d, erfc, erfinv, all_reduce, squeeze_copy, aten.prod, softshrink, aten.unfold, diagonal, diagonal_copy, diagonal_scatter ( #98994, #98995, #94997, #105906, #101416, #101863, #93111, #96039, #99484, #105603, #105165, #103755 )
  • Add decomposition rules for: aten.normal_, lerp, aten.angle, unfold_copy, aminmax, nansum, fmin, fmax, narrow_copy, expand_copy, view_copy, smooth_l1_loss, full_like, affine_grid_generator, aten.dist ( #91207, #104866 , #105609, #96038, #96039, #102077, #101963, #104709, #105586 )
  • cudagraph and cudagraph trees (#97440 , #98254, #89146, #98529, #102273, #105148 )
  • Add convolution triton template (#95556)
  • Pattern matcher infrastructure (#97740)
  • Use Welford algorithm to compute variance in a single pass (#104725 )
  • Value range analysis ( #102611 )
  • Do IR validation recursively ( #98887 )
  • Merge consecutive splits (#100107 )
  • Constant and index_expr propagation pass to simplify indexing expressions (#101077 )
  • Fallback max_pool2d_with_indices to eager rather than fail an assertion if dilation is not 1. (#100531 )
  • Don't fuse nodes with long distance to avoid increasing memory usage (#104024 )
  • Easier to add third-party backend ( #106874 )
  • Improvements on the CPU backend
    • Support vertical reduction (#97644)
    • Support dynamic shape (#97230, #102068 )
    • Support masked load ( #107670 )
    • Enable accuracy testing in CI (#94898)
    • Enable Inductor to support BF16 atomic_add (#96620)
  • Improvements for AMD
    • tl.dot and MFMA support enabled in ROCm triton for conv/mm lowerings (#107600)
    • Remove ROCm miopen_batch_norm fallback, now lowering to triton (#100089)
    • Enable "reduce-overhead" compile mode with hipgraph support on ROCm5.6 (#103092)
  • Align inductor behavior with eager mode for split_with_sizes (#99702)
  • Avoid decomposing _unsafe_index in Inductor (#107882)
  • Make CI error on inductor fallback when decomp is available (#99473)
  • Enable weight prepack for LSTM (#103071)
  • Enable fused_attention pattern matcher (#107128)
  • Add fused_attention pattern matcher with additional clone (#108141)

JIT

  • Include more shape information on tensor values in jit.trace functions (#95544)
  • Allow annotations using generics directly, e.g. tuple instead of Tuple (#98703)
  • Improve source attribution for easier debugging (#95761, #96423, #98606, #100171)
  • Shape functions implemented for stack, squeeze.dims, cross_entropy_loss, conv_transpose (#92205, #93919, #98078, #97875, #102139)
  • Partially support ForwardRef type annotations for NamedTuple attributes (#96933)
  • Optionally ignore UTF-8 decoding error when converting std::string to python str. (#97282)
  • Improvements to flatbuffer serialization and deserialization (#97190, #97298, #99050)
  • Support serialization/deserialization of >4GB strings (#99104)
  • Enable torch.jit.load for custom device (#99535)
  • Allow C++ custom class to define __repr__ (#100724)

Misc

  • Better function annotations for nn.functional (#102918)
  • Add safe weights_only option to load_state_dict_from_url (#98479)
  • Enable Transparent Hugepages (THP) for buffer sizes >=2MB (#95963)
  • Automatic pulling of ExtraFilesMap without explicit mapping. (#99747)
  • Remove device transfers from Java Native Interface (JNI) (#105583)
  • profile_plot generates snapshot objects (#103497)
  • vmap Support for torch.tril and torch.triu (#94287)

Bug fixes

Python Frontend

  • Fix docstring setup to allow running PyTorch in python optimize mode (#100750)
  • Fix deserialization for UpsamplingBilinear2d (#101248)
  • Fix torch.distributions.Dirichlet.log_prob when x=0 and alpha=1 (#103605)
  • Fix torch.distributions.Gumbel.cdf (#91698
  • Fix PEP 484 Violation (#105022)
  • Fix bitwise shift operations when shifting out of bounds (#97150)
  • Fix torch.asarray to use the default device (#106779)
  • Fix deepcopy on torch.Tensor on MTIA device (#107427)
  • Add deterministic path for Tensor.resize_ (#104300)
  • Fix torch.pow to handle real negative base and complex exponent (#95198)
  • Fix LayerNorm(bias=False) error (#108060)
  • Don't fastpath conj copy when conj/neg bit mismatch (#108881)

Autograd

  • Fix torch.autograd.graph.register_multi_grad_hook to not keep tensor alive in closure (#102859)
  • Fix autograd hooks being spuriously garbage collected by removing incorrect THP{Cpp,}Function_traverse PyObject traversals (#102860)
  • Fix Tensor::register_hook behavior on undefined tensors (#105587)
  • Handle undefined gradients out-of-place foreach backward (#100256)
  • Fix codegen logic for foreach derivatives (#95263)
  • Bump version counter for torch{resize_, resize_as_} (#96598)
  • Bump version counter for foreach functions (#93901)

optim

  • Fix and enable complex x amsgrad support for Adam and AdamW (#104989, #104990)
  • Fix unpicklable object in AveragedModel (#95979)
  • Fix parameter list used in weight_decay for Adam (#100973)
  • Fix optimizer state_dict casting to allow step to cast for fused/capturable (#102619)
  • Update lr_scheduler.py to check the type of eta_min (#97003)
  • Fix issue with lr_scheduler serialization containing bound methods (#102627)

torch.nn

  • Fix int() casting in torch.nn.RNN to have correctly traced JIT and ONNX graph. (#92970)
  • Fix device handling in nn.utils.rnn.unpad_sequence (#98042)
  • Fix torch.nn.FractionalMaxPool2d output_size error (#99507)
  • Fix inconsistent torch.nn.MaxPool1d output on cpu and gpu (#99843)
  • Fix device of lengths in pack_padded_sequence when the default device is GPU (#103967)
  • Fix bias overflow for memory efficient attention in scaled_dot_product_attention (#107968)
  • Update scaled_dot_product_attention dispatch logic to check for sm86 and head_size == 128 for flash attention (#94921)
  • Raise type error message for interpolate if size contains non-integer elements (#99243)
  • Fix a bug in interpolate uint8 AVX2 on non-contiguous input (#101136)
  • Fix bug in interpolate when interpolation size is larger than max (#101403)
  • Add error if stateless.functional_call is called with nn.DataParallel (#107403)
  • Fixing interpolate on uint8 unsqueezed 3D CL tensor (#100258)
  • Add check for 0 to 1 inclusive for elements of target tensor in BCE loss (#97814)

functorch

  • Fix torch.vmap support for torch.roll (#95048), nn.{PixelShuffle, PixelUnshuffle}(#96493)
  • Add better error message for mutating .data under functorch transforms (#94817)
  • Fix functorch.jacrev support for torch.take (#95772)
  • Fix functorch support for transforming over Tensor indexing (#98748)
  • Fix torch.vmap support for torch.searchsorted (#99698)
  • Fix torch.vmap support for Tensor.index_put` (#100516)
  • Fix UB in functorch infrastructure (#101568)
  • C++ autograd.Function now raises an error with functorch transforms instead of being silently incorrect (#103957)
  • Fix nll_loss batch rule with negative ignore_idx (#106118)

Distributed

Distributed (c10d)

  • Fix kDefaultTimeout multiple definition build failure in Gloo (#97270)
  • Delete lengths offset checks (#98368)
  • Drop the GIL when creating a TCPSTore to avoid deadlocks. (#100555)
  • Fix bug in process_group_name when there are duplicate PGs (#100518)
  • Fix subprocess group handlig in scatter_object_list (#100552)
  • Fix the check message of unsupported collectives ops. (#101775)
  • Fix netName assignment for NCCL Config (#105776)
  • Skip timeout in FileStore for Windows if the file path is invalid (#103247)

FullyShardedDataParallel

  • Use correct handle training state when prefetching (#98249)
  • Fix issue where fully_shard may determine compute device incorrectly (#98831)
  • Enable FSDP use_orig_params=True mixed precision training when some ranks have no (non-zero sized) parameter shards (#99175)
  • Fix use_orig_params=True, CPU offload, no_sync() (#100180)
  • Fix device_id when buffer-only module (#103504)
  • Fix skip-sharded-views + mixed precision (#105346)
  • Ignore buffer type casting in ignored modules (#106766)
  • Fix train -> EMA -> eval with mixed precision (#106858)
  • Unblock ignored_states + auto wrap (for now) (#104418)
  • Fix a memory leak in optim_state_dict (#96263)
  • Fix bug in determining whether parameters need to be materialized (#97488)
  • Fix typo when setting FSDP state dict config (#97110)
  • Fix osd rank0_only in fsdp (#99136)
  • Fix decision logic for should_cast_forward_inputs in _root_pre_forward() and _pre_forward() (#99546)
  • Fix ignored_states when they are passed as generators (#102575)
  • Fix for optim state dict (#102901)
  • Handle corner case of load with multi-backend PG for FSDP state_dict (#107172)

Distributed Tensor (Prototype Release)

  • Fix DeviceMesh logics in deciding which PG to use (#96861)
  • Remove non-generic asserts in _get_or_create_default_group() (#96961)
  • Fix the default PG condition for DeviceMesh (#97384)
  • Fix DTensor equal op (#99014)
  • Use Stride inferred from local tensor in to_local bwd (#102630)
  • Enable partial tensor add without redistribute (#105939)
  • Get rid of dim_groups attribute from DeviceMesh (#103105)
  • Fix requires_grad in distribute_tensor (#107606)
  • Fix new_empty_strided op’s crash on the shard placement (#108600))
  • Fix new_empty_strided op (#108600)
  • Fix requires_grad callsite (#108358)

torch.compile

Dynamic Shapes

A lot of dynamic-shapes bugfixes, too many to enumerate one-by-one. Some important points:

  • Heavy work our sympy-based symbolic reasoning system, including value ranges analysis for proving redundant constraints (#95174, #105877, #104968, #105138, #105139, #97963, #96121, #104557, #106644(#106644 ), #94944)
  • Improved symbolic tracing support for operators, including SymInt’ified schemas and SymInt aware operator/backward implementations (#95543, #96100, #97362, #97675) Some specific operators:
  • Avoid overspecializing in situations where it is not necessary (#96008)
  • Don't attempt to use fake tensor fallback to real tensor if there are symbolic sizes (#97148)
  • Make Tensor.__contains__ accept SymInt/Float/Bool. (#98933)
  • Change Dynamo to not duck-size unspecialized ints (#99010)
  • Improved mixed type handling for SymInts (#100008, #100328)
  • Support for annotating that SymInts have a constrained range (#103346)
  • Don't generate guards that refer to unbacked SymInts (#95732)
  • Automatically guard when SymInt is converted to int, instead of erroring (#95479)
  • Z3 based translation validation for debugging symbolic reasoning problems (#104827, #106643, #107523, #106645, #101307, #101607)
  • Improve constraint violation error reporting, including recommended constraints for export (#102729(#102729 ), #102198, #107470, #107790(#107790 ), #100745, #101636, #101815)
  • Improve logs for dynamic shapes using TORCH_LOGS=dynamic (#99277, #98941, #107439)
  • If we can't statically prove 32-bit indexing is OK, only add guard if hint exists (#106004)
  • Add expect_true for irrefutable guards, greatly improving overall support for error checking involving unbacked ints, and other unbacked symint improvements (#106720, #95216, #106788)
  • Support for torch.compile with FakeTensor that has SymInt sizes (#107662)

Other bug fixes

In addition, we have the following fixes broken down into roughly 4 parts:

  • Primtorch and decomposition bugfixes and improvements
  • FakeTensor and meta function bugfixes and improvements
  • AOTAutograd bugfixes and improvements
  • General “composability” bugfixes.

The first three cover a large number of general improvements to torch.compile, since torch.compile captures a graph internally by using these major components (fake tensor, prims and decomps, and AOTAutograd, see docs(https://pytorch.org/get-started/pytorch-2.0/)).

Primtorch and decompositions bugfixes

There were a large number of fixes to the primtorch and ref decompositions, which are used in torch.compile during graph capture. These all fixed quite a few bugs in torch.compile:

  • Sub.scalar decomp: fix primtorch handling for with alpha and float64 arg (#95421)
  • Embedding_backward_dense decomp: broadcasting fix (#95499)
  • Upsample_bilinear decomp fix (#101682)
  • Batch_norm decomp reduce computation when weight or bias is none (#104616)
  • _unsafe_index decomp (#106814)
  • Hardshrink: make decomp composite implicit (#107039)
  • normal op decomposition for specializations of the normal op (#106792)
  • matmul decomp: update to match eager (#105850)
  • prims.collapse: make it a real prim (#91748)
  • Diagonal, linalg.diagonal: add refs (#95774)
  • addmv decomp (#96264)
  • Minimum_value: fix schema (#97327)
  • squeeze.dims decomp (#97020)
  • cumprod decomp: add ref (#98670)
  • Prims.unbind: fix ref if given dimension size is 0 (#100122)
  • Aten.arange.default: decompose to to arange.start_step (#99739)
  • Philox_rand: add decomps (#100206)
  • Elu_backward: fix bad accuracy in decomp (#100284)
  • polar decomp: add ref (#100345)
  • Batch_norm decomp: fix decomp when weight/bias is not flattened (#101059)
  • aten.fill.Tensor decomp: don’t call .item() (#103880)
  • Torch.renorm: add decomp (#103858)
  • multi_margin_loss ops: add decomps (#104578)
  • aten.logspace decomp: bugfix (#105201)
  • multilabel_margin_loss_forward op: add decomps (#105302)
  • Torch.{stft,istft} decomps: add ref (#106400)
  • Aten.rrelu_with_noise decomp: add ref (#106812)
  • Misc fixes:
    • better error message when functionalization cant handle op (#95392)
    • Simplify some decompositions. (#107038)
    • Make the glue compute short circuit only if possible (#94437)
    • Don't use PrimTorch decomposition for empty (#94512)
    • Remove unnecessary TensorMeta rewrap (#95004)

FakeTensor and Meta function fixes

Fake Tensors and meta functions are used internally to perform “shape inference” during graph capture when running torch.compile. In particular: when we capture a graph of pytorch operators, we’d like detailed information on the shapes of intermediate and output tensors in the graph. There were a large number of bugfixes and improvements to these two subsystems over the last release.

Operator bugfixes:

Increased operator coverage:

Other:

  • Support resize on meta storage (#101988)
  • meta_tensor] polish error strings in meta registrations ([#95052)
  • meta] error checking for inplace ops ([#101532)
  • Implement size checking for copy_ with meta tensors (#107779)
  • Use safe_is_leaf to test leafness (#102706)
  • FakeTensor] Workaround FFT ops with incorrect meta strides ([#106319)
  • Better complex support (#98869)
  • pt2] remove unused meta_linalg_eigh ([#100965)
  • pt2] convert out params in register_meta ([#101344)
  • Add missing decompositons/lowerings for logical/bitwise operators (#102566)
  • pt2] bug fix: invert condition in checkFloatingOrComplex ([#102944)
  • err on dot product for tensors of different sizes (#106572)

AOTAutograd bugfixes

AOTAutograd is a major component of the torch.compile stack, and received many bugfixes and improvements over the last release.

  • AOTAutograd: fix 'Trying to backward through the graph a second time' error (#98960)
  • Handle tracing foreach ops in ProxyTorchDispatchMode. (#99724)
  • functionalization: error during mutations on mem overlap (#99919)
  • Functionalization of torch.rand/rand_like ops (#97377)
  • fix inference mode / PyDispatcher / Functionalize interaction (#103275)
  • Refactor (#95991, #96235)
  • Dynamic shapes improvements through AOTAutograd (#95975, #96219, #96300, #96653)
  • aot_autograd: dont requires_grad on tangents (#96339)
  • aot autograd] avoid cloning some inputs unnecessarily when they dont require grad ([#96342)
  • aot] disable inference view tracking ([#96478)
  • aot autograd: consolidate metadata (#96340)
  • Add missing aot_autograd_arg_pos_to_source (#97487)
  • Disable logging in pattern matcher calls to AotAutograd (#98936)
  • aot_autograd: factor out runtime epilogue from aot_dispatch_base (#100586)
  • Disallow module forward input mutation in aot_export (#101834)
  • aot_autograd][functional_rng] Change calling convention ([#102344)
  • AOTAutograd] add export entrypoints ([#100587)
  • aotautograd: fix mutation bug when input is noncontiguous (#102767)
  • AOTAutograd] perform comparisons with stride hints ([#103342)
  • AOTAutograd] make _unsafe_view() logic happen during the runtime epilogue ([#103919)
  • Read out real strides from compilation result, rather than real args (#105010)
  • AOTAutograd: correctness fix when tracing custom autograd functions that alias inputs (#102992)
  • Add sequence_nr to aot_autograd to map forward ops to their corresponding backward ops (#103129)
  • AOTAutograd: allow input mutations on inputs that are non-contiguous (#106460)
  • Add some support for detecting false aliasing in AOTAutograd (#106461)
  • Add complex dtypes to partitioner (#96297)

Sparse

  • Fix an unexpected assertion error when nesting check_sparse_tensor_invariants context managers (#95372)
  • Fix silent nnz overflow for very large sparse compressed tensors. (#102523)
  • Fix CSR/BSR invariant validation on 0 sized batched inputs (#101180)
  • Fix zeros_like CSR and BSR tensors with batch dimensions. (#101215)
  • Fix autograd issue with identity conversions (#92022)
  • Set outputs of col_/crow_/ccol_/row_indices methods as non-differentiable. (#107447)
  • Fix silent index downcast from int64 for int32 for add/add_ on CSR/BSR (#95294)
  • Fix add/add_ device checks for CSR/BSR (#97520)
  • Fix incorrect sparse_dim in COO.zero_() and in binary operations with zero-sized COO operands (#98292)
  • Fix sparse.mm derivatives for non-contiguous inputs on CPU (#106127)

Linear Algebra

  • baddbmm: Fix when out has nan value for beta=0 (#96086)
  • Add same dtype checks for {tensordot, addmm(cpu) (even when input has zero numel), baddbmm}. (#98938, #100274, #102659)

Profiler

  • Hand-bound CapturedTraceback (#107438)
  • Fix crash by initializing kineto_activity for each event for on-demand profiling (#97550)
  • Fix CUPTI lazy re-initialization and CUDA Graphs crash in CUDA 11 with workaround (#101879)
  • Fix CUDA IMA for CUPTI and CUDA 12 by disabling CUPTI lazy re-initialization (#107744)
  • Fix profiling PT2 w/ dynamic shapes & record_shapes (#104320)
  • Fix profiling shapes with PT2 + lists of dynamic shapes (#105893)
  • Fix an issue where running Kineto daemon and dynolog in docker fails and UUID generation for IPC fabric (#95535)
  • Fix static order deinit with LoggerCollector (#101952)
  • Fix issues in tb_plugin regarding Distributed View and NCCL events (#103031)
  • Fix test_profiler_tree for HIP and enabled individual activity types for RocTracer (#106293)
  • Fix flaky test_memory_timeline_no_id in test_memory_profiler.py (#103441)

Quantization

  • Fixing quantized prelu workflow (#103455)
  • Fix issue of lowering weighted functional ops with kwargs (#95865)
  • Return zero_point from determine_qparams as a int64 (#98746)
  • Fix errors in QuantizeAvx512 (#104400)

CUDA

  • Add broadcastable check to index_put (#94849)
  • Fix uniform returning end point for BFloat16 and Half (#96962)
  • Fix "Cannot assign index like x[[1,2], :] = 2 when torch.use_deterministic_algorithms(True)" (#105833)
  • Fixing a bug where allocating a 4GB block results in using 8GB of memory (#95827)
  • Take CUDA_VISIBLE_DEVICES into account for nvml calls (#94568)

Intel

  • Avoid FPE when running batch norm with zero batch size. (#95324)
  • Fix CPU bitwise shifts for out-of-limit shift values (#96659)
  • Use unordered NEQ comparison for vec512 operator!= implementations (#97466)
  • Fix masked_scatter_: non-contiguous self (#100232)

MPS

  • Introduce xfail (#95045)
  • Error on unsupported types (#95982)
  • Add type promotion to torch.addcmul (#96164)
  • Add random_ overload (#98333)
  • Fix layer_norm_backward_mps key (#100295)
  • Make grid_sampler_2d available (#101108)
  • Fix bernoulli for int types (#100946)
  • Enable arange for int8 and uint8 dtypes (#101303)
  • Handle deserialization more permissively (#98834)
  • Fix mps unary op issue on non densely stored tensors (#105512)
  • Fix torch.std for negative dimentions (#107754)
  • Remove mps specialized path in BCE backward (#95220)
  • Fix type casting copy with storage offset (#95573)
  • Fix views with 3 or more sliced dimensions (#95762)
  • Fix bidirectional LSTM & small one-direction LSTM fix (#95563)
  • Fix in-place add and sub with alpha == 0.0 (#96184)
  • Fix flip where no dims need to be flipped (#96605)
  • Fix LSTM grad_y (#96601)
  • Fix the failure with ReplicatePad3D (#96988)
  • Fix torch.eye unsupported bool constant on macOS 12 (#97027)
  • Add linear inputs check (#99228)
  • Fix gelu exceptions not raised for error inputs (#99237)
  • Fix max_pool2d exceptions not raised for error inputs (#99238)
  • Fix trace exceptions not raised for error inputs (#99239)
  • Add dot input check (#100099)
  • Fix index_put with deterministic algorithm enabled (#97660)
  • Fix embedding cache key (#101857)
  • Fix softplus with f16 input (#101948)
  • Fix incorrect distribution of randperm with device mps (#104171)
  • Fix argmax and argmin clamp value on MPS (#104374)
  • Make torch.empty* deterministic by filling with NaN or max int (#104995)
  • Correct empty tensor mps all operation (#105218)
  • Fix upsample output size tensor (incorrect result in MacOS 14.0) (#105677)
  • Fix MPS clamp issue with different dtypes between input and min/max tensors (#105747)
  • Fix copy_ broadcast (#105617)
  • Fix clamp with strided outputs/inputs (#97858)
  • Restride output strides to contiguous format for inverse op (#102122)
  • Remove casts from reduction/cumsum/sort ops starting with macOS 13.3 (#95817)
  • Fix .item() for multi-dim scalar (#107913, #108410)

Vulkan

  • Ensure dim is size_t (#104201)
  • Fix divide-by-zero with padded tensors (#97698)
  • Ensure non-zero divisors in Vulkan API Tests [#100909, #100910]
  • Fix concat op in feature dimension (#101721)
  • Fix bug of aten::cat for concatenation of 3D tensors at channel dim with channels as multiple of 4 (#103718)
  • Fix the position computation with the consideration of channel padding (#103908)
  • Fix quantized cpu to vulkan broken by padding (#97372)
  • Fix broadcasting in quantized elementwise ops (#97554)
  • Fix lint for at::softmax 1,2,3 dimension tensors (#105082)
  • Fix static analysis errors in vulkan_quantized_api_test.cpp (#97400)
  • Reuse broadcast checks instead of check_inputs (#105960)
  • Fix global and local sizes for image->bool copy (#106752)

Build

  • USE_FAST_NVCC Windows (#95206)
  • Enable CuDNN v8 frontend in RL (#102284)

ONNX

TorchScript ONNX exporter

  • Fixes for operators:
    • Add cast operator after reduce to match desired dtype (#100700)
    • Simplify repeat_intereleave export for scalar-valued repeat (#100575)
    • Fix wrong type when exporting {zeros, ones, full, empty, rand, randn}_like ops to onnx (#103048)
    • Fix output_padding for quantized tconv (#104207)
    • Refactor AvgPool to support dynamic shapes (#105683)
    • Fix expand_as (#95962)
    • Add new aten::device variant to TorchScript (#97023)
    • Export dynamic step size for aten::slice (#104385)
    • STFT Support (#92087)
    • Fix aten::flatten conversion with 0d input to onnx Reshape and 1d to Identity (#104089)
    • Fix output shape mismatch issue of max_pool (#106270)
    • Add quantization support to reshape and size for the ONNX exporter (#106629)
    • Return input itself for non-fp inputs and support decimals for aten::round op (#107920)
  • Apply peephole for eval mode when constant folding is enabled only (#95801)
  • Detect None constant during jit scalar type analysis (#101608)
  • Fix onnx Gather constant folding (#101329)
  • Fix third-party custom operator support in torchscript exporter (#104785)
  • Fix memory leak when exporting models (#107244)
  • Fix typo scipt -> script (#97850)
  • Fix circular padding to support dynamic axes (#95647)
  • Perform Shape inference on added Cast node (#106093)
  • Cap opset version at 17 for torch.onnx.export (#107829)
  • Make torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp data_ptr-correct (#100681)

TorchDynamo ONNX exporter

  • Fixes for operators:
    • Fix scalar elements in op.Concat (#98509)
    • Fix aten::cat export when arg include parameters (#105373)
  • Remove duplicated code from previous rebase (#99072)
  • Cover undiscoverable ops torch.ops.aten (#99682)
  • Fix type annotation for fx_to_onnxscript (#100050)
  • Set tracing_mode through options.dynamic_shapes and enable dynamic tests in test_fx_to_onnx_runtime.py (#100212)
  • Add RemoveConstantInputStep to adapt torch inputs to ONNX inputs (#100252)
  • Fix exported onnx initializer name (#104741)
  • Fix UnsupportedFxNodesAnalysis after onnx dispatcher changes (#105156)
  • Support torch.device in FX exporter (#105757)
  • Fix passes to reuse existing fake mode if possible (#105764)
  • Exclude FXSymbolicTracer from _assert_fake_tensor_mode (#107712)
  • Fix initializer naming at torch.onnx.ExportOutput.save_model_with_external_data (#105002)
  • Apply options.dynamic_shapes to dynamo API usage in fx exporter (#99199)

torch.fx

  • Fix split_module bug with unused keys (#95493)
  • Fix tabulate import error (#104468(#104468 ))
  • Fix issue with SubgraphMatcher when ignoring literals (#98458)
  • Update InternalMatch in subgraph_rewriter after repeated replacements (#99039)
  • Fix conv+bn folding issue for mixed dtype (#99696)
  • Fix submodules/parameters/buffers preservation when unpickling graph module (#104115)
  • Prevent interpreter from altering original node’s meta (#105880)
  • Fix split module’s interaction with dead code (#104554)
  • Fix split copying over node.meta (#107248)
  • Fix repr when arg is an OpOverload (#102547)

Dynamo

Misc TorchDynamo fixes

  • Correctly use PythonPrinter for generating wrapper code referencing SymPy (#96710)
  • Fail fast when dynamo attempts to add unspecialized int/float as additional graph inputs (#96786)
  • Simplify module_key creation logic (#94945)
  • Generalize summary script to work with more CSV names (#98500)
  • Add support for nonzero, some improvements to reduce guards (#95387)
  • Update Dynamo.export to preserve names of args & kwargs (#95851)
  • Slight cleanup of VariableBuilder giant if condition (#95471)
  • Add guards for deterministic algos (#96695)
  • Add signpost_event to dynamic_shapes (#103882)
  • Add some missing disabled functions (#103662)
  • Add support for dictionary with torch object keys. (#103158)
  • Add timeout for translation validation instances. (#104654)
  • Add Wav2Vec2 HuggingFace support (#103009)
  • Add dyanmo backend based on ONNXRuntime (#106589)
  • Allow for torch.sym_int to return int while tracing (#104837)
  • Allow NumPy code in torch.compile to run on cuda (#104699)
  • Avoid cond prefix when naming subgraph of HigherOrderOperators (#101439)
  • Avoid graph break on repeat_interleave.self_int (#99528)
  • Change dimension constraint summary to log.info (#101584)
  • Debug shape guards (#95848)
  • Disable dynamo on some opt methods and differentiable optimizer tests (#103066)
  • Disable fused adam op compile (#105256)
  • Don't apply automatic_dynamic_shapes if we force tensor to be static (#103673)
  • Don't specialize torch.Size with specialize_int = False (#96419)
  • Dynamo size dim kwargs (#97450)
  • Dynamo stride dim kwargs (#97444)
  • Enable fused foreach Adam compilation (#104121)
  • Enable torch._C._get_privateuse1_backend_name in Dynamo tracing (#103141)
  • Ensure optimizer state references are cleared (#100282)
  • Equality assertions (#102256)
  • Explicitly fall back to eager with GraphModule with no output for onnx&tvm backends (#99805)
  • Extend assert statement to include ListVariable (#100841)
  • Fix disable_saved_tensors_hooks - graph break (#106875)
  • Fix for tuple construction from tuple iterators (#97862)
  • Fix graph break on boolean mask better (#103052)
  • Fix incorrectly getting the name of OrderedDict's index in dynamo (#96940)
  • Fix isinstance on SymInt in dynamo (#99393)
  • Fix lineinfo generation on PY3.11+ (#103525)
  • fix module buffers call (#102251)
  • Fix number of inputs in onnxrt and tvm backend (#95429)
  • Fix optimizer cuda health check graph break (can be done in the compiler) (#102765)
  • Fix optimizer grad mode state interaction with dynamo (#103952)
  • Fix OrderedDict reconstruction bytecode (#95800)
  • Fix the compatible issue of the Dynamo and the PyDev.Debugger. (#96721)
  • Fix torch.compile issue with torch.tensor (#96299)
  • fix torch.distributions lazy_attribute failure (#103208)
  • Fix usages of contextmanager without finally (#96170)
  • Flatten exceptions in dynamo (#100779)
  • Full default dict support in dynamo (#102202)
  • Generate type match guard for torch.Size input (#96421)
  • Graph break on differentiable boolean mask setitem (#102843)
  • Graph break on operators that fake tensor doesn't support (#97708)
  • Guard on default device (#99551)
  • Handle calls to typing.cast (#104799)
  • Handle dim in size kwargs (#96992]) ([#97098)
  • Initialize optimizer in dynamo to avoid graph break and tracing slowness (#102640)
  • Keep submodule's name for nn.Sequential when unrolling (#94913)
  • Make _CURRENT_TRACING_CONTEXT thread local (#105942)
  • Make int unspecialization actually work (#95621)
  • Make openxla and opexla_eval backend show up in list_backends (#107905)
  • Make Openxla dynamo backend take boxed input (#107260)
  • Manually generate guards for optimizer (#103121)
  • Node.stack_trace should have innermost frame last (#95592)
  • Normalize builtin types to dtypes (#106074)
  • Add a flag that allows breaking on NumPy ops (#107687)
  • Fix ndarray.__pow__ (#107746)
  • Return NotImplemented for `np.sort(complex)`` (#107710)
  • Support linalg, random and fft module (#105320)
  • torch._numpy: remove noops and half-implemented nan-functions (#107596)
  • Wrap ndarray dunder methods (#107689)
  • Pass torch.compile mode/options to all backends (#99645)
  • Update pre_dispatch tracing: support autocast and no_grad/enable_grad ctx managers, add a pre_dispatch_eager dynamo backend (#103024)
  • Preserve CreationMeta when metafying views (#103152)
  • Preserve mark_dynamic when cloning inputs (#99617)
  • Prevent GraphArg from keeping real tensors live (#100515)
  • Propagate mark_dynamic in dynamo compiled outputs (#99634)
  • Properly avoid wrapping numbers as tensors before backend (#96193)
  • Properly parenthesize dynamo_dynamic_indices test (#99823)
  • Properly respect automatic dynamic config for unspec int (#103321)
  • Raise warning if user has hooks installed on the module (#94848)
  • Record caller frame instead of function frame (#96882)
  • Resolve InlinedClosureVariable in InstructionTranslator stack (#106491)
  • Rewrite size/stride/numel TensorVariable handling (#103438)
  • Simulate torch function enablement state (#105091)
  • Simulate tracing tree_map_only (#104815)
  • Simulate treespec flattening/unflattening (#101896)
  • Skip if curr_size is None (#101170)
  • Support CUDA stream passed from outside of torch.compile decorator (#94627)
  • Support getattr for ConstantVariable when compiling with Dynamo (#98153)
  • Support module dict iter (#99503)
  • Support threading.local getattr (#104292)
  • Support unary not on lists (#102210)
  • Support wrapping + returning tensor subclasses (#104802)
  • Trace through Tensor slots (#107159)
  • Turn on add_runtime_assertion by default (#102671)
  • Tweak dynamic=False behavior (#105715)
  • Update XLA dynamo backend name (#106489)
  • Update exir.pass_base to use export.pass_base (#106647)

Misc dynamic shapes fixes

  • Add API to mark input tensors static for cudagraphs (#107154)
  • Add invariant that all symbolic shapes must be bound in graph (#99089)
  • Add support for Inductor + symbolic shapes + training (#93059)
  • Add symbolic tracing support to torch._dynamo.export (fake input + weights) (#100017)
  • Add unbacked symbol support (#98877)
  • Always create ShapeEnv, always apply unspec logic (#103302)
  • Do not mutate SymNode expressions. (#107492)
  • Do not track parameters, do not generate guards (#98350)
  • Add dynamic range constraint API (#98779)
  • Enable dynamic shapes of torch.nn.Parameter (#105855)
  • Further improve symbolic shapes logging (#99159)
  • Group constraints by arg (#102096)
  • Guard static shapes alongside tensors, instead of from shape_env, in dynamic_shapes=True (#99566)
  • Make hash_storage work with size 0/1 storage (#100467)
  • Make unspecified ints to range over negative and positive. (#104658)
  • Propagate dynamic int on __setitem__ (#105923)
  • Remove redundant dynamic_dim (#107815)
  • Support bit shifting SymInts (#104318)
  • Switch dynamic_shapes to True by default (#103597)
  • Warn if guards are added to ShapeEnv after we produced guards (#97820)
  • Don't specialize when indexing by SymInt (#99123)
  • Fix specialization when you pass an unspec int into slicing on a Python list. (#104142)
  • Flag guard unbacked SymInt/SymFloat support (#94987)

Benchmark related bug fixes

  • Add a flag to benchmarks script to keep the test report directory (#96398)
  • Fix amp in inference in benchmarking suite (#103220)
  • Handle new inference csv from CI for benchmarking (#98294)
  • Small operatorbench improvements (#103110)

Export related bug fixes

  • Add aot_export (#101490)
  • Add get buffer from exported program (#107809)
  • Add support for edge dialect ops in exir/serde (#106371)
  • Change torch._dynamo.export(aten_graph=...) to allow pre_autograd tracing (#98031)
  • Error on closed over variables (#99367)
  • Enable dynamo export to export identity function (#94962)
  • Error when constraining on static values (#101655)
  • ExportedProgram (#102259)
  • Fix soundness bug with unsupported constraints (#102897)
  • Fix specify_constraints signature for exporting module (#101831)
  • Improve error message for IO mismatch (#107907)
  • Make serializer more composable (#104816)
  • Persist torch.assert in aten graph (#100101)
  • Preserve `meta"val"]`` on export ([#95314)
  • Raise error on 3.11 dynamo export (#95088)
  • Refactor and add same_signature flag to dynamo.export (#106569)
  • Refactor dynamic dims api, stateless internals, higher level export API (#96699)
  • Remove eliminate_dead_code (#105875)
  • Remove fake_mode arg from torch._dynamo.export API (#106345)
  • Remove setter for graph_module (#106651)
  • Suggest constraints to specify for export based on generated shape guards (#98463)
  • Support list output for HigherOrderOperators (#101986)
  • Support runtime assertion for inline constraints (#100763)
  • Integrate torch.ops.call_delegate into the delegate workflow (#92562)
  • Wrap more constraint violation cases to UserError (#100897)

Logger bug fixes

  • Rename sym_shapes logger to dynamic (#99335)
  • Raise a NameError when accessing non-existent variable (#96418)
  • Convert logging f-strings to use % format, part five (#98765)
  • Enable passing a dict of module names: log level to set_logs python api (#98989)
  • Expose function to retrieve list of registered loggers (#100776)
  • Remove unnecessary check when logging artifacts (#99260)
  • Revamp guard debug logging (#107505)
  • Add assert + test for artifact log booleans (#104907)
  • Add fast traceback utilities (#107358)
  • Add graph break logging option instead of config flag (#103202)
  • Add verbose_guards logging artifact (#107388)
  • Do not use unicode quotes (#99446)
  • Elevate cudagraphs failure to warning, added lineno to recompiles (#105081)
  • Generate error on bad input to equality constraint (#107311)
  • Fix outdated log settings in doc (#102285]) ([#102286)
  • Make DimConstraints create actionable message (#100103)
  • Make sure log tests are run in non-verbose mode (#106496)
  • Report guard failures with recompiles logging (#105500)
  • Update error message with torch logging instructions (#102892)
  • Fix typo in settings regex logging (#97245)
  • Improve TORCH_LOGS settings error msg (#97264)

Minifier related bug fixes

  • Add --check-str support to after_aot minifier (#104758)
  • Teach requires_bwd_pass how to interpret int (#98312)
  • Add --offload-to-disk support to minifier (#100546)
  • Improve minifier printing to be more chatty when it makes sense (#100486)
  • Make run_fwd_maybe_bwd work with int inputs (#99365)
  • Misc accuracy improvements on minifier (#100447)
  • Print AOT Autograd graph name when accuracy failed (#99366)
  • Relax after_aot restriction on no buffers, serialize small constants (#100472)
  • Cast copied model rather than update the original model (#101901)

Inductor

  • Skip triton configs for mm_plus_mm that may crash triton ( #96385)
  • Avoid fusion with indirect indexing (#96273)
  • Use 64-bit indexing for large tensors in triton codegen (#97447)
  • Make aten.constant_pad_nd always do a copy even when padding is 0 to have consistent behavior (#100082)
  • Make argmin/max handle duplicate values in a way consistent with eager ( #99920)
  • Handle negative padding in reflect_pad_backward. ( #100923)
  • Make torch.sign return the same type as input (#101346)
  • Only reuse aliased buffers if there are no more users ( #100332)
  • Fix a number of issues with divs in ValueRangeAnalysis (#100547)
  • Prevent pattern matches across mutation ops in inductor pre-grad FX passes (#101144)
  • Avoid caching stale inner_fn_str/ReadWrites objects (#106502)
  • Correctly infer dtype of full (#95593)
  • Avoid zero division error for dropout (#100222)
  • Fix multi output layout error in indexing dtype calculation (#108085)
  • Bug fixes for the CPU backend
    • Fix compilation issues on pre clang-10 (#103347)
    • Fix compiler error when trying to vectorize logit_and and logit_or (#95361)
    • Properly handle 3D tensor for Conv2d ( #99601)
    • Fix reduction crash caused by storing float value to bfloat16 (#102719)
    • Properly hande broadcasting for bfloat16 (#104319)
    • Fix compilation for TIMM mobilevit_s model (#100230)
  • Bug fixes for the AMD backend
    • Triton wheel support enabled in non-ROCm environments (#95142)
    • Conditionalise triton mm/conv configs on ROCm to mitigate crashes (#107584)
  • Dynamic shape related bug fixes
    • Disable persistent reductions with dynamic shapes since persistent reduction relies on tensor shapes (#98405)
    • Turn off divisible_by_16 for dynamic shapes (#98471)
    • Make philox_rand_like work with dynamic shapes (#95461)
    • Handle int/float arguments for cpp codegen in inductor (#95533)

JIT

  • Mark torch.cuda._exchange_device op as having side effects (#96364)
  • Fix jit.trace codegen for out-variants on ops with more than one output (#101563)
  • Make NNC compatible with LLVM 15-17 (#96762), #98811), #101396), #103824)
  • Fix errors found by fuzzing and sanitizers (#94815), #101400), #102156), #103667), #103969), #106041), #103327), #94300)
  • Fix handling of >32-bit scalars on 32-bit platforms in NNC (#97669)
  • Fixes for NNC’s variable handling and serialization on big-endian systems (#96951), #95881), #104249)
  • Ignore meta-device tensors instead of erroring when loading a model with a target device (#100495)
  • Add overloads for _unsafe_index_put, _unsafe_index (#104127)
  • Match eager result from torch.round in NNC codegen (#104430)
  • Fix lifetime of JITException binding (#106401)
  • Fix annotation handling for subclasses in python >= 3.10 (#104485)

Misc

  • Stride bugfix: add overflow check for stride calculation (#94900)
  • Set SavedVariable.is_output to true for grad_fn->result_ (#105504)
  • Handle tail 0-size tensor appropriately in MultiTensorApply (#100811)
  • Fix UntypedStorage pin error (#104355)
  • Fix validate_input_col for nn.Module or Callable (#96213)
  • Fix segmentation fault in flatbuffers when parsing malformed modules (#95221)
  • Fix TorchScript support in as_nested_tensor (#97960)
  • Reintroduce s390x SIMD support (#99057)

Performance

General

  • Avoid copies in matmul (#76828)
  • Improve the precision of abs() and sign() for large values (#99550)
  • Fuse ops in eager cosine_similarity while keeping the stability and the gradients (#104771)
  • Add scalar conversion using avx instructions for half (#102140)
  • enable Half for cat serial kernel (#96021)
  • Re-enable AVX512 ATen kernels for compute-intensive ops (#104165)

torch.optim

torch.nn

  • Optimize reflection padding performance on CPU (#102254)
  • Optimize replication padding performance on CPU (#102255)
  • Improve precision and performance for BFloat16 upsampling (#91169)

Sparse

Improved performance in the following:

torch.compile

  • Implement CSE for guards (#98488)

Distributed

Distributed (c10d)

  • Enable store_barrier only on the ranks that are part of the process group and not the whole world to make it scalable in PG initiation. (#99931)

Distributed Tensor (Prototype Release)

FullyShardedDataParallel:

CUDA

  • Speed up bincount and histc on CUDA (#97090)
  • Speed up indexing_backward_kernel with duplicates (#100505)
  • Speed up torch.cat on contiguous tensors with wide loads (#102815)
  • Speed up LossCTC (#97269)
  • Speed up prefix scan algorithm (#103314, #103435), #103502)
  • Speed up vectorized_layer_norm (#107287)

Intel

  • Improve mkldnn matmul performance when one input is contiguous tensor but the strides are not default contiguous strides (#99511)

MPS

  • Implement NEON accelerated implementation of ERF() (#105610)
  • Add encoder coalescing support for native kernels (#99810)
  • Add PipelineStateObject caching for advanced indexing kernels (#99855)
  • Squeeze last dimensions, if possible, for 5D (or bigger) reductions to map them to optimal 4D implementation (#99856)

Vulkan

  • Pad channels when using texture storage instead of "tight packing" (#95251)
  • Introduce GPU Memory Layout qualifier allow more efficient memory layouts when storing Tensors (#106978)

ONNX

  • Improve diagnostics performance (#99936, #96348)
  • Don't duplicate model weights in ONNX export (#101134)
  • Reduce exporter memory usage by removing intermediate values (#101148)
  • TorchScript ONNX exporter:
    • aten::relu6: avoid unncessary Relu operation (#99022)

Inductor

  • Match decomposed attention patterns and replace them with eager implementation. This improves perf since eager implementation may use flash attention which do comprehensive fusions. ( #97741, #100609, #107578)
  • matmul padding (#101913, #102200, #103600 )
  • Fuse type casting with triton matmul kernel (#106443, #106516, #107495 )
  • Improve loop ordering to generate more coalesced memory access (#106827)
  • Enable persistent reductions (#94847, #102444 )
  • Layout optimization for convolution ( #99773 )
  • Improve max-autotune ( #95554, #95555, #96410, #97219 )
  • Coordinate descent tuning: doing coordinate descent search to find promising triton configs that are good for perf. (#99594, #99403, #103660 )
  • Inductor Freezing (#100652)
  • Horizontally Fuse Addmm for inference (#100746 )
  • Avoid unnecessary copy (#102089 )
  • Convert layout of conv weight to channels last ahead of time for inference (#103642)
  • Performance improvement for CPU backend: Support mkldnn packed linear to improve bfloat16 performance ( #96954 )
  • Performance improvement for dynamic shapes:
    • Support multilayer reduction (#99475, #101915, #106747 )
    • Apply divisible_by_16 flag in more cases for vectorized load and store ( #105743 )

Release Engineering

  • Add workflow for quick perf comparison for inductor (#96166)
  • Run inference accuracy and performance tests with bfloat16 for inductor (#103535)
  • Add DALLE2_pytorch to inductor benchmarking workflow with AMP fallback (#104283)
  • Run the inductor benchmark suite with dynamic batch only (#97912)

torch.export

  • Speed up export time by avoiding calling the callable during export. (#107249)

JIT

  • Improve load times by reducing the number of times re-indexing occurs (#102312)
  • Skip the source info in the error report if the source code is too large (#105608)

Documentation

CUDA

  • Fix torch.cuda.mem_get_info doc (#96621)

DataPipe

  • Add generated docstring to functional form DataPipe (#100503)
  • Update docstring for functional form of DataPipes (#100446)

torch.fx

  • Update fx.pass.graph_drawer usage doc to draw fx graph (#95534)
  • Add doc test in graph_drawer.py (#95919)
  • Add docs for writing ATen IR passes + FX Pattern matching (#100577)
  • Update torch.fx docs (#97058)
  • Fix typos under torch/fx directory (#97596)

torch.export

  • torch.export landing page (#108783)

Intel

  • Add best practices doc for CPU backend (#105051)

Linear Algebra

  • Fix docs rendering in linalg.{matrix_exp, ldl_factor}. (#101363, #99777)
  • Fix examples in linalg.tensorinv. (#105911)
  • Improve error message for crashes related to linalg.eigh when input matrix is ill-conditioned, in some cusolver versions (#107082)

optim

  • Document optimizer state_dict() better with an example (#105958)
  • Have SGD summary show up in optimizer overview (#107738)

Python Frontend

Quantization

  • Fix disbale--and other--typos (#95322)
  • Fix return values of _get_name() in quantized ConvTranspose (#97678)
  • Fix docs for prepare_fx/prepare_qat_fx (#105979)
  • Error when someone calls train/eval on pre_autograd graph (#108143)
  • Move dropout replacement to move_model_to_eval (#108255)
  • Fix and rename move_model_to_eval to move_exported_model_to_eval (#109027)

Inductor

  • Improve Discoverability of Inductor Optimizations (#95824 )

Release Engineering

  • Fix doc-rendering error and deprecate CircleCI docs scripts (#105678)

Dynamo

  • Add a RST doc for the performance dashboard (#100592)
  • Small doc update for torch_compile_debug (#95809)
  • Logging documentation updates (#100595)
  • Move Dynamo IPEX backend to training/inference category (#108643)

nn_frontend

  • Remove future deprecation warning from kl_div docs (#96541)
  • Fix the docs for cosine_similarity (#104772)
  • Correct HingeEmbeddingLoss documentation (#95140)
  • Fix docstring for shape of target for MultiLabelSoftMarginLoss (#107817)
  • Document differing behavior of RReLU between training and evaluation (#95624)

ONNX

  • Remove API reference for TorchScript export diagnostics (#107979)
  • Refactor torch.onnx documentation (#108379)

Distributed

FullyShardedDataParallel

  • Update the doc to be more clear that per-device NCCL stream is per PG (#95705)
  • Re-addd why we register the post-backward hook only on the first forward in the case of multiple forwards (#95326)
  • Clarify CPU offload implicitly in reshard_doc (#98666)
  • Document optim_state_dict_config in method (#102657)
  • Document get_state_dict_type (#102658)

Distributed (c10d)

  • Fix typos in comments under torch/csrc/distributed (#96062)
  • Update isend/irecv warning messages for nccl (#95236)
  • Add warning about object-based collectives for GPU tensors to docs. (#97702)

Distributed Checkpoint

  • Fix documentation for distributed checkpointing for optimizers (#95264)
  • Add fsdp checkpoint example (#95258)
  • Update DCP doc to use the updated FSDP optim state_dict APIs (#95303)
  • Update documentation to read FileSystemReader instead of FileSystemLoader (#102795)
  • Add documentation for HSDP saving using DCP (#104810)

RPC

  • Add missing RRef docs for RPC (#106902)

Sparse Frontend

  • Improve error message when expand is called on sparse tensor (#98365)

Composability

Dynamic Shapes

  • Update dynamic shapes documentation (#109764)

Dynamo

  • Add docs for torch.compile(numpy) (#109710)

Developers

torch.fx

  • Fix typos in torch/fx/_compatibility.py (#97618(#97618 ))
  • Add torch/utils/_stats.py to stack frame skiplist (#98117)
  • Add pre_autograd kwarg to make_fx (#97559)
  • Revert torch.fx.interpreter error printing change (#101462)
  • Fix pytree error formatting string (#105935(#105935 ))
  • Assume SymPy is always installed (#94903)
  • Add a more error checking to minifier (#103057)
  • Refactor unwrap_proxy() for proxy tensor tracing (#104667)
  • Enable ruff's UP rules and autoformat dynamo / functorch and refs (#105432)
  • Enable flake8-simplify checks (#97984)

Inductor

  • Allow overriding the decomposition table in compile_fx API. ( #95468 )
  • Allow saving parameters for compiling a graph and relay later to improve development efficiency (#106952 )
  • Support benchmarking kernel perf to gather metrics like latency and memory bandwidth ( #95355, #95506, #95845, #96458, #96461, #97057, #103547)
  • Tracking operator count ( #100329 )
  • Print the path to the generated wrapper code with TORCH_LOGS=output_code (#99038 )
  • Provenance tracking for wrapper code (#105717, #95901 )
  • Support inductor OSS perf dashboard (#95685, #99387, #99754, #105221)

Composability

  • A number of improvements that make it easier for custom backends to integrate as a pytorch eager mode backend out-of-tree, through the PrivateUse1 DispatchKey

    • Allow privateuse1 key to be used with legacy constructor (#95748)
    • Add Generator register for the privateuse1 backend (#93920)
    • Optimize the AMP func name in custom_device_mod (#98052)
    • Enable dispatch stub for backend PrivateUse1 (#99611)
    • Support random for custom device (#97420)
  • Nvfuser python API import fix (#94036)

  • Add ability to create library fragments (#98439)

  • Core aten IR:

    • Tag functions to core IR in native_functions.yaml (#105849)
    • Add _native_batch_norm_legit_no_training to core IR (#107732)
    • Make python decomp for native_batch_norm CompositeImplicitAutograd, remove native_batch_norm from core aten opset (#107791)
    • Avoid extra copies in batchnorm decomposition inference by introducing a new op, _native_batch_norm_legit_no_training (#94946)
    • Add aten.smooth_l1_loss_backward to core_aten_decompositions (#100267)
    • Add empty/empty_like to core aten decomps (#105158)
  • Fixed missing-prototypes warnings in torch_cpu (Part 1) (#100053)

  • Fix typos in checkFloatingOrComplex errors (#102456)

  • Allow existing "Python RAII guards" to be used as context managers (#102579)

  • Replace _prims_common.check with torch._check* (#103240)

  • Update core aten decomp table (#105673)

  • Generate mypy hints for torch.Tag, add a couple of pointwise ops (#106910)

  • aot_autograd: avoid using intermediate_base logic unnecessarily (#97786)

  • Fix disable amp for runtime wrapper (#97864)

  • aot_autograd: more logging on metadata asserts (#99177)

  • Proper handling when outputs are aliased but have identical size/stride/offset metadata (#100430)

  • Fix de-dupping metadata computation bug (#100431)

Release Engineering

  • Rename default branch to main (2418b94)
  • Use PyTorch wheel in Windows CI (#94958)
  • Use GPU machine and run GPU tests with Bazel builds (#95721)
  • Enable simpler C++ test discovery + running workflow on CI with run_test.py (#99956, #99559)
  • Run C++ test_api binary directly in CI slow jobs (#101088)

Autograd Frontend

  • Add a macro for derivatives formulas that returns multiple outputs and can be specified to save certain tensors conditionally (#103750)
  • Fix torch._C.get_current_graph_task_execution_order accumulate_grads ordering (#105353)
  • torch.autograd._force_original_view_tracking to work as both context manager and function (#106706)
  • Enable autograd to be compiled (#103822, #104316)

JIT

  • Create public interface for torch.jit to reduce pyright errors (#101678)

optim

  • Change step from 1D to singleton tensor in Adam (#96994)

ONNX

  • Remove torch dependencies in _beartype (#98958)
  • Delay torch.onnx import to after all dynamo sub]components ([#99070)
  • Enable xdoctests in CI (#98546)
  • Update ONNX submodule from ONNX 1.13.1 with Protobuf 4.21 updates (#96138)
  • Run ONNX tests as part of standard run_test script (#99215)
  • Skip flaky dynamic tests before ORT==1.15 in fx exporter (#98856)
  • Add additional_test_kwargs into test_fx_to_onnx_with_onnxruntime.py (#99434)
  • Bump onnx-script version with imported module renaming (#99926)
  • Add test_fx_op_consistency.py (#99465)
  • Refactor test_op_consistenct.py and test_fx_op_consistency.py (#100172)
  • Add xfail into subtests of op consistency and retire fixme (#100173)
  • Skip flaky dynamic test in CI (#100297)
  • Add supported ops into test_fx_op_consistency - 1st batch (#100265)
  • Bump onnx submodule to release 1.14.0 (#101809)
  • Bump ORT version to 1.15.0 (#102248)
  • Add FX exporter MaxPool tests (#102773)
  • FX Dispatcher Test (#103971)
  • Bench torch.onnx.dynamo_export and torch.onnx.export under dynamo bench (#103135)
  • Separate fx _type_utils from torchscript exporter (#103942)
  • Use load_model_from_string (#104533)
  • Enable ruff's UP rules and autoformat onnx/ (#105427)
  • Suppress ORT warnings in unit tests (#105624)
  • Add comment on test_view_dynamic_zero_dim (#105950)
  • Bump numpy from 1.21.6 to 1.22.0 in /benchmarks/dynamo/_onnx (ab9ea0d)
  • Enable skipped gpt2 test (#94930)
  • Clean up outdated skip ort < 1.15 decorator in tests (#105951)
  • Add test support for dynamic shapes for torch.onnx.dynamo_export (#106495)
  • Update xfail reasons in fx runtime tests (#107257)
  • Add unittest for exporting embedding_bag (#105862)
  • Add huggingface models into CI tests (#107247)

Distributed

FullyShardedDataParallel

  • Log FSDP mixed precision (#97367)
  • Add loggings of modules FSDP hooks firing (#102508)
  • Print out more useful error message for optim_state_dict (#96860)
  • Use INFO instead of DETAIL for warning logs (#102639)
  • Add a summary log when finishing state_dict (#103784)

Distributed (c10d)

  • Fix sandcastle_skip_if decorator name is confusing (#95649)
  • Add sequence number in PG Wrapper (#97462)
  • Print collective in PG Wrapper (#97544)
  • Add diff capability in PG Wrapper (#100214)
  • Add TDD, NCCL_DEBUG log (#97692)
  • Don't crash on retrieve NCCL DesyncReport (#98470)
  • Print stacktrace on collectFullMesh in for Gloo (#98810)
  • Add fully_shard debug function to print sharded tree structure, module names and managed param fqns (#99133)
  • Enhance error msg in PG Wrapper (#100213)
  • Make ProcessGroupNCCL work.wait() respect timeout (#100162)
  • Add size info to collective logs (#100413)
  • Add the logics and interface to log ProcessGroup comms configuration (#104373)
  • Make NCCL default logging more friendly. (#105695)
  • Add OnCompletion Hook to ProcessGroup (#106988) (#107233)
  • Improve input mismatch error msg (#107281)

DistributedDataParallel

  • Add debug logging around DDP mixed precision copies (#96438)

Distributed Tensor (Prototype Release)

  • Add necessary logging to APIs and components for PTD use cases such as DTensor, TP and DCP (#101994, #102209, #102278

Sparse Frontend

  • Expand sparse.softmax zero nnz tests to cover cases of previously reported FPE (#95646)
  • Use nested namespaces in sparse (#97581)
  • Fix cuSparse CSR SPMM when using nullptr in csrRowOffsets (#105957)
  • Remove CUTLASS extensions merged upstream (#107612)
  • Remove CUTLASS extensions merged upstream (#107612)
  • Fixes for reference and move (#95942)
  • Triton kernels without public API
    • Use missing-prototypes in torch_cpu (#104138)
    • SDPA: Support frontend for BSR masks (#104042)
    • sampled_addmm: Support BSR (#101163)
    • softmax: Support Triton kernel for BSR inputs (#102095)

Security

Release Engineering

  • Move mergebot and other CI/CD workflows to its own secure environment (#107060)