Skip to content

Commit

Permalink
Merge branch 'main' into export-D48927532
Browse files Browse the repository at this point in the history
  • Loading branch information
muchulee8 committed Sep 7, 2023
2 parents 9077584 + c458fa0 commit 1868687
Show file tree
Hide file tree
Showing 42 changed files with 482 additions and 182 deletions.
2 changes: 2 additions & 0 deletions .devcontainer/scripts/install-dev-tools.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ make setup_lint

# Add CMAKE_PREFIX_PATH to bashrc
echo 'export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}' >> ~/.bashrc
# Add linker path so that cuda-related libraries can be found
echo 'export LDFLAGS="-L${CONDA_PREFIX}/lib/ $LDFLAGS"' >> ~/.bashrc
41 changes: 19 additions & 22 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,15 +695,15 @@ Tensor sparse_compressed_to_dense(

// Computes the strides for view_dtype output when the view dtype is
// smaller than the original dtype
inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
inline SymDimVector compute_strides_for_view_dtype_downsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
const int64_t ndim = old_strides.size();

TORCH_CHECK(
old_strides[ndim - 1] == 1,
"self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
" (different element sizes), but got ", old_strides[ndim - 1]);

DimVector new_strides(ndim);
SymDimVector new_strides(ndim);
for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
new_strides[dim_idx] = old_strides[dim_idx] * size_ratio;
}
Expand All @@ -713,14 +713,14 @@ inline DimVector compute_strides_for_view_dtype_downsize(IntArrayRef old_strides

// Computes the strides for view_dtype output when the view dtype is
// larger than the original dtype
inline DimVector compute_strides_for_view_dtype_upsize(IntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
inline SymDimVector compute_strides_for_view_dtype_upsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
const int64_t ndim = old_strides.size();
TORCH_CHECK(
old_strides[ndim - 1] == 1,
"self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
" (different element sizes), but got ", old_strides[ndim - 1]);

DimVector new_strides(ndim);
SymDimVector new_strides(ndim);
for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
TORCH_CHECK(
(old_strides[dim_idx] % size_ratio) == 0,
Expand Down Expand Up @@ -753,8 +753,7 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) {
auto* impl = new_tensor.unsafeGetTensorImpl();

if (self_element_size == new_element_size) {
impl->set_storage_offset(self.storage_offset());
impl->set_sizes_and_strides(self.sizes(), self.strides());
impl->set_sizes_and_strides(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());

} else if (self.dim() == 0) {
TORCH_CHECK(false,
Expand All @@ -766,47 +765,45 @@ Tensor view_dtype(const Tensor& self, ScalarType dtype) {

int64_t size_ratio = self_element_size / new_element_size;
auto new_strides = compute_strides_for_view_dtype_downsize(
self.strides(), size_ratio, self.scalar_type(), dtype);
self.sym_strides(), size_ratio, self.scalar_type(), dtype);

auto old_sizes = self.sizes();
DimVector new_sizes(self.dim());
auto old_sizes = self.sym_sizes();
SymDimVector new_sizes(self.dim());
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
new_sizes[self.dim() - 1] *= size_ratio;

auto new_storage_offset = size_ratio * self.storage_offset();
auto new_storage_offset = size_ratio * self.sym_storage_offset();

impl->set_storage_offset(new_storage_offset);
impl->set_sizes_and_strides(new_sizes, new_strides);
impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);

} else {
// Upsizing element size

int64_t size_ratio = new_element_size / self_element_size;

TORCH_CHECK(
(self.size(-1) % size_ratio) == 0,
(self.sym_size(-1) % size_ratio) == 0,
"self.size(-1) must be divisible by ", size_ratio, " to view ",
self.scalar_type(), " as ", dtype, " (different element sizes), ",
"but got ", self.size(-1));
"but got ", self.sym_size(-1));

TORCH_CHECK(
(self.storage_offset() % size_ratio) == 0,
(self.sym_storage_offset() % size_ratio) == 0,
"self.storage_offset() must be divisible by ", size_ratio, " to view ",
self.scalar_type(), " as ", dtype, " (different element sizes), but got ",
self.storage_offset());
self.sym_storage_offset());

auto new_strides = compute_strides_for_view_dtype_upsize(
self.strides(), size_ratio, self.scalar_type(), dtype);
self.sym_strides(), size_ratio, self.scalar_type(), dtype);

auto old_sizes = self.sizes();
DimVector new_sizes(self.dim());
auto old_sizes = self.sym_sizes();
SymDimVector new_sizes(self.dim());
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
new_sizes[self.dim() - 1] /= size_ratio;

auto new_storage_offset = self.storage_offset() / size_ratio;
auto new_storage_offset = self.sym_storage_offset() / size_ratio;

impl->set_storage_offset(new_storage_offset);
impl->set_sizes_and_strides(new_sizes, new_strides);
impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);
}

