Skip to content

Commit

Permalink
Merge branch 'master' of github.com:pytorch/pytorch into sparsedocs1
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch committed Sep 13, 2022
2 parents 2d172e7 + 53c71e2 commit 534d929
Show file tree
Hide file tree
Showing 142 changed files with 4,055 additions and 3,798 deletions.
2 changes: 1 addition & 1 deletion .circleci/docker/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ docker build \
--build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \
--build-arg "KATEX=${KATEX:-}" \
--build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx900;gfx906}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx906}" \
--build-arg "IMAGE_NAME=${IMAGE_NAME}" \
--build-arg "UCX_COMMIT=${UCX_COMMIT}" \
--build-arg "UCC_COMMIT=${UCC_COMMIT}" \
Expand Down
8 changes: 1 addition & 7 deletions .circleci/docker/common/install_cudnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@ if [[ ${CUDNN_VERSION} == 8 ]]; then
# cuDNN license: https://developer.nvidia.com/cudnn/license_agreement
mkdir tmp_cudnn && cd tmp_cudnn
CUDNN_NAME="cudnn-linux-x86_64-8.3.2.44_cuda11.5-archive"
if [[ ${CUDA_VERSION:0:4} == "11.7" ]]; then
CUDNN_NAME="cudnn-linux-x86_64-8.5.0.96_cuda11-archive"
curl -OLs https://ossci-linux.s3.amazonaws.com/${CUDNN_NAME}.tar.xz
else
curl -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.3.2/local_installers/11.5/${CUDNN_NAME}.tar.xz
fi

