Skip to content

Commit

Permalink
Update on "[Inductor] support vertical reduction in cpp"
Browse files Browse the repository at this point in the history
cc soumith voznesenskym penguinwu anijain2305 EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
  • Loading branch information
jgong5 committed Apr 2, 2023
2 parents 67f335a + 34e5202 commit e897e89
Show file tree
Hide file tree
Showing 24 changed files with 1,638 additions and 1,385 deletions.
17 changes: 5 additions & 12 deletions aten/src/ATen/native/cpu/IndexKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,8 @@ void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
});
}

template <typename scalar_t, typename mask_t>
template <typename scalar_t>
void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
auto is_mask_bool = std::is_same<mask_t, bool>::value;
std::ptrdiff_t source_cntr = 0;
scalar_t* source_ptr = source.data_ptr<scalar_t>();
auto numel = source.numel();
Expand All @@ -342,10 +341,7 @@ void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
char* mask = data[1];
const int64_t mask_stride = strides[1];
for (const auto i : c10::irange(n)) {
mask_t mask_value = *(mask_t*)(mask + mask_stride * i);
if (!is_mask_bool) {
TORCH_CHECK(mask_value <= static_cast<mask_t>(1), "Mask tensor can take 0 and 1 values only");
}
auto mask_value = *reinterpret_cast<bool*>(mask + mask_stride * i);
if (mask_value) {
TORCH_CHECK(source_cntr < numel, "Number of elements of source < number of ones in mask");
*(scalar_t*)(dst + dst_stride * i) = *(source_ptr);
Expand All @@ -358,19 +354,16 @@ void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
}

void masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
TORCH_CHECK(iter.input_dtype() == ScalarType::Bool, "masked_scatter_ only supports boolean masks, "
"but got mask with dtype ", iter.input_dtype());
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
ScalarType::Bool,
ScalarType::BFloat16,
ScalarType::Half,
iter.dtype(),
"masked_scatter",
[&] {
auto mask_dtype = iter.input_dtype(0);
if (mask_dtype == ScalarType::Bool) {
cpu_masked_scatter_kernel<scalar_t, bool>(iter, source);
} else {
cpu_masked_scatter_kernel<scalar_t, unsigned char>(iter, source);
}
cpu_masked_scatter_kernel<scalar_t>(iter, source);
});
}

Expand Down
9 changes: 3 additions & 6 deletions aten/src/ATen/native/cuda/IndexKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,15 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor& mask, const Tensor& so
at::assert_no_internal_overlap(self);
TORCH_CHECK(
self.scalar_type() == source.scalar_type(),
"masked_scatter: expected self and source to have same dtypes but got",
"masked_scatter_: expected self and source to have same dtypes but got",
self.scalar_type(),
" and ",
source.scalar_type());
TORCH_CHECK(mask.dtype() == ScalarType::Bool, "masked_scatter_ only supports boolean masks, "
"but got mask with dtype ", mask.dtype());

c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_scatter_");

if (b_mask->dtype() == ScalarType::Byte) {
TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
}

