Skip to content

Commit

Permalink
Update on "[FX][2/2] Make docstrings pretty when rendered"
Browse files Browse the repository at this point in the history
Differential Revision: [D25351588](https://our.internmc.facebook.com/intern/diff/D25351588)

[ghstack-poisoned]
  • Loading branch information
James Reed committed Dec 8, 2020
2 parents 21881ff + 7ab9258 commit 594f1ec
Show file tree
Hide file tree
Showing 55 changed files with 1,093 additions and 697 deletions.
2 changes: 2 additions & 0 deletions .jenkins/pytorch/codegen-test.sh
Expand Up @@ -38,6 +38,8 @@ mkdir -p "$OUT"/pyi/torch/_C
mkdir -p "$OUT"/pyi/torch/nn
python -m tools.pyi.gen_pyi \
--declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \
--native-functions-path aten/src/ATen/native/native_functions.yaml \
--deprecated-functions-path tools/autograd/deprecated.yaml \
--out "$OUT"/pyi

# autograd codegen (called by torch codegen but can run independently)
Expand Down
20 changes: 10 additions & 10 deletions android/test_app/app/build.gradle
Expand Up @@ -60,20 +60,20 @@ android {
//}
flavorDimensions "model", "build", "activity"
productFlavors {
mbq {
mnet {
dimension "model"
applicationIdSuffix ".mbq"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"")
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbq\"")
applicationIdSuffix ".mnet"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet.pt\"")
addManifestPlaceholders([APP_NAME: "MNET"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
}
mbvulkan {
mnetVulkan {
dimension "model"
applicationIdSuffix ".mbvulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2-vulkan.pt\"")
applicationIdSuffix ".mnet_vulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mnet_vulkan.pt\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MBQ"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mbvulkan\"")
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
}
resnet18 {
dimension "model"
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/BatchingRegistrations.cpp
Expand Up @@ -941,8 +941,8 @@ Tensor new_empty_strided_batching_rule(
size.size(), ") must match dimensionality of strides (",
stride.size(), ")");
auto storage_size = native::storage_size_for(size, stride);
for (auto& physical_stride : physical_strides) {
physical_stride *= storage_size;
for (int64_t idx = 0; idx < physical_strides.size(); ++idx) {
physical_strides[idx] *= storage_size;
}

// physical_strides = [B1 * B2 * S, B2 * S, S] + strides
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/NamedTensorUtils.cpp
Expand Up @@ -264,11 +264,11 @@ static std::vector<Dimname> compute_dot_product_outnames(
}
std::vector<Dimname> outnames(num_outnames, Dimname::wildcard());
int64_t index = 0;
for (size_t j = 0; j < tensor_names.size(); ++j) {
for (int64_t j = 0; j < tensor_names.size(); ++j) {
if (j == tensor_dotted_dim) continue;
outnames[index++] = tensor_names[j];
}
for (size_t j = 0; j < other_names.size(); ++j) {
for (int64_t j = 0; j < other_names.size(); ++j) {
if (j == other_dotted_dim) continue;
outnames[index++] = other_names[j];
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/TensorIterator.cpp
Expand Up @@ -939,8 +939,8 @@ TensorIterator TensorIterator::reduce_op(Tensor& out1, Tensor& out2, const Tenso
}

void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) {
for (auto& tensor: config.tensors_) {
operands_.emplace_back(std::move(tensor));
for (int i = 0; i < config.tensors_.size(); i++) {
operands_.emplace_back(std::move(config.tensors_[i]));
}
num_outputs_ = config.num_outputs_;
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/TensorNames.cpp
Expand Up @@ -61,10 +61,10 @@ TensorNames::TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end) {
}

TensorNames& TensorNames::unifyFromRightInplace(const TensorNames& other, const char* op_name) {
size_t size_diff = std::labs(names_.size() - other.names_.size());
int64_t size_diff = std::labs(names_.size() - other.names_.size());

if (names_.size() > other.names_.size()) {
for (size_t idx = size_diff; idx < names_.size(); ++idx) {
for (int64_t idx = size_diff; idx < names_.size(); ++idx) {
names_[idx] = names_[idx].unify(other.names_[idx - size_diff], op_name);
}
} else {
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/cuda/CUDASolver.cpp
Expand Up @@ -46,14 +46,14 @@ void getrf<c10::complex<double>>(
TORCH_CUSOLVER_CHECK(cusolverDnZgetrf_bufferSize(
handle, m, n, reinterpret_cast<cuDoubleComplex*>(dA), ldda, &lwork));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
void* buffer = allocator.allocate(sizeof(cuDoubleComplex) * lwork).get();
auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex) * lwork);
TORCH_CUSOLVER_CHECK(cusolverDnZgetrf(
handle,
m,
n,
reinterpret_cast<cuDoubleComplex*>(dA),
ldda,
static_cast<cuDoubleComplex*>(buffer),
static_cast<cuDoubleComplex*>(dataPtr.get()),
ipiv,
info));
}
Expand All @@ -71,14 +71,14 @@ void getrf<c10::complex<float>>(
TORCH_CUSOLVER_CHECK(cusolverDnCgetrf_bufferSize(
handle, m, n, reinterpret_cast<cuComplex*>(dA), ldda, &lwork));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
void* buffer = allocator.allocate(sizeof(cuComplex) * lwork).get();
auto dataPtr = allocator.allocate(sizeof(cuComplex) * lwork);
TORCH_CUSOLVER_CHECK(cusolverDnCgetrf(
handle,
m,
n,
reinterpret_cast<cuComplex*>(dA),
ldda,
static_cast<cuComplex*>(buffer),
static_cast<cuComplex*>(dataPtr.get()),
ipiv,
info));
}
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/Convolution.cpp
Expand Up @@ -177,13 +177,13 @@ auto ConvParams::needs_64bit_indexing_no_split(const at::Tensor& input, const at
int64_t outsize = 1;
if (transposed) {
std::vector<int64_t> o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups);
for (const auto& e: o) {
outsize *= e;
for (int64_t i = 1; i < o.size(); i++) {
outsize *= o[i];
}
} else {
std::vector<int64_t> o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation);
for (const auto& e: o) {
outsize *= e;
for (int64_t i = 1; i < o.size(); i++) {
outsize *= o[i];
}
}
return outsize > int_max;
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/native/ForeachOpsKernels.cpp
Expand Up @@ -188,6 +188,9 @@ FOREACH_UNARY_OP(sinh);
FOREACH_UNARY_OP(round);
FOREACH_UNARY_OP(lgamma);
FOREACH_UNARY_OP(frac);
FOREACH_UNARY_OP(trunc);
FOREACH_UNARY_OP(reciprocal);
FOREACH_UNARY_OP(sigmoid);

FOREACH_POINTWISE_OP_SCALAR(addcdiv);
FOREACH_POINTWISE_OP_SCALAR(addcmul);
Expand All @@ -201,7 +204,7 @@ std::vector<Tensor> foreach_tensor_##NAME##_slow(TensorList tensors1, TensorList
\
std::vector<Tensor> result; \
result.reserve(tensors1.size()); \
for (size_t i = 0; i < tensors1.size(); i++) { \
for (int i = 0; i < tensors1.size(); i++) { \
result.emplace_back(at::NAME(tensors1[i], tensors2[i])); \
} \
\
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -343,8 +343,8 @@ Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float
Tensor cos(const Tensor& self) { return unary_op_impl_float(self, cos_stub); }
Tensor& cos_(Tensor& self) { return unary_op_impl_(self, at::cos_out); }

Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sinh_stub); }
Tensor sinh(const Tensor& self) { return unary_op_impl(self, at::sinh_out); }
Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, sinh_stub); }
Tensor sinh(const Tensor& self) { return unary_op_impl_float(self, sinh_stub); }
Tensor& sinh_(Tensor& self) { return unary_op_impl_(self, at::sinh_out); }

Tensor& cosh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cosh_stub); }
Expand Down
47 changes: 45 additions & 2 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -535,6 +535,28 @@ void magmaCholeskySolve<float>(
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaCholeskySolve<c10::complex<double>>(
magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, c10::complex<double>* dA, magma_int_t ldda,
c10::complex<double>* dB, magma_int_t lddb, magma_int_t* info) {
MagmaStreamSyncGuard guard;
magma_zpotrs_gpu(uplo, n, nrhs,
reinterpret_cast<magmaDoubleComplex*>(dA), ldda,
reinterpret_cast<magmaDoubleComplex*>(dB), lddb, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaCholeskySolve<c10::complex<float>>(
magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, c10::complex<float>* dA, magma_int_t ldda,
c10::complex<float>* dB, magma_int_t lddb, magma_int_t* info) {
MagmaStreamSyncGuard guard;
magma_cpotrs_gpu(uplo, n, nrhs,
reinterpret_cast<magmaFloatComplex*>(dA), ldda,
reinterpret_cast<magmaFloatComplex*>(dB), lddb, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaCholeskySolveBatched<double>(
magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
Expand All @@ -551,6 +573,26 @@ void magmaCholeskySolveBatched<float>(
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaCholeskySolveBatched<c10::complex<double>>(
magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, c10::complex<double>** dA_array, magma_int_t ldda,
c10::complex<double>** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
info = magma_zpotrs_batched(uplo, n, nrhs,
reinterpret_cast<magmaDoubleComplex**>(dA_array), ldda,
reinterpret_cast<magmaDoubleComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaCholeskySolveBatched<c10::complex<float>>(
magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, c10::complex<float>** dA_array, magma_int_t ldda,
c10::complex<float>** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
info = magma_cpotrs_batched(uplo, n, nrhs,
reinterpret_cast<magmaFloatComplex**>(dA_array), ldda,
reinterpret_cast<magmaFloatComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaCholesky<double>(
magma_uplo_t uplo, magma_int_t n, double* dA,
Expand Down Expand Up @@ -1376,7 +1418,7 @@ Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upp
int64_t info = 0;
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "cholesky_solve_cuda", [&]{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_solve_cuda", [&]{
apply_cholesky_solve<scalar_t>(self_working_copy, A_working_copy, upper, info);
});
TORCH_CHECK(info == 0, "MAGMA cholesky_solve : invalid argument: ", -info);
Expand Down Expand Up @@ -1418,10 +1460,11 @@ AT_ERROR("cholesky: MAGMA library not found in "

MAGMAQueue magma_queue(self.get_device());

constexpr int64_t batch_limit = 262140;
int64_t batch_limit = self.is_complex() ? 65535 : 262140;
// Compute as many batches of 262140 possible
// 262140 is the size of the largest batch of matrices that can be run with
// violating maximum kernel configuration
// For complex input the batch limit is 65535 (determined experimentally, see https://github.com/pytorch/pytorch/pull/47047#discussion_r516086923 for more information)
// The number of "mini"-batches are floor(batch_size / batch_limit)
// and these cover floor(batch_size / batch_limit) * batch_limit cholesky calls
int64_t mini_batches = batch_size / batch_limit, mini_idx;
Expand Down

0 comments on commit 594f1ec

Please sign in to comment.