Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/pytorch/pytorch into sbho…
Browse files Browse the repository at this point in the history
…kare/batch-norm-running-stats
  • Loading branch information
shubhambhokare1 committed Nov 4, 2020
2 parents c8d9893 + 6b3802a commit 9c671ce
Show file tree
Hide file tree
Showing 13 changed files with 195 additions and 165 deletions.
6 changes: 3 additions & 3 deletions .jenkins/caffe2/common.sh
Expand Up @@ -2,9 +2,9 @@ set -ex

LOCAL_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
ROOT_DIR=$(cd "$LOCAL_DIR"/../.. && pwd)
TEST_DIR="$ROOT_DIR/caffe2_tests"
gtest_reports_dir="${TEST_DIR}/cpp"
pytest_reports_dir="${TEST_DIR}/python"
TEST_DIR="$ROOT_DIR/test"
gtest_reports_dir="${TEST_DIR}/test-reports/cpp"
pytest_reports_dir="${TEST_DIR}/test-reports/python"

# Figure out which Python to use
PYTHON="$(which python)"
Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/native/cuda/ForeachFunctors.cuh
Expand Up @@ -163,7 +163,7 @@ struct BinaryOpScalarFunctor {
opmath_t scalar) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int n = tl.numel_for_tensor[tensor_loc];

T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
Expand All @@ -183,7 +183,7 @@ struct BinaryOpScalarListFunctor {
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int n = tl.numel_for_tensor[tensor_loc];

T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
Expand All @@ -205,7 +205,7 @@ struct BinaryOpListAlphaFunctor {
opmath_t alpha) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int n = tl.numel_for_tensor[tensor_loc];

T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
Expand Down Expand Up @@ -254,7 +254,7 @@ struct UnaryOpFunctor {
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int n = tl.numel_for_tensor[tensor_loc];

T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
Expand Down Expand Up @@ -301,7 +301,7 @@ struct PointwiseOpScalarFunctor {
opmath_t scalar) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int n = tl.numel_for_tensor[tensor_loc];

T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
Expand All @@ -321,7 +321,7 @@ struct PointwiseOpScalarListFunctor {
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int n = tl.numel_for_tensor[tensor_loc];

T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
Expand All @@ -342,7 +342,7 @@ struct PointwiseOpListFunctor {
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
int n = tl.numel_for_tensor[tensor_loc];

T* args[depth];
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/cuda/MultiTensorApply.cuh
Expand Up @@ -31,15 +31,15 @@ static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
template<int n> struct TensorListMetadata
{
void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]];
int numel_for_tensor[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]];
};

template<typename scalar_vals_t, int n> struct TensorListScalarListMetadata
{
void* addresses[n][depth_to_max_tensors_scalarlist[n-1]];
int sizes[depth_to_max_tensors_scalarlist[n-1]];
int numel_for_tensor[depth_to_max_tensors_scalarlist[n-1]];
scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]];
Expand Down Expand Up @@ -74,7 +74,7 @@ void multi_tensor_apply(

tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t];

tensorListMeta.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++) {
tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
}
Expand Down Expand Up @@ -105,7 +105,7 @@ void multi_tensor_apply(
loc_tensor_info = 0;
}
else {
tensorListMeta.sizes[0] = tensorListMeta.sizes[loc_tensor_info-1];
tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info-1];
tensorListMeta.scalar_vals[0] = tensorListMeta.scalar_vals[loc_tensor_info-1];
for(int d = 0; d < depth; d++) {
tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1];
Expand All @@ -131,7 +131,7 @@ void multi_tensor_apply(
int loc_block_info = 0;
int loc_tensor_info = 0;
for(size_t t = 0; t < n_tensors; t++) {
tensorListMeta.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++) {
tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
}
Expand Down Expand Up @@ -162,7 +162,7 @@ void multi_tensor_apply(
loc_tensor_info = 0;
}
else {
tensorListMeta.sizes[0] = tensorListMeta.sizes[loc_tensor_info-1];
tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info-1];
for(int d = 0; d < depth; d++) {
tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1];
}
Expand Down
2 changes: 1 addition & 1 deletion scripts/onnx/test.sh
Expand Up @@ -23,7 +23,7 @@ do
done
set -- "${UNKNOWN[@]}" # leave UNKNOWN

pip install pytest scipy hypothesis
pip install pytest scipy hypothesis # these are all already satisfied in CI

if [[ $PARALLEL == 1 ]]; then
pip install pytest-xdist
Expand Down
24 changes: 24 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -1035,6 +1035,30 @@ def forward(self, x):
# quantize, should run with no errors
quantized = convert_fx(prepared_copy)