curl -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.3.2/local_installers/11.5/${CUDNN_NAME}.tar.xz
tar xf ${CUDNN_NAME}.tar.xz
cp -a ${CUDNN_NAME}/include/* /usr/include/
cp -a ${CUDNN_NAME}/include/* /usr/local/cuda/include/
Expand Down
1 change: 0 additions & 1 deletion .circleci/docker/ubuntu-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm

# Install CUDNN
ARG CUDNN_VERSION
ARG CUDA_VERSION
COPY ./common/install_cudnn.sh install_cudnn.sh
RUN if [ "${CUDNN_VERSION}" -eq 8 ]; then bash install_cudnn.sh; fi
RUN rm install_cudnn.sh
Expand Down
2 changes: 1 addition & 1 deletion .circleci/scripts/windows_cudnn_install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ case ${CUDA_VERSION} in
;;
11.7)
# Use cudnn8.3 with hard-coded cuda11.5 version
cudnn_file_name="cudnn-windows-x86_64-8.5.0.96_cuda11-archive"
cudnn_file_name="cudnn-windows-x86_64-8.3.2.44_cuda11.5-archive"
;;
*)
echo "CUDA_VERSION: ${CUDA_VERSION} not supported yet"
Expand Down
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/vision.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
a67cc87a33a3f713aebf5299bdeb2672c98e0bc5
a89b1957a62e2f68f001d5d60268743edbe164d8
2 changes: 2 additions & 0 deletions .github/workflows/_linux-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ jobs:
NUM_TEST_SHARDS: ${{ matrix.num_shards }}
PR_BODY: ${{ github.event.pull_request.body }}
SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2
SCCACHE_S3_KEY_PREFIX: ${{ github.workflow }}
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
DOCKER_IMAGE: ${{ inputs.docker-image }}
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
Expand Down Expand Up @@ -171,6 +172,7 @@ jobs:
-e PR_LABELS \
-e MAX_JOBS="$(nproc --ignore=2)" \
-e SCCACHE_BUCKET \
-e SCCACHE_S3_KEY_PREFIX \
-e XLA_CUDA \
-e XLA_CLANG_CACHE_S3_BUCKET_NAME \
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
Expand Down
6 changes: 3 additions & 3 deletions .jenkins/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
fi

if [[ -n "$CI" && -z "$PYTORCH_ROCM_ARCH" ]]; then
# Set ROCM_ARCH to gfx900 and gfx906 for CI builds, if user doesn't override.
echo "Limiting PYTORCH_ROCM_ARCH to gfx90[06] for CI builds"
export PYTORCH_ROCM_ARCH="gfx900;gfx906"
# Set ROCM_ARCH to gfx906 for CI builds, if user doesn't override.
echo "Limiting PYTORCH_ROCM_ARCH to gfx906 for CI builds"
export PYTORCH_ROCM_ARCH="gfx906"
fi

# hipify sources
Expand Down
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ exclude_patterns = [
'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h',
'test/cpp/jit/upgrader_models/*.ptl',
'test/cpp/jit/upgrader_models/*.ptl.ff',
'cmake/External/nccl.patch',
]
command = [
'python3',
Expand Down Expand Up @@ -347,6 +348,7 @@ exclude_patterns = [
'test/cpp/jit/upgrader_models/*.ptl',
'test/cpp/jit/upgrader_models/*.ptl.ff',
'.lintrunner.toml',
'cmake/External/nccl.patch',
]
command = [
'python3',
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/FunctionalInverses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ Tensor FunctionalInverses::transpose_copy_int_inverse(const Tensor& base, const
}
}

Tensor FunctionalInverses::_nested_view_from_buffer_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, const Tensor& nested_size_tensor, const Tensor& nested_stride_tensor, IntArrayRef offsets) {
TORCH_INTERNAL_ASSERT(false, "Attempted to call _nested_view_from_buffer() during the functionalization pass. For now, nested tensors aren't supported during functionalization");
return Tensor();
}

Tensor FunctionalInverses::unsqueeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim) {
if (reapply_views) {
return at::squeeze(mutated_view, dim);
Expand Down
21 changes: 18 additions & 3 deletions aten/src/ATen/InferSize.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <ATen/DimVector.h>
#include <c10/core/ScalarType.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/DimVector.h>
#include <c10/util/Optional.h>
#include <sstream>
#include <vector>
Expand All @@ -14,9 +16,13 @@ namespace at {
// templated to handle std::vector<int64_t> and DimVector use cases, see
// below
//
template <typename ResultVec>
inline void infer_size_impl(IntArrayRef shape, int64_t numel, ResultVec& res) {
int64_t newsize = 1;
template <typename InputArrayRef, typename NumelType, typename ResultVec>
inline void infer_size_impl(
InputArrayRef shape,
NumelType numel,
ResultVec& res) {
NumelType newsize = 1;
// N.B. this is an index, not a sym dim!
auto infer_dim = c10::optional<int64_t>();
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
if (shape[dim] == -1) {
Expand Down Expand Up @@ -69,4 +75,13 @@ inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
return res;
}

inline at::SymDimVector infer_size_dv(
c10::SymIntArrayRef shape,
c10::SymInt numel) {
auto res = at::SymDimVector(shape);
infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
shape, numel, res);
return res;
}

} // namespace at
59 changes: 43 additions & 16 deletions aten/src/ATen/NestedTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/NestedTensorImpl.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/util/Exception.h>
#include <c10/core/TensorImpl.h>

Expand All @@ -25,6 +26,46 @@ inline void validate_nested_tensor_metadata(
(size_dim == 0 && (int64_t)offsets.empty()) ||
(size_dim == 2 && nested_sizes.size(0) == (int64_t)offsets.size()));
}

/**
* Generates a nested key_set from a non-nested tensor.
*
* When creating a nested tensor from a non-nested tensor
* We want to maintain the same keyset as the buffer but
* swap non nested keys for nested ones
*
* @return Appropriate key set for nested tensor
*/
inline c10::DispatchKeySet generate_nested_key_set_from_buffer(
const at::Tensor& buffer) {
auto nested_key_set = buffer.key_set();
const bool has_autograd = nested_key_set.has_any(c10::autograd_dispatch_keyset);
// Remove non_nested tensor specific keys
nested_key_set = nested_key_set -
c10::DispatchKeySet{c10::DispatchKey::Dense, c10::DispatchKey::Autograd};

// Add nested tensor specific keys
nested_key_set =
nested_key_set | c10::DispatchKeySet{c10::DispatchKey::NestedTensor};
nested_key_set =
has_autograd ? nested_key_set | c10::autograd_nested : nested_key_set;
return nested_key_set;
}