if (self.numel() == 0) {
return self;
}
Expand Down
24 changes: 6 additions & 18 deletions aten/src/ATen/native/cuda/IndexKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,15 @@ void take_kernel(

namespace {

template <typename mask_t>
__global__ void masked_scatter_size_check(int64_t *mask_exclusive_sum, mask_t *mask, int64_t srcSize) {
__global__ void masked_scatter_size_check(int64_t *mask_exclusive_sum, bool *mask, int64_t srcSize) {
// Convert exclusive sum to inclusive sum
auto totalElements = *mask_exclusive_sum + *mask;
CUDA_KERNEL_ASSERT(totalElements <= srcSize);
}

template <typename mask_t>
void masked_scatter_cuda_impl(
} // anonymous namespace

void launch_masked_scatter_kernel(
const TensorBase &self, const TensorBase &mask,
const TensorBase &maskPrefixSum, const TensorBase &source) {
auto srcSize = source.numel();
Expand All @@ -361,7 +361,7 @@ void masked_scatter_cuda_impl(

// Use a prefix sum to determine the output locations of the masked elements
auto maskPrefixSum_data = maskPrefixSum.data_ptr<int64_t>();
auto mask_data = mask_cont.data_ptr<mask_t>();
auto mask_data = mask_cont.data_ptr<bool>();

at::cuda::cub::mask_exclusive_sum(
mask_data, maskPrefixSum_data, mask_numel);
Expand Down Expand Up @@ -395,7 +395,7 @@ void masked_scatter_cuda_impl(
[&]() {
auto source_ptr = source_contig.data_ptr<scalar_t>();
gpu_kernel(
iter, [=] GPU_LAMBDA(scalar_t a, mask_t mask, int64_t maskPrefixSum) -> scalar_t {
iter, [=] GPU_LAMBDA(scalar_t a, bool mask, int64_t maskPrefixSum) -> scalar_t {
if (mask) {
return source_ptr[maskPrefixSum];
}
Expand All @@ -405,18 +405,6 @@ void masked_scatter_cuda_impl(
});
}

} // anonymous namespace

void launch_masked_scatter_kernel(
const TensorBase &self, const TensorBase &mask,
const TensorBase &maskPrefixSum, const TensorBase &source) {
if (mask.scalar_type() == kBool) {
masked_scatter_cuda_impl<bool>(self, mask, maskPrefixSum, source);
} else {
masked_scatter_cuda_impl<uint8_t>(self, mask, maskPrefixSum, source);
}
}

template <typename scalar_t>
void flip_kernel_impl(TensorIterator& iter) {
if (!iter.can_use_32bit_indexing()) {
Expand Down
12 changes: 8 additions & 4 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,9 @@ def run_one_model(
f"{stats['graph_breaks']} graph breaks ({stats['unique_graph_breaks']} unique)"
)

if self.args.stats:
Stats.print_summary()


def help(fn):
return fn.__doc__
Expand Down Expand Up @@ -1757,23 +1760,24 @@ def get_example_inputs(self):
"--profiler_trace_name",
help="Overwrites exported trace name",
)

parser.add_argument(
"--diff-branch",
default=diff_branch_default,
help="delta current branch against given branch.",
)

parser.add_argument(
"--tag", default=None, help="Specify a tag to be included in csv files."
)

parser.add_argument(
"--explain",
action="store_true",
help="print some graph/op statistics during the run, similar to .explain()",
)

parser.add_argument(
"--stats",
action="store_true",
help="print graph counter stats",
)
parser.add_argument(
"--cold-start-latency",
"--cold_start_latency",
Expand Down
3 changes: 2 additions & 1 deletion c10/util/irange.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#pragma once

#include <c10/util/Exception.h>
#include <c10/util/TypeSafeSignMath.h>

#include <algorithm>
#include <iterator>
Expand Down Expand Up @@ -51,7 +52,7 @@ struct integer_iterator {
// end`. To handle `c10::irange(n)` where n < 0 (which should be
// empty), we just make `begin != end` fail whenever `end` is
// negative.
return other.value < 0 || value == other.value;
return is_negative(other.value) || value == other.value;
} else {
return value == other.value;
}
Expand Down
30 changes: 22 additions & 8 deletions test/distributed/_spmd/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,17 +558,14 @@ def train_step(mod, opt, inp):
torch.manual_seed(1)
# FIXME(@mrshenli): gradients for bias is missing
mod = nn.Linear(10, 10, bias=True).cuda(rank)
# FIXME(@mrshenli): we have to enable foreach to get better perf
opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=True)
inp = torch.randn(2, 10).cuda(rank)

ddp_mod = DDP(deepcopy(mod), device_ids=[rank])
ddp_opt = torch.optim.SGD(ddp_mod.parameters(), lr=0.01, foreach=True)
self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step)

@skip_if_lt_x_gpu(2)
@with_comms
def test_adam(self):
def _test_adam(self, *, foreach: bool, fused: bool):
@compile()
def train_step(mod, opt, inp):
mod(inp).sum().backward()
Expand All @@ -580,16 +577,31 @@ def train_step(mod, opt, inp):
torch.manual_seed(0)
# FIXME(@mrshenli): gradients for bias is missing
mod = nn.Linear(10, 10, bias=False).cuda(rank)
# FIXME(@mrshenli): we have to enable foreach to get better perf
opt = torch.optim.Adam(
mod.parameters(), lr=0.01, foreach=True, capturable=True
mod.parameters(),
lr=0.01,
foreach=foreach,
fused=fused,
capturable=True,
)
inp = torch.randn(2, 10).cuda(rank)

ddp_mod = DDP(deepcopy(mod), device_ids=[rank])
ddp_opt = torch.optim.Adam(ddp_mod.parameters(), lr=0.01, foreach=True)
ddp_opt = torch.optim.Adam(
ddp_mod.parameters(), lr=0.01, foreach=foreach, fused=fused
)
self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step)

@skip_if_lt_x_gpu(2)
@with_comms
def test_adam_foreach(self):
self._test_adam(foreach=True, fused=False)

@skip_if_lt_x_gpu(2)
@with_comms
def test_adam_fused(self):
self._test_adam(foreach=False, fused=True)

@skip_if_lt_x_gpu(2)
@with_comms
def test_train_step_override(self):
Expand Down Expand Up @@ -678,7 +690,9 @@ def train_step(mod, opt, inp):
self.assertEqual(graph_optimization.call_count, 1)
gm = train_step.__dict__[COMPILED_OBJECT_KEY].gm
train_step(mod, opt, inp)
self.assertEqual(id(gm), id(train_step.__dict__[COMPILED_OBJECT_KEY].gm))
self.assertEqual(
id(gm), id(train_step.__dict__[COMPILED_OBJECT_KEY].gm)
)
self.assertEqual(graph_optimization.call_count, 1)


Expand Down
2 changes: 2 additions & 0 deletions test/dynamo/test_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class TestHFPretrained(torch._dynamo.test_case.TestCase):
@maybe_skip
def test_pretrained(self):
def fn(a, tmp):
if hasattr(tmp, "somekey"):
a = a + 1
if tmp.return_dict:
return a + torch.ones(2) * tmp.max_length
return a
Expand Down
24 changes: 24 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,30 @@ def fn(args):

self.assertTrue(same(ref, res))

def test_nullcontext1(self):
@torch.compile(fullgraph=True, backend="eager")
def fn(x, ctx):
x = x.sin()
with ctx:
x = x.cos()
x = x.sin()
return x

y = torch.randn(10)
self.assertTrue(same(fn(y, contextlib.nullcontext()), y.sin().cos().sin()))

def test_nullcontext2(self):
@torch.compile(fullgraph=True, backend="eager")
def fn(x, ctx):
x = x.sin()
with ctx():
x = x.cos()
x = x.sin()
return x

y = torch.randn(10)
self.assertTrue(same(fn(y, contextlib.nullcontext), y.sin().cos().sin()))

# AssertionError: ABCMeta
@unittest.expectedFailure
def test_numpy_list(self):
Expand Down

0 comments on commit e897e89

Please sign in to comment.