return new_tensor;
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,8 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
const c10::optional<at::Tensor>& scale_result,
Tensor& out, Tensor& amax) {
// Check sizes
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major >= 9, "torch._scaled_mm is only supported on devices with compute capability >= 9.0)");
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
Expand Down
16 changes: 2 additions & 14 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from torch._dynamo.utils import clone_inputs, graph_break_reasons
from torch._functorch.aot_autograd import set_model_name
from torch._inductor import config as inductor_config
from torch._inductor.utils import fresh_inductor_cache
from torch._inductor.utils import aot_inductor_launcher, fresh_inductor_cache
from torch._subclasses.fake_tensor import FakeTensorMode

from torch.utils import _pytree as pytree
Expand Down Expand Up @@ -1146,21 +1146,9 @@ def load(cls, model, example_inputs, eager_forward):
for node in output_node.args[0]
]

# Use a utility function for easier benchmarking
source = """
#include <torch/csrc/inductor/aot_runtime/model_container.h>
torch::aot_inductor::AOTInductorModelContainer model(1);
void run(
const std::vector<at::Tensor>& input_tensors,
std::vector<at::Tensor>& output_tensors) {
model.run(input_tensors, output_tensors, at::cuda::getCurrentCUDAStream(), nullptr);
}
"""
module = torch.utils.cpp_extension.load_inline(
name="aot_inductor",
cpp_sources=[source],
cpp_sources=[aot_inductor_launcher],
functions=["run"],
extra_ldflags=[so_path],
with_cuda=True,
Expand Down
5 changes: 5 additions & 0 deletions test/distributed/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ def test_tcp_store_timeout_doest_break_client(self):
time_diff = end - start
self.assertGreater(test_store_timeout.seconds * 10, time_diff)

def test_tcp_store_url_with_libuv(self):
url = self.create_tcp_url()
gen0 = dist.rendezvous(url + "&rank=0&use_libuv=1")
store0, rank0, size0 = next(gen0)
self.assertTrue(store0.libuvBackend)

class DummyStore(dist.Store):
def __init__(self):
Expand Down
7 changes: 7 additions & 0 deletions test/dynamo/test_comptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ def _(ctx):
'obj_weakref': None
'guarded_class': None
}
global '' BACKEND_MATCH
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
shape_env '' SHAPE_ENV
{
'guard_types': None,
Expand Down
39 changes: 38 additions & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2378,6 +2378,44 @@ def foo(x, y):
foo, (a, {"k": b}), constraints=[dynamic_dim(a, 0), dynamic_dim(b, 0)]
)

def test_enforce_equalities(self):
def bar(x, y):
return torch.matmul(x, y)

def specify_constraints(x, y):
return [
dynamic_dim(x, 0) == dynamic_dim(y, 0),
dynamic_dim(x, 1) == dynamic_dim(x, 2),
dynamic_dim(x, 2) == dynamic_dim(y, 1),
dynamic_dim(y, 1) == dynamic_dim(y, 2),
]

x = torch.randn(10, 3, 3)
y = torch.randn(10, 3, 4)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
".*y.*size.*1.* = 3 is not equal to .*y.*size.*2.* = 4",
):
torch._export.export(
bar,
(x, y),
constraints=specify_constraints(x, y),
)
y = torch.randn(10, 3, 3)
ebar = torch._export.export(
bar,
(x, y),
constraints=specify_constraints(x, y),
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in ebar.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
)

@config.patch(
capture_dynamic_output_shape_ops=True,
specialize_int=True,
Expand Down Expand Up @@ -3755,7 +3793,6 @@ def forward(self, x):

self.assertTrue(torch.allclose(m(x), gm(x)))

@unittest.expectedFailure
def test_predispatch_with_for_out_dtype_nested(self):
class M(torch.nn.Module):
def __init__(self, weight):
Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def fn(x, y):
fn_opt(torch.ones(1000, 1000), 1)
self.assertGreater(len(records), 0)

test_dynamo_debug = within_range_record_test(30, 50, dynamo=logging.DEBUG)
test_dynamo_debug = within_range_record_test(30, 55, dynamo=logging.DEBUG)
test_dynamo_info = within_range_record_test(2, 10, dynamo=logging.INFO)

@make_logging_test(dynamo=logging.DEBUG)
Expand Down

0 comments on commit 1868687

Please sign in to comment.