Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2245d7d
Improve char printing (#167899)
cyyever Nov 16, 2025
5d99a79
[xpu][test] Migrated two test files to XPU (#166684)
shangerxin Nov 16, 2025
e2e1075
Allow same triton kernels in export (#167862)
minjang Nov 16, 2025
363385a
s/Stragety/Strategy/ (#167916)
ezyang Nov 16, 2025
4322354
[Inductor] optimize scalar welford_reduce (#162709)
jiayisunx Nov 14, 2025
d8ce6f8
Enable PyTorch OSS numerics changes, inductor heuristics (#167799)
PaulZhang12 Nov 17, 2025
aa504d4
[audio hash update] update the pinned audio hash (#167914)
pytorchupdatebot Nov 17, 2025
f2e6f94
deprecate check_is_size and guard_size_oblivious (#167198)
laithsakka Nov 15, 2025
ca3aaef
Fix clamp broadcasting on MPS (Fixes #160734) (#165058)
roei-shlezinger Nov 17, 2025
b9bccec
Revert "[ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)"
pytorchmergebot Nov 17, 2025
99117c1
Remove old NVTX interface (#167637)
Aidyn-A Nov 17, 2025
5804408
[1/3][XPU][feature] The implementation of memory private pool in XPU …
majing921201 Nov 17, 2025
93ddd38
Re-land#2 "Fix thread safety in getCurrentCUDABlasHandle and getCUDAB…
t-ivan-gr Nov 17, 2025
53809f9
[ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI ke…
usamahz Nov 17, 2025
661d165
[xla hash update] update the pinned xla hash (#167968)
pytorchupdatebot Nov 17, 2025
6fdb974
Update torch-xpu-ops commit pin (#167698)
CuiYifeng Nov 17, 2025
9ff95f6
[inductor] Expose config for fx bucket all_reduces (#167634)
IvanKobzarev Nov 12, 2025
2b5eabc
Rework PyObject preservation (v2) (#167564)
colesbury Nov 17, 2025
2f74916
Do not hardfail on use nccl estimations for non-nccl (#167827)
IvanKobzarev Nov 17, 2025
2b69673
[CD] Add libopenblas to dep list for AArch64+CPU whl (#167841)
robert-hardwick Nov 17, 2025
1b43d6c
[ROCm] enable fastSpecializedAtomicAdd for gfx950 (#167661)
jeffdaily Nov 17, 2025
4c152a7
Revert "add device generalization support for distributed tests (#165…
pytorchmergebot Nov 17, 2025
39ebab1
Revert "Remove python workaround for ContextDecorator (#167049)"
pytorchmergebot Nov 17, 2025
22ccd44
Revert "Improve char printing (#167899)"
pytorchmergebot Nov 17, 2025
a4c7bf7
Revert "Use c10::filesystem (#167821)"
pytorchmergebot Nov 17, 2025
094e529
[MPS] Fix repeat_interleave with slices (#167961)
malfet Nov 17, 2025
95d1df7
Disable CUDA MXFP4 on non-B200 GPUs (#167857)
slayton58 Nov 17, 2025
77acc66
[ROCm][CI] Upgrade ROCm CI to 7.1 (#166743)
xinyazhang Nov 17, 2025
567dcdb
Fix longstanding race condition around getAllOperatorsFor (#167860)
swolchok Nov 14, 2025
2f3bb74
Improve benchmarks/dynamo:check_perf_csv output and failure summary (…
adabeyta Nov 17, 2025
ae3ce54
Revert "[ROCm] Enable StaticCudaLauncher for ROCm (#166492)"
pytorchmergebot Nov 17, 2025
02b55c3
Move isQIntType to headeronly (#167772)
pearu Nov 16, 2025
1233be0
[STABLE ABI] Add mutable_data_ptr() and const_data_ptr() methods to t…
pearu Nov 16, 2025
01deee2
Fix dataloader tests failing on python 3.14 (#167429)
divyanshk Nov 17, 2025
694f9b9
Revert "[ROCm][CI] Upgrade ROCm CI to 7.1 (#166743)"
pytorchmergebot Nov 17, 2025
4414e1b
Cleanup in inductor usage of nccl estimator after its fix (#167633)
IvanKobzarev Nov 17, 2025
b288d00
[inductor] unittest for run2run determinism (#167482)
shunting314 Nov 15, 2025
689d731
[inductor] fix the decision of inner reduction (#167697)
shunting314 Nov 15, 2025
2ddcf53
Logaddexp complex inconsistent bw cpu and cuda (#163509)
cleonard530 Nov 17, 2025
a892f76
[MPS] mm out sparse (#167908)
Isalia20 Nov 17, 2025
927899d
fixes a few issues with out_dtype overload for addmm/baddbmm (#167931)
ngimel Nov 17, 2025
9d8ceaa
Revert "[ARM] Improve LLM performance & mem usage using int4-bf16 Kle…
pytorchmergebot Nov 17, 2025
bdd3c3a
Support SymInt placeholder in wrapper fxir (#167757)
nandesuka Nov 17, 2025
4e1b772
Fix: Improve fallback behavior in `deserialize_torch_artifact` and re…
abhitorch81 Nov 17, 2025
661fb53
Revert "Remove old NVTX interface (#167637)"
pytorchmergebot Nov 17, 2025
1c04a43
Revert "Tiling bug fix (#167771)"
pytorchmergebot Nov 17, 2025
151fae4
Update base for Update on "Test that TORCH_FEATURE_VERSION guards are…
mikaylagawarecki Nov 18, 2025
4d1947b
Update base for Update on "Test that TORCH_FEATURE_VERSION guards are…
mikaylagawarecki Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .ci/manywheel/build_cpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ if [[ "$ARCH" == "aarch64" ]]; then
# ARM system libraries
DEPS_LIST+=(
"/usr/lib64/libgfortran.so.5"
"/opt/OpenBLAS/lib/libopenblas.so.0"
)
DEPS_SONAME+=(
"libgfortran.so.5"
"libopenblas.so.0"
)
fi

Expand Down
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/audio.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
07b6cbde121417a70e4dc871adb6d27030e0ce3f
ee1a1350eb37804b94334768f328144f058f14e9
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
94631807d22c09723dd006f7be5beb649d5f88d0
3 changes: 3 additions & 0 deletions aten/src/ATen/core/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ class TORCH_API TensorBase {
size_t weak_use_count() const noexcept {
return impl_.weak_use_count();
}
bool is_uniquely_owned() const noexcept {
return impl_.is_uniquely_owned();
}

std::string toString() const;

Expand Down
10 changes: 8 additions & 2 deletions aten/src/ATen/cuda/CUDAContextLight.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <cstdint>
#include <map>
#include <shared_mutex>

#include <cuda_runtime_api.h>
#include <cusparse.h>
Expand Down Expand Up @@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();

TORCH_CUDA_CPP_API void clearCublasWorkspaces();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
struct WorkspaceMapWithMutex {
std::map<std::tuple<void*, void*>, at::DataPtr> map;
std::shared_mutex mutex;
};

TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();
Expand Down
95 changes: 75 additions & 20 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,27 +99,35 @@ void destroyCublasHandle(cublasHandle_t handle) {
// - Comments of @soumith copied from cuDNN handle pool implementation
#ifdef NO_CUDNN_DESTROY_HANDLE
#else
cublasDestroy(handle);
cublasDestroy(handle);
#endif
}

using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle, destroyCublasHandle>;

} // namespace

std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
return instance;
}

std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
return instance;
}

void clearCublasWorkspaces() {
cublas_handle_stream_to_workspace().clear();
cublaslt_handle_stream_to_workspace().clear();
{
auto& workspace = cublas_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
{
auto& workspace = cublaslt_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
}

size_t parseChosenWorkspaceSize() {
Expand Down Expand Up @@ -233,6 +241,38 @@ at::DataPtr getNewCUDABlasLtWorkspace() {
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
}

void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) {
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));

auto& workspace = cublas_handle_stream_to_workspace();

size_t workspace_size = getChosenWorkspaceSize();

// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
handle, workspace_it->second.get(), workspace_size));
return;
}
}

// Slow path: allocate workspace outside the lock
auto new_workspace = getNewWorkspace();

// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first;
TORCH_CUDABLAS_CHECK(
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
}
}

void* getCUDABlasLtWorkspace() {
#ifndef USE_ROCM
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
Expand All @@ -241,20 +281,40 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
return workspace_it->second.mutable_get();
}
#endif
cublasLtHandle_t handle = getCurrentCUDABlasLtHandle();
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});

auto& workspace = cublaslt_handle_stream_to_workspace();

// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
return workspace_it->second.mutable_get();
}
}

// Slow path: allocate workspace outside the lock
auto new_workspace = getNewCUDABlasLtWorkspace();

// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it =
workspace.map.try_emplace(key, std::move(new_workspace)).first;
return workspace_it->second.mutable_get();
}
return workspace_it->second.mutable_get();
}

cublasHandle_t getCurrentCUDABlasHandle() {
Expand Down Expand Up @@ -298,13 +358,8 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// will allocate memory dynamically (even if they're cheap) outside
// PyTorch's CUDA caching allocator. It's possible that CCA used up
// all the memory and cublas's cudaMallocAsync will return OOM
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
}
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
setWorkspaceForHandle(handle, stream);

#if !defined(USE_ROCM)
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
Expand Down
83 changes: 38 additions & 45 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,20 +296,16 @@ template <typename scalar_t, typename res_scalar_t = scalar_t>
bool launchGemmAndBiasCublasLt(
// args contains result which is modified
cublasCommonArgs& args,
const Tensor& self,
const std::optional<Tensor>& self,
const Scalar& alpha,
Activation activation = Activation::None
) {
// We apply bias in the epilogue only when it is 1D,
// or when it can be squeezed to 1D.
// self_ptr == nullptr implies ignore bias epilogue
// and use standard gemm-like API.
const auto* self_ptr = [&]() -> auto {
if (self.dim() == 1 || self.squeeze().dim() == 1) {
return self.const_data_ptr<scalar_t>();
}
return static_cast<const scalar_t*>(nullptr);
}();
const auto* self_ptr = self.has_value() ? self.value().const_data_ptr<scalar_t>() : static_cast<const scalar_t*>(nullptr);


const auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
Expand Down Expand Up @@ -392,35 +388,30 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
bool disable_addmm_cuda_lt = persistent_disable_addmm_cuda_lt || disable_addmm_cuda_lt_override;
#ifdef USE_ROCM
// Conditioned on the device index, which is not persistent
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
disable_addmm_cuda_lt = disable_addmm_cuda_lt || isGloballyDisabledAddmmCudaLt(self.device());
#endif
// Condition on the input
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt;
// }
disable_addmm_cuda_lt = disable_addmm_cuda_lt || !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation);

at::ScalarType scalar_type = mat1.scalar_type();
bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float;

#ifdef USE_ROCM
disable_addmm_cuda_lt = disable_addmm_cuda_lt || is_float_output_with_half_input;
#endif

bool use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
// for float output with half input cublasLT with bias produces wrong results
use_bias_ptr_lt &= !is_float_output_with_half_input;

// Handle result/self shapes
if (!result.is_same(self)) {
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});

// We use bias ptr in the Lt path only when bias is 1D
const auto use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
if (!use_bias_ptr_lt) {
// We do expand self even before
// check for beta != 0.0 to make sure that
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
// runs green.
return expand_size(self, result.sizes(), "addmm");
}
return c10::MaybeOwned<Tensor>::borrowed(self);
}();
// We do not copy bias only when we need the bias ptr
// We do not copy bias only when we need the bias ptr
if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) {
// NOTE: self should broadcast over result
at::native::copy_(result, *self_maybe_expanded);
at::native::copy_(result, *expand_size(self, result.sizes(), "addmm"));
}
}

Expand Down Expand Up @@ -468,7 +459,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, self, alpha, activation);
lt_success = launchGemmAndBiasCublasLt<scalar_t, float>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
}
);
#endif
Expand All @@ -480,7 +471,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type,
"addmm_cuda_lt",
[&] {
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, self, alpha, activation);
lt_success = launchGemmAndBiasCublasLt<scalar_t>(args, use_bias_ptr_lt ? std::make_optional(self) : std::nullopt, alpha, activation);
}
);
} // end is_float_output_with_half_input
Expand Down Expand Up @@ -936,7 +927,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
return _int_mm_out_cuda(self, mat2, result);
}

static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const at::ScalarType out_dtype, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
// ref ATen/native/LinearAlgebra.cpp common_checks_baddbmm_bmm
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
Expand All @@ -960,23 +951,20 @@ static void baddbmm_bmm_out_dtype_checks(const Tensor& batch1, const Tensor& bat
(out_dtype == at::ScalarType::Float && (batch1.scalar_type() == at::ScalarType::Half || batch1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");

if (!is_bmm && self_baddbmm.has_value()) {
if (self_baddbmm.has_value()) {
const auto& self = self_baddbmm.value();
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
TORCH_CHECK(self.sizes() == output_size, "self must have the same shape as the output");
}
}

Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
IntArrayRef batch1_sizes = batch1.sizes();
IntArrayRef batch2_sizes = batch2.sizes();

Tensor out = at::empty({batch1_sizes[0], batch1_sizes[1], batch2_sizes[2]}, batch1.options().dtype(out_dtype));
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
return _bmm_out_dtype_cuda(batch1, batch2, out_dtype, out);
}

Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype, true);
baddbmm_bmm_out_dtype_checks(batch1, batch2, 0.0, 1.0, out_dtype);
Scalar beta(0.0);
Scalar alpha(1.0);
{
Expand All @@ -988,14 +976,16 @@ Tensor& _bmm_out_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at
}

Tensor _baddbmm_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
// We need to copy the tensor
Tensor out = self.clone().to(self.options().dtype(out_dtype));

return _baddbmm_out_dtype_cuda(out, batch1, batch2, out_dtype, beta, alpha, out);
TORCH_CHECK(self.scalar_type() == out_dtype || self.scalar_type() == batch1.dtype(),
"self dtype must match either out_dtype or batch1 dtype");
Tensor out = at::empty({batch1.size(0), batch1.size(1), batch2.size(2)}, batch1.options().dtype(out_dtype));
return _baddbmm_out_dtype_cuda(self, batch1, batch2, out_dtype, beta, alpha, out);
}

Tensor& _baddbmm_out_dtype_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, false, self);
baddbmm_bmm_out_dtype_checks(batch1, batch2, beta, alpha, out_dtype, out);
// We need to copy the tensor
out.copy_(self);
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(out, out, batch1, batch2, beta, alpha);
Expand Down Expand Up @@ -1030,24 +1020,27 @@ Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::Sca
}

Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
Tensor result = at::empty(self.sizes(), self.options().dtype(out_dtype));
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
Tensor result = at::empty({mat1.size(0), mat2.size(1)}, self.options().dtype(out_dtype));
return _addmm_dtype_out_cuda(self, mat1, mat2, out_dtype, beta, alpha, result);
}

Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
// repeat dimensionality checks for direct calls to `out` overload
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(out_dtype == mat1.scalar_type() ||
(out_dtype == at::ScalarType::Float && (mat1.scalar_type() == at::ScalarType::Half || mat1.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");

TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == self.scalar_type() ||
(out_dtype == at::ScalarType::Float && (self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::BFloat16)),
"out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs");
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
TORCH_CHECK(out_dtype == self.scalar_type() || self.scalar_type() == mat1.scalar_type(),
"self dtype must match either out_dtype or mat1 dtype");

addmm_out_cuda_impl(out, self, mat1, mat2, beta, alpha);

Expand Down
Loading