/**
* Generates a the correct view keyset.
*
* When creating a nested tensor view of base
* The appropriate keyset will be dependent on the nested
* status of the base
*
* @return Appropriate key set for nested tensor
*/
c10::DispatchKeySet get_view_key_set(const at::Tensor& base) {
return base.is_nested() ? base.key_set()
: generate_nested_key_set_from_buffer(base);
}

} // namespace
namespace at {
namespace native {
Expand Down Expand Up @@ -119,19 +160,6 @@ inline std::vector<int64_t> construct_offsets(const at::Tensor& sizes) {
return offsets;
}

// [Note: Nested Tensor Autograd] The Nested Tensor key is a functionality
// key and therefore getAutogradRelatedKeySetFromBackend will return the
// wrong autograd key. For this specific impl we make sure to register the
// correct Autograd key which is AutogradNestedTensor
c10::DispatchKeySet generate_nested_key_set(at::Tensor buffer) {
c10::DispatchKeySet key_set =
c10::DispatchKeySet(DispatchKey::NestedTensor) | c10::DispatchKeySet{buffer.key_set().highestBackendKey()};

// Add AutogradNestedTensor specific keys
key_set = key_set | inplace_or_view_ks | autograd_nested;
return key_set;
}

NestedTensorImpl::NestedTensorImpl(
Storage storage,
c10::DispatchKeySet key_set,
Expand Down Expand Up @@ -164,7 +192,7 @@ NestedTensorImpl::NestedTensorImpl(
std::vector<int64_t>&& offsets)
: NestedTensorImpl(
buffer.storage(),
generate_nested_key_set(buffer),
generate_nested_key_set_from_buffer(buffer),
buffer.dtype(),
nested_size_tensor,
nested_stride_tensor,
Expand Down Expand Up @@ -195,12 +223,11 @@ NestedTensorImpl::NestedTensorImpl(
at::Tensor nested_size_tensor,
at::Tensor nested_stride_tensor,
std::vector<int64_t>&& offsets)
: TensorImpl(impl_type, Storage(base_tensor.storage()), base_tensor.key_set(), base_tensor.dtype()),
: TensorImpl(impl_type, Storage(base_tensor.storage()), get_view_key_set(base_tensor), base_tensor.dtype()),
nested_size_tensor_(std::move(nested_size_tensor)),
nested_stride_tensor_(std::move(nested_stride_tensor)),
offsets_(std::move(offsets)),
opt_sizes_(construct_opt_sizes(nested_size_tensor_)) {
TORCH_INTERNAL_ASSERT(base_tensor.is_nested());
validate_nested_tensor_metadata(nested_size_tensor_, nested_stride_tensor_, offsets_);
refresh_dim();
set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/NestedTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
* is generated and redispatched to a non-nested kernel this function
* generates the key set used by that buffer tensor
*
* @return A newly constructed view tensor
* @return Appropriate key set for non-nested tensor
*/
inline c10::DispatchKeySet generate_buffer_key_set() const {
auto buffer_key_set = this->key_set();
Expand All @@ -184,6 +184,7 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
buffer_key_set = Autograd
? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
: buffer_key_set;

return buffer_key_set;
}
};
Expand Down
30 changes: 19 additions & 11 deletions aten/src/ATen/TensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,12 @@ std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
// templatized for DimVector and IntArrayRef use cases,
// see overloads of computeStride() below.
//
template <typename ResultVec, typename NewShapeVec>
template <typename ResultVec, typename NewShapeVec, typename Numel>
inline c10::optional<ResultVec> computeStride_impl(
IntArrayRef oldshape,
IntArrayRef oldstride,
const NewShapeVec& oldshape,
const NewShapeVec& oldstride,
const NewShapeVec& newshape,
ResultVec toResult(const IntArrayRef&)
ResultVec toResult(const NewShapeVec&)
) {
if (oldshape.empty()) {
return ResultVec(newshape.size(), 1);
Expand All @@ -326,7 +326,7 @@ inline c10::optional<ResultVec> computeStride_impl(
// we use the stride as if it were computed via resize.
// This could perhaps be combined with the below code, but the complexity
// didn't seem worth it.
const int64_t numel = c10::multiply_integers(oldshape);
const Numel numel = c10::multiply_integers(oldshape);
if (numel == 0 && oldshape.equals(newshape)) {
return toResult(oldstride);
}
Expand All @@ -338,18 +338,18 @@ inline c10::optional<ResultVec> computeStride_impl(
newstride[view_d] = 1;
} else {
newstride[view_d] =
std::max<int64_t>(newshape[view_d+1], 1) * newstride[view_d+1];
std::max<Numel>(newshape[view_d+1], Numel(1)) * newstride[view_d+1];
}
}
return newstride;
}

int64_t view_d = (int64_t)newshape.size() - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = oldstride.back();
Numel chunk_base_stride = oldstride.back();
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
Numel tensor_numel = 1;
Numel view_numel = 1;
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= oldshape[tensor_d];
// if end of tensor size chunk, check view
Expand Down Expand Up @@ -383,15 +383,23 @@ c10::optional<std::vector<int64_t>> computeStride(
IntArrayRef oldstride,
IntArrayRef newshape) {
auto toResult = [](const IntArrayRef& a) { return a.vec(); };
return computeStride_impl<std::vector<int64_t>, IntArrayRef>(oldshape, oldstride, newshape, toResult);
return computeStride_impl<std::vector<int64_t>, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
}

c10::optional<SymDimVector> computeStride(
c10::SymIntArrayRef oldshape,
c10::SymIntArrayRef oldstride,
c10::SymIntArrayRef newshape) {
auto toResult = [](const SymIntArrayRef& a) { return SymDimVector(a); };
return computeStride_impl<SymDimVector, c10::SymIntArrayRef, c10::SymInt>(oldshape, oldstride, newshape, toResult);
}

c10::optional<DimVector> computeStride(
IntArrayRef oldshape,
IntArrayRef oldstride,
const DimVector& newshape) {
auto toResult = [](const IntArrayRef& a) { return DimVector(a); };
return computeStride_impl<DimVector, DimVector>(oldshape, oldstride, newshape, toResult);
return computeStride_impl<DimVector, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
}

} // namespace detail
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/TensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ TORCH_API c10::optional<std::vector<int64_t>> computeStride(
IntArrayRef oldstride,
IntArrayRef newshape);

TORCH_API c10::optional<SymDimVector> computeStride(
c10::SymIntArrayRef oldshape,
c10::SymIntArrayRef oldstride,
c10::SymIntArrayRef newshape);

TORCH_API c10::optional<DimVector> computeStride(
IntArrayRef oldshape,
IntArrayRef oldstride,
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// to let the original CompositeImplicitAutograd handle Undefined
if (dispatch_key != DispatchKey::Undefined && isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutogradNestedTensor)) {
if (auto nested_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutogradNestedTensor)) {
if (!has_backend_kernel) {
return {*nested_registration, "nested kernel"};
return {*nested_registration, "nested kernel"};
}
}
}

if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) {
Expand Down
8 changes: 3 additions & 5 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,11 +709,9 @@ void gemm_and_bias(
CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);

CuBlasLtMatmulPreference preference;
// See https://github.com/pytorch/pytorch/issues/73328.
// Check https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace .
// Recommended size of user-provided workspace is at least 4MiB (to match
// cuBLAS' default workspace pool).
size_t workspaceSize = 4 * 1024 * 1024;
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(
preference.descriptor(),
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSFallback.mm
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)

void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack)
{
TORCH_CHECK_NOT_IMPLEMENTED(false, "The operator '", op.schema().operator_name(), "' is not current implemented ",
TORCH_CHECK_NOT_IMPLEMENTED(false, "The operator '", op.schema().operator_name(), "' is not currently implemented ",
"for the MPS device. If you want this op to be added in priority during the prototype ",
"phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. ",
"As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ",
Expand Down

0 comments on commit 534d929

Please sign in to comment.