def test_dequantize(self):
r""" Test to make sure dequantize node are placed before
non-quantizable node
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.act = torch.nn.GELU()

def forward(self, x):
x = self.conv(x)
return self.act(x)

data = torch.rand(5, 1, 3, 3, dtype=torch.float)
for quant_type in self.static_quant_types:
node_list = [
ns.call_module(nnq.Conv2d),
ns.call_method("dequantize"),
ns.call_module(nn.GELU),
]
self.checkGraphModeFxOp(
M().eval(), (data,), quant_type, expected_node_list=node_list)

@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
Expand Down
128 changes: 51 additions & 77 deletions test/test_foreach.py
Expand Up @@ -703,56 +703,57 @@ def test_bin_op_scalar_with_different_tensor_dtypes(self, device):
#
# Ops with list
#
def test_add_list_error_cases(self, device):
tensors1 = []
tensors2 = []

# Empty lists
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
torch._foreach_add(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
torch._foreach_add_(tensors1, tensors2)

# One empty list
tensors1.append(torch.tensor([1], device=device))
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
torch._foreach_add(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
torch._foreach_add_(tensors1, tensors2)

# Lists have different amount of tensors
tensors2.append(torch.tensor([1], device=device))
tensors2.append(torch.tensor([1], device=device))
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
torch._foreach_add(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
torch._foreach_add_(tensors1, tensors2)

# Different dtypes
tensors1 = [torch.zeros(10, 10, device=device, dtype=torch.float) for _ in range(10)]
tensors2 = [torch.ones(10, 10, device=device, dtype=torch.int) for _ in range(10)]

with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."):
torch._foreach_add(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."):
torch._foreach_add_(tensors1, tensors2)

# different devices
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
tensor1 = torch.zeros(10, 10, device="cuda:0")
tensor2 = torch.ones(10, 10, device="cuda:1")
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
torch._foreach_add([tensor1], [tensor2])
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
torch._foreach_add_([tensor1], [tensor2])

# Coresponding tensors with different sizes
tensors1 = [torch.zeros(10, 10, device=device) for _ in range(10)]
tensors2 = [torch.ones(11, 11, device=device) for _ in range(10)]
with self.assertRaisesRegex(RuntimeError, "Corresponding tensors in lists must have the same size"):
torch._foreach_add(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, r", got \[10, 10\] and \[11, 11\]"):
torch._foreach_add_(tensors1, tensors2)
def test_bin_op_list_error_cases(self, device):
for bin_op, bin_op_ in zip(self.foreach_bin_ops, self.foreach_bin_ops_):
tensors1 = []
tensors2 = []

# Empty lists
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
bin_op(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
bin_op_(tensors1, tensors2)

# One empty list
tensors1.append(torch.tensor([1], device=device))
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
bin_op(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
bin_op_(tensors1, tensors2)

# Lists have different amount of tensors
tensors2.append(torch.tensor([1], device=device))
tensors2.append(torch.tensor([1], device=device))
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
bin_op(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
bin_op_(tensors1, tensors2)

# Different dtypes
tensors1 = [torch.zeros(10, 10, device=device, dtype=torch.float) for _ in range(10)]
tensors2 = [torch.ones(10, 10, device=device, dtype=torch.int) for _ in range(10)]

with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."):
bin_op(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, "All tensors in the tensor list must have the same dtype."):
bin_op_(tensors1, tensors2)

# different devices
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
tensor1 = torch.zeros(10, 10, device="cuda:0")
tensor2 = torch.ones(10, 10, device="cuda:1")
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
bin_op([tensor1], [tensor2])
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
bin_op_([tensor1], [tensor2])

# Corresponding tensors with different sizes
tensors1 = [torch.zeros(10, 10, device=device) for _ in range(10)]
tensors2 = [torch.ones(11, 11, device=device) for _ in range(10)]
with self.assertRaisesRegex(RuntimeError, "Corresponding tensors in lists must have the same size"):
bin_op(tensors1, tensors2)
with self.assertRaisesRegex(RuntimeError, r", got \[10, 10\] and \[11, 11\]"):
bin_op_(tensors1, tensors2)

@dtypes(*torch.testing.get_all_dtypes())
def test_add_list(self, device, dtype):
Expand Down Expand Up @@ -799,33 +800,6 @@ def test_div_list(self, device, dtype):
self.assertEqual(res, tensors1)
self.assertEqual(tensors1, res)

def test_bin_op_list_error_cases(self, device):
tensors1 = []
tensors2 = []

for bin_op in self.foreach_bin_ops + self.foreach_bin_ops_:
# Empty lists
with self.assertRaises(RuntimeError):
bin_op(tensors1, tensors2)

# One empty list
tensors1.append(torch.tensor([1], device=device))
with self.assertRaises(RuntimeError):
bin_op(tensors1, tensors2)

# Lists have different amount of tensors
tensors2.append(torch.tensor([1], device=device))
tensors2.append(torch.tensor([1], device=device))
with self.assertRaises(RuntimeError):
bin_op(tensors1, tensors2)

# Different dtypes
tensors1 = [torch.zeros(2, 2, device=device, dtype=torch.float) for _ in range(2)]
tensors2 = [torch.ones(2, 2, device=device, dtype=torch.int) for _ in range(2)]

with self.assertRaises(RuntimeError):
bin_op(tensors1, tensors2)

@dtypes(*torch.testing.get_all_dtypes())
def test_add_list_different_sizes(self, device, dtype):
tensors1 = [torch.zeros(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)]
Expand Down
5 changes: 4 additions & 1 deletion torch/_torch_docs.py
Expand Up @@ -6566,20 +6566,23 @@ def merge_dicts(*dicts):

add_docstr(torch.randperm,
r"""
randperm(n, *, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False) -> LongTensor
randperm(n, \*, generator=None, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False,
pin_memory=False) -> LongTensor
Returns a random permutation of integers from ``0`` to ``n - 1``.
Args:
n (int): the upper bound (exclusive)
Keyword args:
{generator}
{out}
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: ``torch.int64``.
{layout}
{device}
{requires_grad}
{pin_memory}
Example::
Expand Down

0 comments on commit 9c671ce

Please sign in to comment.