Skip to content

PyTorch 1.12: TorchArrow, Functional API for Modules and nvFuser, are now available

Compare
Choose a tag to compare
@soulitzer soulitzer released this 28 Jun 16:48
67ece03

PyTorch 1.12 Release Notes

  • Highlights
  • Backwards Incompatible Change
  • New Features
  • Improvements
  • Performance
  • Documentation

Highlights

We are excited to announce the release of PyTorch 1.12! This release is composed of over 3124 commits, 433 contributors. Along with 1.12, we are releasing beta versions of AWS S3 Integration, PyTorch Vision Models on Channels Last on CPU, Empowering PyTorch on Intel® Xeon® Scalable processors with Bfloat16 and FSDP API. We want to sincerely thank our dedicated community for your contributions.

Summary:

  • Functional Module API to functionally apply module computation with a given set of parameters
  • Complex32 and Complex Convolutions in PyTorch
  • DataPipes from TorchData fully backward compatible with DataLoader
  • Functorch with improved coverage for APIs
  • nvFuser a deep learning compiler for PyTorch
  • Changes to float32 matrix multiplication precision on Ampere and later CUDA hardware
  • TorchArrow, a new beta library for machine learning preprocessing over batch data

Backwards Incompatible changes

Python API

Updated type promotion for torch.clamp (#77035)

In 1.11, the ‘min’ and ‘max’ arguments in torch.clamp did not participate in type promotion, which made it inconsistent with minimum and maximum operations. In 1.12, the ‘min’ and ‘max’ arguments participate in type promotion.

1.11

>>> import torch
>>> a = torch.tensor([1., 2., 3., 4.], dtype=torch.float32)
>>> b = torch.tensor([2., 2., 2., 2.], dtype=torch.float64)
>>> c = torch.tensor([3., 3., 3., 3.], dtype=torch.float64)
>>> torch.clamp(a, b, c).dtype
torch.float32

1.12

>>> import torch
>>> a = torch.tensor([1., 2., 3., 4.], dtype=torch.float32)
>>> b = torch.tensor([2., 2., 2., 2.], dtype=torch.float64)
>>> c = torch.tensor([3., 3., 3., 3.], dtype=torch.float64)
>>> torch.clamp(a, b, c).dtype
torch.float64

Complex Numbers

Fix complex type promotion (#77524)

Updates the type promotion rule such that given a complex scalar and real tensor, the value type of real tensor is preserved

1.11

>>> a = torch.randn((2, 2), dtype=torch.float)
>>> b = torch.tensor(1, dtype=torch.cdouble)
>>> (a + b).dtype
torch.complex128

1.12

>>> a = torch.randn((2, 2), dtype=torch.float)
>>> b = torch.tensor(1, dtype=torch.cdouble)
>>> (a + b).dtype
torch.complex64

LinAlg

Disable TF32 for matmul by default and add high-level control of fp32 matmul precision (#76509)

PyTorch 1.12 makes the default math mode for fp32 matrix multiplications more precise and consistent across hardware. This may affect users on Ampere or later CUDA devices and TPUs. See the PyTorch blog for more details.

Sparse

Use ScatterGatherKernel for scatter_reduce (CPU-only) (#74226, #74608)

In 1.11.0, unlike scatter which takes a reduce kwarg or scatter_add, scatter_reduce was not an in-place function. That is, it did not allow the user to pass an output tensor which contains data that is reduced together with the scattered data. Instead, the scatter reduction took place on an output tensor initialized under the hood. Indices of the output that were not scattered to were filled with reduction inits (or 0 for options ‘amin’ and ‘amax’).

In 1.12.0, scatter_reduce (which is in beta) is in-place to align with the API of the related existing functions scatter/scatter_add. For this reason, the argument input in 1.11.0 has been renamed src in 1.12.0 and the new self argument now takes a destination tensor to be scattered onto. Since the destination tensor is no longer initialized under the hood, the output_size kwarg in 1.11.0 that allowed users to specify the size of the output at dimension dim has been removed. Further, in 1.12.0 we introduce an include_self kwarg which determines whether values in the self (destination) tensor are included in the reduction. Setting include_self=True could, for example, allow users to provide special reduction inits for the scatter_reduction operation. Otherwise, if include_self=False, indices scattered to are treated as if they were filled with reduction inits.

In the snippet below, we illustrate how the behavior of scatter_reduce in 1.11.0 can be achieved with the function released in 1.12.0.

Example:

>>> src = torch.arange(6, dtype=torch.float).reshape(3, 2)
>>> index = torch.tensor([[0, 2], [1, 1], [0, 0]])
>>> dim = 1
>>> output_size = 4
>>> reduce = "prod"

1.11

>>> torch.scatter_reduce(src, dim, index, reduce, output_size=output_size)
`tensor([[ 0., 1., 1., 1.],
        [ 1., 6., 1., 1.],
        [20., 1., 1., 1.]])`

1.12

>>> output_shape = list(src.shape)
>>> output_shape[dim] = output_size
# reduction init for prod is 1
# filling the output with 1 is only necessary if the user wants to preserve the behavior in 1.11
# where indices not scattered to are filled with reduction inits
>>> output = src.new_empty(output_shape).fill_(1)
>>> output.scatter_reduce_(dim, index, src, reduce)
`tensor([[ 0., 1., 1., 1.],
        [ 1., 6., 1., 1.],
        [20., 1., 1., 1.]])`

torch.nn

nn.GroupNorm: Report an error if num_channels is not divisible by num_groups (#74293)

Previously, nn.GroupNorm would error out during the forward pass if num_channels is not divisible by num_groups. Now, the error is thrown for this case during module construction instead.

1.11

m = torch.nn.GroupNorm(3, 7)
m(...)  # errors during forward pass

1.12

m = torch.nn.GroupNorm(3, 7)  # errors during construction

nn.Dropout2d: Return to 1.10 behavior: perform 1D channel-wise dropout for 3D inputs

In PyTorch 1.10 and older, passing a 3D input to nn.Dropout2D resulted in 1D channel-wise dropout behavior; i.e. such inputs were interpreted as having shape (N, C, L) with N = batch size and C = # channels and channel-wise dropout was performed along the second dimension.

1.10

x = torch.randn(2, 3, 4)
m = nn.Dropout2d(p=0.5)
out = m(x)  # input is assumed to be shape (N, C, L); dropout along the second dim.

With the introduction of no-batch-dim input support in 1.11, 3D inputs were reinterpreted as having shape (C, H, W); i.e. an input without a batch dimension, and dropout behavior was changed to drop along the first dimension. This was a silent breaking change.

1.11

x = torch.randn(2, 3, 4)
m = nn.Dropout2d(p=0.5)
out = m(x)  # input is assumed to be shape (C, H, W); dropout along the first dim.

The breaking change in 1.11 resulted in a lack of support for 1D channel-wise dropout behavior, so Dropout2d in PyTorch 1.12 returns to 1.10 behavior with a warning to give some time to adapt before the no-batch-dim interpretation goes back into effect.

1.12

x = torch.randn(2, 3, 4)
m = nn.Dropout2d(p=0.5)
out = m(x)  # input is assumed to be shape (N, C, L); dropout along the second dim.
            # throws a warning suggesting nn.Dropout1d for 1D channel-wise dropout.

If you want 1D channel-wise dropout behavior, please switch to use of the newly-added nn.Dropout1d module instead of nn.Dropout2d. If you want no-batch-dim input behavior, please note that while this is not supported in 1.12, a future release will reinstate the interpretation of 3D inputs to nn.Dropout2d as those without a batch dimension.

F.cosine_similarity: Improve numerical stability (#31378)

Previously, we first compute the inner product, then normalize. After this change, we first normalize, then compute inner product. This should be more numerically stable because it avoids losing precision in inner product for inputs with large norms. Because of this change, outputs may be different in some cases.

Composability

Functions in torch.ops.aten.{foo} no longer accept self as a kwarg

torch.ops.aten.{foo} objects are now instances of OpOverloadPacket (instead of a function) that have their __call__ method in Python, which means that you cannot pass self as a kwarg. You can pass it normally as a positional argument instead.

1.11

>>> torch.ops.aten.sin(self=torch.ones(2))
    tensor([0.8415, 0.8415])

1.12

# this now fails
>>> torch.ops.aten.sin(self=torch.ones(2))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: __call__() got multiple values for argument 'self'
# this works
>>> torch.ops.aten.sin(torch.ones(2))
tensor([0.8415, 0.8415])

torch_dispatch now traces individual op overloads instead of op overload packets (#72673)

torch.ops.aten.add actually corresponds to a bundle of functions from C++, corresponding to all over the overloads of add operator (specifically, add.Tensor, add.Scalar and add.out). Now, __torch_dispatch__ will directly take in an overload corresponding to a single aten function.

1.11

class MyTensor(torch.Tensor):
    ....
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        # Before, func refers to a "packet" of all overloads
        # for a given operator, e.g. "add"
        assert func == torch.ops.aten.add

1.12

class MyTensor(torch.Tensor):
    ....
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        # After, func refers to an individual operator overload,
        # e.g. "add.Tensor"
        assert func == torch.ops.aten.add.Tensor
        # you can recover the old behavior with func.overloadpacket
        assert func.overloadpacket == torch.ops.aten.add

Profiler

Disable forward-backward correlation (#72904)

The forward-backward correlation is no longer captured as to workaround a profile crash. This feature may be reenabled in a future release after the underlying issue is fixed.

with torch.profiler.profile() as p:
    loss = model(inputs)
    loss.backward()  # Invoke autograd

# The exported chrome trace will not have forward-backward flow events. (arrows)
p.export_chrome_trace(...)

Mobile

Remove support for bytecode version 3 (#57775)

The minimum supported bytecode version is being bumped from 3 to 4. We no longer support version 3 bytecode models because the bytecode version was bumped from 3 to 4 more than half a year ago, and there was code in operator loading that performed differently on one operator on the global bytecode version 3.

If the model is generated before Oct 5, 2020, please use the following lines to update the model to the latest version:

1.12

import torch
from torch.jit.mobile import _get_model_bytecode_version

old_model_path = "old_model.ptl"
new_model_path = "new_model.ptl"

# Load full jit model
jit_model = torch.jit.load(old_model_path)
# Save model for mobile 
jit_model._save_for_lite_interpreter(new_model_path)
# Verify the model can be loaded
mobile_m = _load_for_lite_interpreter(new_model_path)

# Get bytecode version from the new model
bytecode_version = _get_model_bytecode_version(new_model_path)
print(f"bytecode version is {bytecode_version}")

Remove redundant FSDP prefix and change default auto wrap policy name to avoid confusion (#76858, #73791)

FullyShardedDataParallel's optional param name ‘fsdp_auto_wrap_policy’ (1.11) changed to ‘auto_wrap_policy’ (1.12). ‘default_auto_wrap_policy’ (1.11) is changed to ‘size_based_auto_wrap_policy’ (1.12).

In 1.11, when wrapping a model with FSDP instead of:

model = MyModel()
wrapped_model = FullyShardedDataParallel(
    model,
    **fsdp_auto_wrap_policy**=functools.partial(
        default_auto_wrap_policy,
        min_num_params=0,  # wrap all modules
    )
   ...

1.12

model = MyModel()
wrapped_model = FullyShardedDataParallel(
    model,
   **auto_wrap_policy**=functools.partial(
        size_based_auto_wrap_policy,
        min_num_params=0,  # wrap all modules
    )
   ...

Quantization

TorchScript models exported prior to PyTorch 1.6 using quantized Linear, GRU and LSTM operators will no longer work (#72680, #72522)

TorchScript models created with PyTorch 1.5 or earlier and using the operators quantized::linear_prepack_legacy, linear_prepack_fp16_legacy, quantized::linear_unpack.legacy, or quantized::linear_unpack_fp16.legacy will no longer work and need to be re-exported. Please use PyTorch Quantization to quantize the Linear module instead.

ONNX

Infra (Releng)

  • Bump minimum CMake version to 3.13 (#76312)

Deprecations

Python API

Deprecated torch.testing.make_non_contiguous (#72705)

torch.testing.make_non_contiguous is being deprecated and will be removed in a future release. Depending on the use case there are different replacement options: If you are using make_non_contiguous in the PyTorch test suite, you can use torch.testing._internal.common_utils.noncontiguous_like

1.11

a = torch.randn(1, 2, 3)
torch.testing.make_non_contiguous(a)

1.12

a = torch.randn(1, 2, 3)
torch.testing._internal.common_utils.noncontiguous_like(a)

If you are using make_non_contiguous in combination with a creation function to create a noncontiguous tensor with random values, you can use make_tensor.

1.11

a = torch.randn(1, 2, 3)
torch.testing.make_non_contiguous(a)

1.12

torch.testing.make_tensor(..., noncontiguous=True)

If you are using make_non_contiguous with a specific tensor, you can use torch.repeat_interleave

1.11

a = torch.tensor([[1., 2.], [1., 2.]])
torch.testing.make_non_contiguous(a)

1.12

a = torch.tensor([[1., 2.], [1., 2.]])
torch.repeat_interleave(input, 2, dim=-1)[..., ::2]

Build

LinAlg

Deprecate torch.lu (#73804)

torch.lu() is deprecated in favor of torch.linalg.lu_factor() and torch.linalg.lu_factor_ex(). torch.lu() will be removed in a future PyTorch release. If you were previously using get_infos=False (this is the default), you should use torch.linalg.lu_factor instead:

1.11

LU, pivots = torch.lu(A, compute_pivots) 

1.12

LU, pivots = torch.linalg.lu_factor(A, compute_pivots) 

If you were previously using get_infos=True you should use torch.linalg.lu_factor_ex:

1.11

LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)

1.12

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) 

Deprecate torch.lu_solve (#73806)

torch.lu_solve() is deprecated in favor of torch.linalg.lu_solve(). torch.lu_solve() will be removed in a future PyTorch release.

1.11

X = torch.lu_solve(B, LU, pivots)

1.12

X = torch.linalg.lu_solve(LU, pivots, B) 

Remove deprecated torch.solve (#70986)

torch.solve which was deprecated in a previous release is now being removed. You should use torch.linalg.solve. instead. Note that torch.linalg.solve has its arguments reversed and does not return the LU factorization. To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.

1.11

X = torch.solve(B, A).solution

1.12

X = torch.linalg.solve(A, B)

torch.nn

nn.Module: Deprecate positional args for state_dict() (#72780)

state_dict can currently be called in two ways: destination, prefix, and keep_vars can be passed as positional arguments, or as kwargs. The ability to do the former is being deprecated and will be removed in a future release. You should pass the arguments in as kwargs only.

Composability

Deprecated __torch_function__ as instance method for more functions (#74829)

__torch_function__ should be defined as a class method. Defining __torch_function__ as a plain method has already been previously deprecated for the functions handling __torch_function__ in Python. This change makes it so that that is also the case for functions that handle __torch_function__ in c++.

1.11

class Bad():
    def __torch_function__(self, *args, **kwargs):
        pass
t = Bad()
torch.abs(t)

1.12

class Good():
    @classmethod
    def __torch_function__(cls, *args, **kwargs):
        pass
t = Good()
torch.abs(t)

Quantization

Deprecate torch.jit.quantized (#72690)

Instead of using functions defined in torch.jit.quantized, please use PyTorch Quantization to dynamically quantize Linear/RNNCell/LSTMCell/GRUCell/LSTM modules. It’s both supported in Eager Mode Quantization and FX Graph Mode Quantization

1.11

>> torch.jit.quantized.QuantizedLSTMCell(...)

1.12

>> torch.jit.quantized.QuantizedLSTMCell(...)
   "torch.jit.QuantizedLSTMCell is deprecated and will be removed in an upcoming
    PyTorch release. Please use the torch.nn.quantized.dynamic.LSTMCell instead."

Infra (Releng)

  • Removed CUDA 11.1 binary builds (#73376)
  • Removed CUDA 11.5 binary builds (#76257)

New features

Python API

  • Added new device mps that can be used to leverage GPU acceleration on macOS platform with Apple Native Silicon (M1) or discrete AMD GPUs. (blogpost with details)
  • Added torch.special.log_ndtr (#74795)
  • Added torch.distributions.transforms.{SoftplusTransform,CumulativeDistributionTransform} (#52300, #72495)
  • Promoted torch.testing to stable (#73348)
  • Added maximize flag for optim.Adadelta(#75330)

Build

  • Distributed torchgen as part of PyTorch package (#76306)
  • Added BUILD_LAZY_CUDA_LINALG option (#73447)
  • Introduced an environment variable to change c10 log level (#71746)

Complex Numbers

  • Added a new data-type torch.complex32 to help computing with complex datatype with lower memory usage at the cost of lower precision. Note that this is an experimental feature (#78245) and the major focus in this release was to support operators under torch.fft on CUDA. Besides those operators we have added support and testing for following limited set of ops (NOTE: few operators are only supported on CUDA): Tensor.copy_, torch.complex, torch.testing.make_tensor, cat, Tensor.fill_, Tensor.item, torch.atleast_1d, torch.atleast_2d, torch.atleast_3d, torch.dsplit, torch.vsplit, torch.hsplit, torch.hstack, torch.dstack, torch.vstack, Tensor.conj, torch.add, torch.sub, torch.mul, torch.sub, torch.div, torch.view, torch.view_as, torch.real, torch.imag, torch.neg, Tensor.__getitem__, torch.sum, torch.prod, torch.abs, torch.sgn, torch.exp, torch.log, torch.eq, torch.masked_fill, torch.index_put, torch.rand, torch.randn, torch.full, torch.empty, torch.ones, torch.zeros, torch.block_diag, Tensor.chunk, Tensor.clone, Tensor.contiguous, torch.diag_embed, torch.diagonal, torch.as_strided, torch.column_stack, Tensor.T, Tensor.H, Tensor.mT, Tensor.mH, Tensor.narrow, torch.isfinite, torch.isinf, torch.isreal, torch.flatten, Tensor.chalf, torch.empty_like, torch.movedim ( #73847, #74667, #74854, #75010,#75156, #75311, #75498, #76132, #76158, #75592, #76615, #77179, #77339, #77446, #77483, #77479, #77192, #76724, #77404).
  • Operators in torch.fft now support tensors with torch.complex32 dtype (CUDA only) (#74857).
  • torch.complex32 tensor now participate in type-promotion (#76893)
  • Added torch.chalf alias for torch.complex32 and Tensor.chalf method (#75320).
  • Added proper print support for torch.chalf tensors (#76614).
  • Added support for complex convolution (data-types supported: torch.complex32, torch.complex64, torch.complex128)
    • torch.nn.functional.conv1d and torch.nn.Conv1d (#75310)
    • torch.nn.functional.conv2d and torch.nn.Conv2d (#75412)
    • torch.nn.functional.conv3d and torch.nn.Conv3d (#75581)

LinAlg

  • Added torch.linalg.ldl_factor_ex and torch.linalg.ldl_solve (#69828)
  • Added linalg.vander (#76303)
  • Added linalg.lu (#67833)
  • Added linalg.lu_solve (#72935)

Meta API

  • Added meta tensor kernels for the following operators:
  • Enabled the ability to register Python decompositions for operators as meta kernels, get meta support for where and huber_loss (#77353)
  • Registered meta functions through Python for dot/group_norm/instance_norm/var_mean/index_reduce/matmul/bernoulli/adaptive_avg_pool (#77499) index_select/abs/min/max (#76916), reflection_pad2d (#77681), square (#77682), log_sigmoid_forward (#77739), several more ops (#77362)

torch.nn

  • nn.Dropout1d: New module for 1D channel-wise dropout (#79545)
  • nn.Module: Public API for stateless / functional module computation (#75834)
  • nn.Module: Support for hooks that run after state dict loading (#76823, #77392)
  • Added support for tensor subclasses as parameters (#73459, #77655)

torch.fx

  • Core
    • Allowed Tracer to record usages of Buffers (#73612)
    • Introduced experimental MetaTensorTracer (#76003)
    • Introduced Tracer the ability to trace different forward functions (#77502)

Composability

  • Many features, improvements and fixes to Python tensor subclasses based on __torch_function__ and __torch_dispatch__
    • Added __torch_function__ mode, which allows you to override the meaning of all __torch_function__ overrideable functions within a dynamic scope. (#75154)
    • Added enable_torch_dispatch_mode, which allows nesting of different __torch_dispatch__ modes. (#75965)
    • Added a default implementation of __torch_dispatch__ (#73684)
    • Added support super().__torch_dispatch__ with arguments list (#74509, #74720)
    • Miscellaneous __torch_function__ fixes (#75484, #75110)
    • Added __torch_function__ override protocol supporting to some factory functions (#75639)
    • Fixed propagation of warnings when using __torch_dispatch__. (#74357)
    • Removed spurious warning when using disabled torch function (#75826)
    • Added the ability to snapshot TLS for “has-a” use cases of __torch_dispatch__ (#72623, #74577)
    • Fixed serialization and deep copying for wrapper subclasses (#73078)
    • Allowed is_contiguous() to be overridden in __torch_dispatch__ (#77906)
  • Added a “functionalization” program transform, that can be used to remove mutation + aliasing ops from PyTorch programs, while maintaining program semantics. Currently while most of the logic for the pass lives in core, the pass is exposed as an API through functorch. You can run it with functorch.experimental.functionalize(). Example usages can be found here. (#75913, #76083, #76084, #73442, #77285, #73441, #75302, #75818, #75819, #76125, #76318, #77358)
  • Added a new torch.library API to allow users to override kernels for existing C++ ops through Python (#75905, #76892)
  • Allowed creating new libraries and defining new operators from Python (#76250, #77690)
  • Added experimental API’s for registering and looking up Python decompositions for many aten operators: from torch._decomp import register_decomposition, get_decompositions. (#76311, #76814)

Sparse

  • Added factory functions for sparse CSC, BSR, and BSC tensors (#76634, #76623, #75946, #75961, #75831, #76651)
  • Added ccol_indices and row_indices methods for CSC and BSC tensors. (#77503)
  • Added to_sparse_csc with support for 2D Strided and 2D CSC input (#77521)
  • Added to_sparse_bsr with support for 2D CSR input (#77366)
  • Added index_reduce (#76997, #75981, #76296)

CUDA

  • Add Jiterator support when dtype is complex for sigmoid, exp, sqrt, rsqrt, log, log10, log2, addcmul, abs, addcdiv, sgn, neg , logical_and, angle(#73643, #73776, #73781, #74160, #74161, #74533, #74455, #74827, #74814, #74863, #75123, #76692)
  • Add Jiterator support when dtype is complex for the backward of sigmoid and tanh (#76289, #74948)
  • Add Jiterator support for kaiser_window , prod (#73734, #75231)
  • Enable simple reductions with Jiterator (#75231)
  • Updated to cuDNN v8 API with cuDNN benchmark, convolution bwd / transposed convolution fwd, bfloat16, conv-bias-activation fusion (#60755)
  • Added Python Interface for Jiterator (#76394)
  • Added Jiterator with Python Registration (#77121)
  • Prepared Jiterator code template for multiple outputs (#77902)
  • For CUDA graphs, added torch.cuda.is_current_stream_capturing (#77789)

Vulkan

  • Added Vulkan support for Gated Recurrent Units (torch.nn.GRU) (#72692, #73599)
  • Added Vulkan support for the linear interpolation op (torch.lerp) (#76544)

Profiler

  • Added support both global (experimental) and thread local profiling (#75525, #76078, #76239)

Mobile

  • Added support for different memory formats of Tensors in NNC (#72873)
  • Upgraded mobile model bytecode to V9 and provide backporting to previous versions (#71662)

Distributed

JIT/TorchScript

Quantization

  • Added oneDNN quantization backend (#69820)
  • Added oneDNN quant backend (#74137)

ONNX

Infra (Releng)

  • Added support for ROCm 5.0 (#72895)
  • Added LibTorch builds for ROCm (#57506)
  • Added support for CUDA 11.6 (#75518)

Improvements

Python API

  • Improved numerical stability of torch.distributions.wishart.Wishart (#72993)
  • Added mode property to torch.distributions.Distribution (#76690)
  • Added foreach flag for torch.optim.{Adadelta, Adagrad, Adamax, Adam, ASGD, NAdam, RAdamSGD, Rmsprop, Rprop, AdamW} (#69980, #69981, #69982, #70295, #70481, #70229, #70230, #70231, #70482, #70483, #70484)
  • Added out variant for torch.softmax and torch.log_softmax (#75833)
  • Added handling for r=0 case for torch.combinations (#70270)
  • Added XPU support for torch.autocast (#75250)
  • Added support for Tensor source for .set_(storage, offset, size, strides) (#77007)
  • Changed to register torch.return_types.* as pytree nodes (#75915)
  • Added typing for torch.return_type (#74199)
  • Set correct module for APIs in the torch module (#75801)
  • Improved NotImplementedError verbosity for torch.distributions.kl_divergence (#72845)
  • Added maximize flag to torch.optim.Adagrad (#75968)
  • optim.{Adagrad, Adam, Adamax, AdamW, RAdam}: Updated step in functional optimizers and pass state_steps instead of state (#71333)
  • Improved torch.lerp numerical precision by doing intermediate math in opmath_t (#76062)
  • Changed to alias torch.finfo.tiny to torch.finfo.smallest_normal (#76292)

C++ API

  • Added catch for overflows in calculating storage byte size for col2im (#73719)
  • Implemented center padding for stft (#73432)

Autograd

  • Added forward AD support for torch.{atan2, dist, logsumexp, log_softmax, norm, polar, put softmax} (#73741, #74205, #75027, #75326, #77421)
  • Added forward AD support for torch.nn.functional.{cross_entropy, pairwise_dist, nll_loss, normalize} (#73741, #74205)
  • Added forward AD support for torch.cholesky_inverse (#75033)
  • Added forward AD and forward-over-reverse support for FFTs (#75326)
  • Added forward AD support for torch.nn.functional.{embedding,prelu, bilinear, rrelu, logsigmoid} (#77421)
  • Added forward AD support for torch.nn.BCELoss (#77755)
  • Added forward AD support for Tensor.__rsub__ (#75326)
  • Added forward AD support for torch.clamp when bounds are tensors (#74042)
  • Added forward AD support for torch.nn.functional.{dropout, glu}(#75288, #77186)
  • Added forward-over-reverse for torch.nn.functional.{leaky_relu, glu, elu, selu, celu} (#75294, #77309, #75297)
  • Improved forward and backward derivative torch.{linalg.cholesky, cholesky} (#76032)
  • Improved forward and backward derivative of torch.linalg.qr (#76115)
  • Added complex autograd support for torch.cholesky_inverse (#75033)
  • Added double backward support for torch.nn.functional.binary_cross_entropy wrt target (#77416)
  • Improved error message for torch.nn.functional.batch_norm when running_{mean,var} have forward grad defined (#73655)
  • Improve error message when forward AD is not supported (#75105)
  • Added forward AD and forward-over-reverse support for torch.nn.functional.max_unpool (#68625)
  • Added autograd support for masked_softmax (#71502)

Build

  • Fixed pybind deprecation warnings (#72376)
  • Enabled win-arm64 (#72424)
  • Moved magma utils to its own header (#73058)
  • Turned on -Wsign-compare (#74996)
  • Made all .pyi.in files exportable from torch/_C/ folder (#74962)
  • Improved Jinja2 for docs/cpp build set to version 3.0 (#74718)
  • Added CMake option for using static MKL libraries (#73069)
  • CPU Kernel: Changed to use per-operator headers (#71137)
  • CUDA Kernels: Changed to use per-operator headers (#71212)

Dataloader

  • Added pin_memory_device to Dataloader to pin Tensor to the corresponding GPU device (#65402)

ForEach

  • Improved numerical precision for ForEach L1 and L2 norm by using OpMathType tensor for intermediate results (#68107)

Meta API

  • Changed to skip superfluous storage allocations while constructing meta tensors (#65331)

torch.nn

  • Made nn.init.orthogonal_ no-op for empty input (#75553)
  • nn.{Conv1d, Conv2d, Conv3d}: Added support for complex datatypes (#75310, #75412, #75581)
  • nn.Conv2d: Added bfloat16 support for mkl-dnn backend (#55864)
  • nn.Conv2d: Added support for channels last memory format on CPU for mkl-dnn backend, naive algorithm, and dilated algorithm (#55584, #68101, #70665)
  • nn.EmbeddingBag: Added half precision support on CPU (#74844)
  • nn.FractionalMaxPool*d: Added support 0s in out_size (#73634)
  • nn.Module: Changed to throw error for non-dict inputs to load_state_dict() (#77197)
  • nn.{PixelShuffle, PixelUnshuffle}: Added support for channels last memory format (#50573)
  • nn.PReLU: Enabled fp32/bfloat16 forward and backward for mkl-dnn backend (#60427)
  • F.elu: Improve numerical precision by using opmath and expm1 (#77062)
  • F.{hardshrink, hardsigmoid, hardswish, logsigmoid, smooth_l1_loss, softplus, softshrink}, nn.{BatchNorm, GLU, Upsample}: Add bfloat16 support on CPU (#62558, #63134, #77496, #61944, #76935)

torch.fx

  • FX/graph_drawer
    • Added args/kwargs and users (#73464)
    • Added skip_node_names_in_args option, default to True (#73815)
  • Core
    • Refactor FX codegen into extensible Codegen object (#72566)
    • Modified replace_all_uses_with to allowing filtering of nodes to update(#73763)
    • Made immutable_list and immutable_dict work with pytrees (#73766)
    • Added Assert None concrete_args and improve error messages (#74662)
  • In minimizer, made args work in the uru10x10_to_trt_eval script (#74707)
  • For split_module, changed to return mapping of qualified names from split_module() (#73564)
  • For shape propagation, made shapes and args/kwargs concrete for minimizer (#75291)

Sparse

  • Added CUDA support for scatter_reduce (#74606,#74607)
  • Added 2D Strided, 2D CSR, 2D CSC, 2D COO support to to_sparse_csr (#77521)
  • Added ND Strided, 2D CSC support to to_dense (#74486, #77521)
  • Added 2D CSC support to to_sparse (#73642, #77521)
  • Added support for batched CSR to sparse_csr_tensor (#74542)
  • Added support for __str__ for CSC, BSR, and BSC tensors (#77530, #76650)
  • Updated transpose to return CSC when given CSR (#77615)
  • Added support for CSR gradients for CSR tensors (#75435)
  • Added CSC support to addmm, addmv, mm (#77615)
  • Added autograd for CSR inputs to torch.sparse.sampled_addmm (#68084)
  • Added autograd for CSR inputs to torch.sparse.addmm and torch.sparse.mm (#76591)
  • Added Half/BFloat16 support for to_dense and coalesce methods. (#72397)
  • Added CSR support to mul (#74266, #77177)
  • Added CSR support to sum (#74766)
  • Added BSR support to addmm, addmv, triangular_solve (#77255)
  • Added batched CSR support to torch.sparse.sampled_addmm on CUDA (#77243)
  • Added CSR support for torch.sparse.sampled_addmm on CPU (#76589)
  • Added CSR support to torch.select (#76228)
  • Added CSR support to Tensor.to (#76400)
  • Added CSC support to torch.empty (#77508)
  • Added CSC, BSR, BSC support to torch.clone (#77512)
  • Added CSC, BSR, BSC support for copy_ (#77605)
  • Added (Strided, CSR) input support to torch.mm (#73686)
  • Added CSR support to torch.sparse.mm (#73075)
  • Added (Strided, CSR, CSR) support to addmm on CPU (#73076)
  • Added runtime beta support warning to CSR, CSC, BSR, BSC tensors (#75594, #75865)
  • Added bool support to coalesce and to_dense (#74495)
  • Added half support to sparse_mask (#76862)
  • Added AMD Navi 21 support to coalesce (#73548)

AMD

  • Enabled atomicAddNoRet() for all gfx targets. (#75451)
  • Enabled miopen for RNNs with dropout. (#75429)
  • Used ncclAllToAll for ROCm (#75128)
  • Navi21 Enablement: fix TI num_threads for ROCm, Depthwise kernels, Embedding kernels, Normalization kernels, Softmax kernels, Tensor kernels, Index, Repeat and Sort kernels, Range and Multinomial Kernels (#69942, #72682, #72809, #73543, #73545, #73546, #73549, #73550)
  • Added ROCm version api within CMake (#69481)
  • Enabled sort operator BF16 support (#72854)
  • Enabled HIP IPC (#74383)
  • Enabled topk operator for bfloat16 dtype (#71913)
  • Added HIP_HOME/include.lib in cpp_extensions (#75548)

CUDA

  • PyTorch: added support to NVTX range_start and range_end (#70030)
  • Show friendly error message when forgetting init in torch.cuda (#72404)
  • PyTorch GPU Allocator: better use of blocks with rounding of allocation sizes (#74213)
  • CUDACachingAlloc/GPUInference: implemented garbage collection without GPU sync (#74261)
  • CUBLAS/TF32: added environment variable to allow override of allow_tf32_cublas (#77114)

Intel

  • Bfloat16
    • Added BFloat16 support for torch.{nn.PReLU, nn.Upsample,nn.GLU, randperm, multinomial, poisson, nn.ELU, nn.SELU, nn.CELU, nn.LogSigmoid, nn.Hardsigmoid, nn.Hardshrink, nn.Softshrink, nn.Hardswish, nn.Softplus, nn.SmoothL1Loss, histc, atan2, logcumsumexp, diag, fmod, cumsum, cumprod, nn.utils.weight_norm , nn.BatchNorm2d} and allow autocast enabled (#63634, #58297, #61944, #63215 , #62546, #63134, #72694, #61897, #73845, #74410, #68725)
      • Improved torch.nn.functional.log_softmax on CPU when dim != -1 on both float32 and bfloat16 (#64726)
      • Improved torch.nn.functional.layer_norm bfloat16 performance on CPU (#71376)
      • Improved autocast cpu documentation (#68567)
  • Channels last
    • Add channels-last support for torch.nn.{conv2D(kernel slow_conv_dilated2d and thnn_conv2d, mkldnn as backend), GroupNorm, PixelShuffle, PixelUnshuffle}(#70665, #68101, #55584, #50573, #555864)
  • OneDNN
    • Upgraded oneDNN to v2.6.0, (#75398)
    • Added JIT graph fuser for oneDNN Graph API (v0.5) (#76622)
  • Quantization
    • Improve {qcat_nhwc, qupsample_bilinear2d, qupsample_nearest2d, qbatch_norm2d, qmax_pool2d, qavg_pool2d} performance on multi-core (#69667, #69601, #69600, #69599, #69598, #69517)
    • Add oneDNN as backend for quantization (#69820)
  • Improved torch{norm,argmax,argmin, scatter, gather} performance on CPU (#64479, #64478)
  • Improved torch.nn.functional{log_softmax``, softmax} performance on CPU (#73953)
  • Expanded graph rewrite to handle conv_transpose3d (#76888)
  • Expanded coverage of convolution folding in conv→mul→add→bn (#75724)
  • Added MKLDNN support for PReLU (#60427)

Composability

  • Added torch.nn.init to list of functions overridable by __torch_function__ (#76014)
  • Relaxed dtype restrictions on torch.Tensor(#73850)

Profiler

  • Enabled iteration tracking for kineto (#72292)
  • Added support for input sequence ID tracking for NVTX profiler (#70264)
  • Re-enabled user-annotations in PyTorch (#75601)
  • Added support to configure Kineto CUPTI profiler from PyTorch profiler interface (#75616)

Vulkan

  • Added an interface to obtain execution time data for GPU shader kernels when executing Vulkan operators (#75829)

Mobile

  • Improved Android instrumentation test and update README (#72736)
  • Improved unsupported scalar type error message for Android (#74660)

JIT/TorchScript

  • torch.jit.trace now treats tensor.numel() as aten::numel, instead of a constant value (#74081)
  • When printing out the types of a JIT Dict, with a tuple key, we now print out the types of the tuple if it is simple (#76164)
  • Added support for basic ops support for complex numbers in JIT, We now support op(complex, Tensor) for the following: add (+), mul (*), eq (==), ne (!=), sub (-), div (/) (#73286)
  • TorchScript now preserves the original exception message when rethrowing a Python-based exception (#77093)
  • Modified the conditions for conv folding in torch.jit.freeze to allow for folding arguments that can be promoted to floating point (eg integer tensor arguments) (#73278)
  • Reduced size of JIT debug.pkl files by only storing unique traces (#76688)
  • torch.jit.save and torch.jit.load are now supported for meta tensors ( aka torch.Tensor(device="meta")) (#73435)

Architecture Optimization

  • Added default symmetric qconfig for QNNPACK (#74396)

Quantization

  • Core (Quantized Tensor, Operator, Modules)
    • Added QAT fused Linear-Bn1d (#72431, #72796)
    • Added 4 bit support for embedding quantized module (re-land PR 69769) (#72276)
    • Enabled slicing on per-channel quantized tensors (support is limited to the a contiguous sliced tensor) and corresponding test case (#71269)
    • Added qint32 quantization support (#72472)
    • Added explicit entries for for functional and module conv and linear support into get_default_qconfig_dict&get_default_qat_qconfig_dict (#73528)
    • Added default symmetric QAT qconfig for QNNPACK (#74507)
    • Added Quantized Matmul Op (Naive Implementation) (#71783)
    • Added Quantized Softmax Op (Naive Implementation) (#75415)
    • Using QNNPACK in Quantized Softmax Op (#75799)
  • Eager Mode Quantization
    • Added 4 bit support for eager mode quantization flow (#72277)
  • FX Graph Mode Quantization
    • Added workflow support for torch.matmul quantization (#72444)
    • Added support conv1d and its fusion variants in QAT (#74506)
    • Decoupled prepare_*fx from training/eval modes (#75401)
    • Added quantized Softmax workflow integration (#75106)
    • Renamed default_affine_fixed_qparams_observer and default_symmetric_fixed_qparams_observer (#76637)

ONNX

  • Updated default opset_version to 13. The previous default was 9. To get the old behavior, just specify opset_version=9 when calling torch.onnx.export. Going forward we plan to update the default regularly to "latest as of 18 months ago". (#73898)
  • De-duplicated initializers to reduce ONNX model size for shared parameters (#69547, #74247)
  • Changed to capture annotated attributes for local function (#72883)
  • Improve error and warning messages (#71342, #73255, #73770, #73265)
  • Added support to exporting torch.minimum with different dtype combinations (#76022)
  • Improved Expand shape inference (#72985)
  • Added broadcast to matmul shape inference (#72990)
  • Rewrote linspace symbolic to improve numerical stability (#73610)
  • Enabled topk export with non-int64 k (#73761)
  • Enabled numel tracing (#74081)
  • Added constant folding for onnx::ReduceProd (#74082)
  • Added support to equality checks on devices (#77203)
  • Added support to dynamic dimensions in Squeeze and Unsqueeze (#73104)

torch.package

  • Added Python Version to Torch.Package metadata (#74610)
  • Added utility for determining where bad modules may come from (#74998)

Distributed

  • torch.distributed
    • Refactored TORCH_DISTRIBUTED_DEBUG implementation (#73166)
    • Set default value of TCPStore world_size to None in pybind definition (#77277)
    • Added orthogonalization with QR factorization (#72043)
    • Added pickling support for WorkerInfo (#73371)
    • Added support for RRefs that contain threading.Thread (#74462)
    • Added check for mismatch in number of parameters in verify_params_across_processes (#74113)
    • Added support for backend to register reducer timer (#71700)
    • Made ProcessGroupNCCL load torch_ucc.so when TORCH_UCC_LIBRARY_PATH is set (#69552)
    • Added support for non-contiguous inputs for nn.functional.all_gather/reducescatter/gather (#75276)
    • Added the use of batched operations for PowerSGD (#76041)
    • Changed to create UCC ProcessGroup when ucc_lib available (#69564)
    • Changed to generalize param verification and broadcast (#76374)
    • Changed to use a more reliable signaling mechanism to stop TCPStore background threads (#76973)
    • Added support to disabling post-local gradient sync (#76723)
    • Removed call into Python API without GIL being held in c10d (#72928)
  • FullyShardedDataParallel
    • Fixed summon_full_params when not sharded (#72572)
    • Fixed 0-dim tensor optim state device (#75243)
    • Fixed the synchronization of all_gather stream in summon_full_params (#73314)
    • Added state_dict() save/reload in parity test (#73366)
    • Changed to use unflatten_parameter in _summon_full_parameters (#72467)
    • Changed to use summon_full_params in get_full_params (#73242)
    • Added generic arguments for state_dict (#73323)
    • Added generic argument forward for load_local_state_dict (#73325)
    • Made summon_full_params a public method (#73116)
    • Generalized fsdp_modules() (#73553)
    • Introduced a utility API to allow users easily to set state_dict_type (#73716)
    • Added an option to summon on rank 0 only in summon_full_params (#73903)
    • Enabled offload full params to CPU in summon_full_params (#73904)
    • Removed _lazy_init() in rebuild full params (#74263)
    • Changed to override named_parameters() for clean names in summon_full_params() (#74333)
    • Changed to strip FSDP info in summon_full_params context, similar to named_params in named_buffers (#74517)
    • Change to use param name as key in full_optim_state_dict (#74879)
    • Enabled re-key between param names/IDs for full_optim_state_dict (#74912)
    • Changed to register state_dict hooks for FlatParamsWrapper even if params_list is empty (#74860)
    • Made apply_to_tensors support OrderedDict type (#75560)
    • Added rank0_only to full_optim_state_dict() (#75516)
    • Made summon_full_params a static method (#75423)
    • Added support for PackedSequence type for apply_for_tensors (#76265)
    • Made mixed precision API configurable (#76423)
    • Validated exec order using compute_device (#76664)
    • Improved dict inversion in _get_param_name_to_param to be faster(#76665)
    • Changed to ignore params if not in Optim state dict (#76671)
    • Changed to include buffers in ignored_modules (#76784)
    • Moved param/buffer name computation to constructor for ignored_modules (#76994)
    • Changed to not clone buffers and ensure that we offload buffers to CPU if specified (#77000)
    • Profiling range for FSDP.forward (#76899)
    • Disabled the default behavior of moving CPU module to GPU (#77720)
    • Fixed _get_param_to_unflat_param_names() for shared params (#75430)
  • ShardedTensor (prototype)
    • Changed to use absolute imports for ShardMetadata instead (#73678)
    • Fixed the metadata error in init_from_local_shards with deepcopy (#73400)
    • Fixed view op and matrix ops unit test (#77706)
  • torch.distributed.rpc
    • Improved logging from 'unknown destination worker' (#75811)
    • Improved logging for store.wait error (#76548)
    • Added support for RPC Meta device (#76882)
    • Changed to keep stacktrace when rewriting AttributeError (#73720)
  • DistributedDataParallel
    • Improved debug level and logging (#72455)
    • Removed bucket replicas (#73567)
    • Made HierarchicalModelAverager a subclass of averagers.ModelAverager (#74564)
    • Made code simplification for _find_process_group function (#75007)
    • Made distributed raise ImportError when not available (#75975)
  • torch.distributed.elastic
    • Created a final agent barrier to shutdown process properly (#74931)

Bug fixes

Python API

  • Fixed type promotion for torch.where (#76691)
  • Fixed torch.clamp to correctly propagate nans (#77306)
  • Fixed torch.unique to preserve input size when dim is zero-length (#75764)
  • Fixed torch.ravel to also return contiguous outputs for non-contiguous inputs(#71771)
  • Fixed CosineAnnealingLR to resume last learning rate on restart (#60339)
  • Fixed invalid shape error for torch.fft.{irfft2,irfft2} (#73012)
  • Fixed torch.set_default_dtype to no longer crash with invalid dtype (#72405)
  • Fixed torch.tril edge case (#75335)
  • Fixed torch.broadcast_shapes to not handle shapes with negative dimensions. (#72999)
  • Fixed torch.logsumexp integral to float type promotion (#77480)
  • Fixed torch.amax and torch.amin for empty tensors if dim arg not provided. (#73914)
  • Disallowed calling .tolist on tensors with nullptr storage (#75990)
  • Fixed .tolist to work correctly work for 0 element tensors (#76335)
  • Adjusted the stubs for PyCharm autocompletion of the Tensor methods. (#76712)
  • Fixed Optimizer.zero_grad type annotation (#76998)
  • Fixed torch.distributions.lkj_cholesky device error (#73980)
  • Fixed misplaced type annotation for torch.distributions.transforms.CatTransform (#73747)
  • Fixed torch.clamp scalar overloads to propagate nan (#77371)
  • Fixed advanced indexing assignment when use_deterministic_algorithms(True) for non-contiguous tensors (#76220)
  • Fixed **= operator (#76900)
  • Fixed to to properly support permutation (#77610)

C++ API

  • Used the same checks in all grid_sampler functions (#75164)
  • Fixed mean bug for integral tensors (#76584)
  • Added missing import to fix crash on loading cpp extension (#75736)

Autograd

  • Fixed forward AD formula for torch.angle (#77267)
  • Fixed torch.{minimum, maximum} forward AD formula for float32 (#75277)
  • Fixed forward-mode AD formula for torch.nn.functional.binary_cross_entropy_with_logits (#76322)
  • Fixed gradients for norm related ops at zero when p < 1 to mask out nans (#75103)
  • Fixed forward-over-reverse for convolution to no longer fail in some cases (#75298)
  • Fixed torch.autograd.gradcheck to run with requires_grad=False when check_forward_ad=True (#72309)
  • Fixed requires_grad-ness to be propagated for all backends when tensors are deep-copied (#76256)
  • Fixed torch.autograd.grad to automatically needs an extra tuple when handling single outputs and is_grads_batched=True (#75779)
  • Updated forward AD metadata check to skip stride check when size is 0 (#77269)
  • Fixed deadlock an edge case in autograd (#73961)
  • Allow forking until a worker thread is created in autograd engine (#72689)
  • Removed some spurious warnings in the autograd engine (#72542)
  • Fixed issue with torch.utils.checkpoint.checkpoint when both use_reentrant and preserve_rng_state set to False (#76890)
  • Fixed Python indexing set item to scalar tensor preserve autograd graph (#78746)

Build

  • Added TORCH_CUDA_CU_API to CUDABlas functions (#72340)
  • Fixed doc build for release branches (#72567)
  • Moved AndroidNightly to GHA (#74243)
  • Changed numModules type to unsigned (#74978)
  • In Kineto, Changed to not search for CUPTI in default paths (#76188)
  • Changed to use TensorPipe libuv in Gloo (#77312)

Complex Numbers

  • Fixed segmentation fault when real and imaginary attributes of a tensor are set to a number (#73867)
  • Fixed complex to real casting warning in the backward’s pass for Real→Complex copy (#75805)
  • Make torch.addcmul and torch.addcdiv support different complex and non-complex type args together (#74234)
  • Fixed torch.isfinite for complex to avoid overflow when real and imaginary values are finite but abs is infinite (#76606).
  • Fixed complex abs/angle output format (#77585)

Dataloader

  • Reset worker cycle for persistent DataLoader to ensure determinism across epochs (#73675)

LinAlg

  • Fixed SVD error code handling for OpenBLAS 0.3.15+ and MKL 2022+(#72357)
  • Fixed addmm_cpu for int64 (#75200)

Meta API

  • Fixed meta kernel for normal_ when std is equal to 0 (#70085)
  • Fixed torch.kaiser_window : meta for window_length > 1 (#73733)
  • Fixed meta kernel for normal (#77740)

torch.nn

  • F.pad: Silence error when unused fill value is zero (#76307)
  • nn.{Conv1d, Conv2d, Conv3d}: Properly initialize grad_weight in raw_cudnn_convolution_backward_weight_out (#72157)
  • nn.Conv2d: Fix channels last propagation for naive algorithm (#77347)
  • nn.ConvTranspose*d: Fix to support no-batch-dim inputs with output_size (#76151)
  • nn.CrossEntropyLoss: Support no-batch-dim input with probability target (#77653)
  • nn.CrossEntropyLoss: Fix to avoid floating point exception for zero-size inputs (#73837)
  • nn.GroupNorm: Ensure num_groups > 0 in native_group_norm (#75270)
  • nn.MaxPool2d: Properly support dilation in channels last kernel (#76597)
  • nn.ParameterList: Fix __dir__ implementation (#74997)
  • nn.{ParameterList, ParameterDict}: Support containing any kind of object (#70499)
  • nn.RReLU: Fix to support empty tensor inputs (#70496)
  • nn.utils.rnn.pad_sequence: Fix regression; support tensor input for sequences (#72436)
  • nn.utils.stateless.functional_call: Properly support setting attributes during forward (#77137)

torch.fx

  • Core
    • Made map_aggregate/map_arg work for NamedTuple (#73198)
    • Fixed tracing for OpOverload (#73940)
    • Fixed codegen for bare generic type annotations (#74135)
    • Modified __deepcopy__ to also copy _codegen (#75851)
    • Fixed unnecessary recursion in GraphModule.__call__ (#76068)
    • Changed to prevent infinite recursion in GraphModule (#73866)
    • Changed to preserve codegen on FX graph in transformer (#74189)
  • operator_schemas
    • Added back check for OpOverload (#73978)
    • Fixed normalize_function to consider OpOverloads (#76469)
    • Fixed for normalizing signature for op overloads (#77182)
  • For testing, added super() calls for FX TestCases (#74216)
  • For split_module, made split_module preserve proper placeholder names (#74736)

Sparse

  • Fixed ignored beta value for sparse inputs to torch.addmm with non-MKL build (#72430)
  • Fixed float16/bf16 support for sparse inputs to torch.addmm (#72559)
  • Fixed CUDA error for torch.mul when given COO Tensors with zero sized dense dimensions (#73428)
  • Fixed incorrect results of torch.sparse.sampled_addmm for noncontiguous inputs (#76590)
  • Fixed runtime generation of doc strings for torch._masked functions by making them static instead (#72865)

CUDA

  • Created jiterator cache dirs recursively (#74592)
  • Fixed bincount to use acc scalar for the bounds (#76979)
  • Avoid collections deprecation warning (#72239)
  • Disabled cuBLASLt when batch is too large. (#73533)
  • Abated spurious resize warnings in MultiMarginLoss on CUDA (#75000)
  • Added missing AT_CUDA_CHECK in CUDAGraph.cpp (#74392)
  • CUDA graphs
    • Fixed OOM inside graph capture_begin (#76247)
    • Changed to allow Adam and AdamW to be capture-safe (#77862)

Intel

  • Fixed Caffe2 convolution issue in AVX512 when using oneDNN v2.5.2 (#73290)

Composability

  • Fixed formatting of scalar tensors for the meta device (don't call item) (#74376)
  • Fixed to metadata preservation for Python tensor subclasses: preserve Python dispatch keys when copying tensor metadata (#75644)
  • Fixed data race on TensorImpl::wns_pyobj_ accesses with non-GIL protected threads (#75563)
  • Fixed for Python garbage collector can sometimes deallocate a tensor, even when C++ still has strong references to it (#75933)
  • Added better error checking to TensorImpl::size_between_dim_. (#76719)
  • Changed to ensure that torch.memory_format instances are singletons (#77543)

Profiler

  • Avoided picking up old CUPTI headers (#72761)
  • Kineto submodule update and fixes (#75206)
  • Fixed segfault in AppendOnlyList (#78084)

Vulkan

  • Fixed a bug in the Vulkan implementation of aten::tanh where inputs of large magnitudes would result in numerically unstable results (#73107)
  • Fixed a bug in the Vulkan implementation of aten::add, aten::sub, aten::mul, and aten::div where passing in a single element tensor as a second argument would result in an assertion error (#73108)

Mobile

  • Changed to protect against threading errors when tracing models with parallel operators (#73327)
  • Changed to ensure error messages are preserved from Metal and CoreML Backend (#77430, #76236)
  • Changed to ensure the iOS test app is working correctly (#74090)
  • Fixed off-by-one error in tupleIndex (#72447)
  • Fixed error in export of models containing nested NamedTuple (#75996)

Distributed

  • torch.distributed
    • Fixed process group wrapper check for Gloo (#72657 (#72657))
    • Changes to catch CUDA library runtime error (driver shutting down) during the exit of ProcessGroup (#74258 (#74258))
    • Fixed NCCL version string (#73333 (#73333))
    • Add retry DNS lookup failures (#74641 (#74641))
    • Validated that tensors are contiguous in ProcessGroupNCCL (#77809 (#77809))
    • Fixed sign-compare in c10d/Utils.hpp (#75081 (#75081))
    • Fixed NCCL gather outputs on non-root ranks (#75535 (#75535))
    • Fixed batch_isend_irecv (#74701 (#74701))
    • Disabled RPC profiling for kineto profilers (#76234 (#76234))
    • Typo fix in generated module name (#76880 (#76880))
    • Fixed broadcast for channels-last tensors (#79071 (#79071))
  • DistributedDataParallel
    • Disabled bucketing for the first iteration (#72843 (#72843))
    • Fixed SyncBatchNorm for empty inputs (#74944 (#74944))
    • Added a guard for non CPU/CUDA devices (#75247 (#75247))
    • Fixed bug where getstate of DDP looks for self._replicated_tensor_module when not using ReplicatedTensor. (#76349 (#76349))
    • Fixed post_localSGD_optimizer by calling optim.step only once when there are multiple param groups or params (#74737 (#74737))
    • Fixed PostLocalSGDOptimizer and ModelAverager average (#74894 (#74894))
  • ShardedTensor (prototype)
    • Fixed Sharding spec inference to avoid invalid chunk sharding to be inferred as chunkshardingspec (#75296 (#75296))
  • FullyShardedDataParallel
    • Fixed no_sync() + FULL_SHARD root all-gather behavior (#75901 (#75901))
    • Fixed exec order validation (static variable issue) (#76273 (#76273))
    • Fixed local_state_dict and state_dict_type bugs (#77101 (#77101))
    • Fixed FSDP wrapping for batchnorm when mixed precision enabled (#77234 (#77234))
    • Fixed CheckpointWrapper state_dict to enable wrapped modules loaded into non-checkpointed wrapped module (#77224 (#77224))
    • Changed to relax exec order valid. to only forward pass (#76556 (#76556))
    • Changed to not check forward order in eval mode (#77195 (#77195))
    • Changed to pass device_id into recursive_wrap for FSDP (#77491 (#77491))

JIT/TorchScript

  • torch.jit.fuser("fuser1") is supposed to enable NNC fusion, but it currently only enables gpu fusion. This will enable CPU fusion as well. (#74078)
  • Fixed bug where when parsing a Python TernaryIf expression (x if y else z) was not being parsed into TorchScript using torch.jit.script as right associative (#68416)
  • Got rid of TorchScript sparse tensor is experimental warning. (#73874)
  • Custom post-processing passes registered through torch::jit::RegisterPass now have access to profiled Tensor Type Specializations (#71748)
  • When registering a custom print handler for prim::print() inside torch.deploy, we restore the default print handler when all Python environments are destroyed to prevent errors from not having a Python environment. (#74513)
  • When running torch.jit.freeze on the backward passes of conv (conv_bn) with reduced precision (eg bfloat16) , fusions will respect the precision of the original op, instead of promoting to float32 (#77042)
  • Loosened torch.jit.script type checks that were too strict for the torch.nn.LPPool2D and torch.nn.functional.lp_pool2d functions (#73287)
  • torch.nn.ParameterList is now subscriptable in TorchScript (#75479)

Quantization

  • Fixed get_module_type for fusion (#72735)
  • Fixed bug in QuantWrapper with DeQuant qconfig (#73671)
  • Fixed observer insertion through dtype propagation (#73274)
  • Only do reference module swapping for floating point fused modules (#74231)
  • Fixed dynamic weighted op lowering when input is used multiple times (#74364)
  • Fixed get_default_qconfig_dict for fused modules (#75838)
  • Fixed bug for ave pooling in FX quant (#73054)
  • Fixed FX QAT for untraceable modules (#74277)
  • Fixed qmin/qmax when using customized ‘qrange’ (#74717)

ONNX

  • Fixed repeat interleave when repeats and dim is 1 (#73760)
  • Fixed ONNX gather shape inference (#73607)
  • Fixed 1d case flatten export (#74595)
  • Fixed opset_version checked before set (#76928)
  • Fixed an assertion failure involving Slice (#72989)
  • Fixed LSTM reshape shape inference regression (#72532)
  • Fixed Caffe2 ONNX export for environment with newer ONNX (#75718)
  • Refactored test/onnx/test_onnx_export.py for better code reuse (#76851)
  • Fixed aten::to("cpu") and aten::to(device="cpu") (#76498)
  • Fixed BatchNormalization for invalid dtype (#74875)
  • Added Autocast support for einsum (#71916)

torch.package

  • Deploy: added dummy metadata for builtin packages (#76211)
  • Enabled module modification during repackaging (#71520)
  • Added test case for repackaging parent module (#72367)
  • Fixed orderedimporter dummy package check (#72533)
  • Improved error message for module detection on saving pass (#73106)
  • Changed to allow torch/csrc/deploy/interpreter/Optional.hpp to be allowed into the wheel distribution (#74643)

Performance

Python API

  • Improved torch.topk performance on CUDA (#74267)
  • Added SIMD horizontal reduce to improve torch.log_softmax and torch.softmax performance on CPU (#73953)
  • Made small optimizations for torch.view (#72626)
  • Optimized dim reduce performance on torch.{norm, argmax, argmin} (#72083)
  • Improved CPU performance for torch.log_softmax when dim != -1 on both float32 and bfloat16 (#72163)
  • Improved torch.softmax dim=-1 performance on bfloat16 by adding more fusion (#76278)
  • Removed duplicate call to objective function in strong Wolfe line search in L-BFGS optimizer. (#72773)

Autograd

  • Optimized code-generated in-place forward AD formulas (#74017)
  • Added a fast path for torch.{stack, cat} forward AD computation when tangents are zero-filled (#75590)
  • Reduced forward AD recomputation for linalg.{eig,eigh,svd} when function returns multiple outputs (#75583)

Sparse

  • Improved performance of index_select for COO inputs on CPU (#72710)
  • Improved performance of index_add on CUDA (#76996)

Dataloader

  • Improved the performance of BatchSampler (#76951)

AMD

  • Enabled foreach fast path (#74417)
  • Reverted cat operator performance work-around (#74129)

CUDA

  • Removed sync in embedding (#70943)
  • Added fused addmm path in linear for contiguous 3D input (#72728)
  • Changed to use cub 1.15's latest scan-by-key algorithm to replace thrust for Embedding.cu and EmbeddingBag.cu (#66580)
  • Changed to use cub::DeviceSelect::UniqueByKey for EmbeddingBackward (#68376)
  • Changed to use cuBLASLt interface for bias fusion (#72148)
  • Set workspace size for cuBLASLt interface 1M (#73439)
  • Added fastAtomicAdd to scatter_add [v2] (#75545)
  • Added a new optimized cuDNN RNN algorithm for small RNN hidden_size (#73211)
  • Avoided CPU Sync in SyncBatchNorm When Capturing CUDA Graphs (#78810) (commit)
  • Added Autocast CPU doc (#68567)
  • Documented CUDA 11.5 windows issue (#73013)
  • Added __all__ for torch.cuda.memory (#76490)

Composability

  • Improved performance for forward-mode AD with at::sub: added ZeroTensor fast-path (#75587)

torch.nn

  • nn.EmbeddingBag: Removed out-of-bounds check to improve CUDA performance (#74767)
  • nn.GELU: Added support tanh-based approximation (#61439)
  • nn.GroupNorm: Improved channels last performance on CPU (#69067)
  • nn.LayerNorm: Improved bfloat16 performance on CPU (#71376)
  • nn.LayerNorm: Added mixed data type mode for forward path (#73844)
  • nn.MultiheadAttention: Fast path using nested tensors for inference under specific conditions (#77924, #77761)
  • nn.MultiheadAttention: Fuse the attn_mask addition (#73219, #72871))
  • nn.MultiheadAttention: Native fast path under specific conditions (#75809, #76333, #72944, #72941, #72671, #72375, #72458, #72464, #72463)
  • nn.MultiheadAttention: Preserve identity relationships among query, key, and value for batch_first=True (#73053)
  • nn.utils.weight_norm: Added native CPU kernel (#73845)
  • F.grid_sample: Improved backward pass scaling with input size for 3d implementation (#71759)

Benchmark

  • Added binary to benchmark model load speed (#74700)

Profiler

Mobile

  • Reduced unnecessary reference count bumps while parsing ByteCode. (#72523)

Quantization

  • Improved multi-core performance of qavg_pool2d (#69517)
  • Improved multi-core performance of qmax_pool2d (#69598)
  • Improved multi-core performance of qbatch_norm2d (#69599)
  • Improved multi-core performance of qupsample_nearest2d (#69600)
  • Improved multi-core performance of qupsample_bilinear2d (#69601)
  • Improved qcat_nhwc performance on both multi-core and single-core (#69667)
  • Added Optimized QInt8 Quantize Tensor Arm (#76245)

Documentation

Python API

  • Updated torch.amp document with CPU Training/Inference Examples (#77244)
  • Updated torch.utils.dlpack.from_dlpack documentation (#70543)
  • Fixed indexing of class names in docs for torch.{device, dtype, layout, memory_format} (#73632)
  • Fixed torch.asarray docs and add test case (#73736)
  • Removed misleading statement in optim.Optimizer docs (#76967)
  • Fixed nesterov momentum equation for torch.optim.SGD (#76639)
  • Added missing zero-ing step in torch.optim.Rprop algorithm (#75555)
  • Fixed docs about type promotion of torch.{bitwise_left_shift,bitwise_right_shift} (#77613)
  • Fixed docstring for torch.roll (#74880)
  • Added docs for torch.scatter_reduce (#73125)
  • Automatically generate docstring for torch.distributions.kl_divergence (#72845)
  • Miscellaneous documentation improvements (#74796, #76369)

C++ API

  • Exposed documentation for unfold (#74224)

Autograd

  • Fixed error in “Autograd Mechanics” doc’s eval mode section (#74807)
  • Added “Gradients for non-differentiable functions” section in "Autograd Mechanics" doc to explain how gradients are chosen in edge cases (#76898)
  • Added link to "Custom function double backward tutorial" from "Extending Pytorch" page (#72584)
  • Documented forward AD interaction with grad mode (#72216)
  • Fixed code examples to run successfully (#74044)

Dataloader

  • Updated DataLoader docstring about prefetch_factor to reflect right amount of batches prefetched by DataLoader (#74558)
  • Fixed docstring for collate_fn (#76594)

LinAlg

  • Extrapolated on equiv between linalg @ and solve (#71769)
  • Updated torch.lu_unpack docs (#73803)

torch.nn

  • nn.CosineEmbeddingLoss: Use correct cosine similarity term instead of cosine distance (#75188)
  • nn.Hardtanh: Use min_val and max_val in function definition (#75789)
  • nn.KLDivLoss: Fixed log_target example (#74945)
  • nn.``LazyModuleMixin Fixed typo in docs (#76269)
  • nn.LSTM: Clarified docs for outputs vs. hidden states (#74291)
  • nn.Module: Fixed docs by moving _version class variable after docstring (#72912)
  • nn.Module: Fixed docstring typo for get_submodule() (#73018)
  • nn.Module: Fixed URL for creating GitHub issues (#73411)
  • nn.RNN: Fixed math notation for linear projections (#77082)
  • nn.Transformer: Detailed 3D tensor shape for masks (#75552)
  • nn.TripletMarginLoss: Fixed formatting error (#76629)
  • F.{conv3d, conv_transpose3d, fold, linear}, nn.{AdaptiveAvgPool3d, AvgPool1d, MultiMarginLoss, PairwiseDistance, TripletMarginLoss}: Fixed doc formatting regressions (#73014)
  • F.multi_head_attention_forward: Added to functional rst (#72675)
  • F.multi_head_attention_forward: Fixed math formatting, misc edit (#74181)
  • F.pad: Fixed supported input shapes in docs (#76117)
  • nn.init.trunc_normal_: Added to nn.init docs (#76896)
  • nn.utils.clip_grad_norm_: Fixed return value description (#76230)
  • nn.Convolution: Added note on complex support (#78351)

torch.fx

  • Added better error message for FX when using concrete_args (#76600)

Composability

  • Added docs for Python Registration (#79481)

Sparse

  • Added missing entry for torch.sparse.sampled_addmm on website (#72312)

Mobile

  • Documentation improvement in test_backend_with_compiler (52c516e)
  • Added README for mobile model test (#76385, #76409)

Distributed

  • torch.distributed
    • Clarified the input of PostLocalSGDState (#72792)
    • Added a reference to hierarchical SGD for Model Averaging (#73823)
    • Updated documentation about NCCL environment variables (#74006)
    • Added TORCH_CPP_LOG_LEVEL to the docs (#76625)
  • FullyShardedDataParallel
    • Improved the documentation of state_dict (#73453)
    • Updated full_optim_state_dict warning (#75109)
    • Added warning when fail to clone (#74946)
    • Added mixed precision doc (#76130)
    • Added warnings for shared params and updated doc (#77726)
    • Fixed state_dict_type() example (#77848)
    • Reworded device placement warning (#77850)
    • Updated state_dict() docstring (#77853)
  • torch.distributed.rpc
    • Added note in RPC docs about retries. (#73601)
  • DistributedDataParallel
    • Updated the comment for Forward and Backward Hook (#74063)
    • Added documentation for c10d log levels (#73361)
  • torch.distributed.elastic
    • Added documentation clarifying that torchrun is a console script to torch.distributed.run (#73598)

TorchScript

  • Corrected torch.jit.Attribute docs to say that it needs to be used in subclasses of torch.jit.ScriptModule, not torch.nn.Module (#74653)

Quantization

  • Added docs for torch.quantize_per_tensor_dynamic (#72311)
  • Fixed typo in quantization docs (#73511)
  • Grammatically updated quantization tech doc (#74436)
  • Added best practices for quantization accuracy debugging (#77536)
  • Improved rendered documentation for backend_config_dict (#77535)
  • Autogenerated quantization backend configs for documentation (#75126)
  • Added more docs for quantization.rst (#75998)
  • Fixed formatting for quantization.rst (#76223)

ONNX

  • Added the developing PyTorch ONNX exporter wiki doc link (#72663)
  • Added list of supported ATen ops to torch.onnx page (#74397)

Visualization

  • torch.utils.tensorboard.writer: Added missing 'dataformats' argument to 'add_image' docs. (#48834)