Skip to content

Commit

Permalink
Update on "[inductor] convert layout of conv weight ahead of time for…
Browse files Browse the repository at this point in the history
… inference"


This PR handles inference. Will do similar thing for training later.

Some manual testing results shows this can improve inference perf by 2-3% (absolute improvement not relative one).
- convmixer: 4.285x -> 4.309x
- resnet50: 2.170x -> 2.203x

The PR is built upon freezing. Since without freezing, the weight input for a conv node may not be a parameter directly but be the output of precision converting ops. It's so much easier to implement this PR after freezing.


Commands
```
TORCHINDUCTOR_FREEZING=1 python benchmarks/dynamo/timm_models.py --backend inductor --amp --performance --only convmixer_768_32 --inference  
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78

[ghstack-poisoned]
  • Loading branch information
shunting314 committed Jun 28, 2023
2 parents 24b118b + 434729d commit b8c4556
Show file tree
Hide file tree
Showing 48 changed files with 786 additions and 138 deletions.
2 changes: 1 addition & 1 deletion .ci/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then
export USE_CUDA=0
export USE_ASAN=1
export USE_MKLDNN=0
export UBSAN_FLAGS="-fno-sanitize-recover=all"
export UBSAN_FLAGS="-fno-sanitize-recover=all;-fno-sanitize=float-divide-by-zero;-fno-sanitize=float-cast-overflow"
unset USE_LLVM
fi

Expand Down
2 changes: 1 addition & 1 deletion .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then
(cd test && python -c "import torch; print(torch.__version__, torch.version.git_version)")
echo "The next four invocations are expected to crash; if they don't that means ASAN/UBSAN is misconfigured"
(cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_csrc_asan(3)")
(cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_csrc_ubsan(0)")
#(cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_csrc_ubsan(0)")
(cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_vptr_ubsan()")
(cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_aten_asan(3)")
fi
Expand Down
1 change: 1 addition & 0 deletions benchmarks/dynamo/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def setup_torchbench_cwd():
}

FORCE_AMP_FOR_FP16_BF16_MODELS = {
"DALLE2_pytorch",
"doctr_det_predictor",
"doctr_reco_predictor",
"Super_SloMo",
Expand Down
5 changes: 1 addition & 4 deletions test/ao/sparsity/test_data_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import logging
import torch
from torch.nn.utils.parametrize import is_parametrized
import unittest
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ASAN
from torch.testing._internal.common_utils import TestCase

from typing import Tuple
from torch import nn
Expand Down Expand Up @@ -511,7 +510,6 @@ def __init__(self):


class TestQuantizationUtils(TestCase):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN due to address sanitization")
def test_ptq_sparsify_first(self):
"""The expectation is post_training_sparse_quantize function
1. Takes in a model
Expand Down Expand Up @@ -551,7 +549,6 @@ def test_ptq_sparsify_first(self):
assert abs(sl_emb1 - 0.80) <= 0.05 # +- 5% leeway
assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway

@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN due to address sanitization")
def test_ptq_quantize_first(self):
"""The expectation is post_training_sparse_quantize function
1. Takes in a model
Expand Down
153 changes: 153 additions & 0 deletions test/cpp/c10d/ProcessGroupNCCLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,29 @@ class NCCLTest : public NCCLTestBase {
}
}

at::Tensor to_sparse_row_indices_format(at::Tensor& tensor) {
// Get the indices of all non-zero elements in the dense tensor
// Get the unique row indices of the non-zero elements
auto row_indices = std::get<0>(
at::_unique(tensor.nonzero().select(/*dim=*/1, /*index=*/0)));
at::Tensor sparse_values = tensor.index_select(
/*dim=*/0, row_indices); // get the values at the non-zero indices
return at::sparse_coo_tensor(
row_indices.unsqueeze(0), sparse_values, tensor.sizes())
.to(tensor.device());
}

// Launches value initialization for every sparse tensor
void valueInitializationForSparse() {
at::cuda::OptionalCUDAGuard deviceGuard;
for (const auto i : c10::irange(numDevices_)) {
deviceGuard.set_index(i);
tensors_[i].fill_(pg_->getRank() * numDevices_ + i + 1);
// Convert the dense tensor to a sparse tensor in COO row format
tensors_[i] = to_sparse_row_indices_format(tensors_[i]);
}
}

const int numDevices_;
int worldSize_;
std::vector<at::Tensor> tensors_;
Expand Down Expand Up @@ -196,6 +219,21 @@ class AllreduceNCCLTest : public NCCLTest {
}
};

class SparseAllreduceNCCLTest : public NCCLTest {
public:
SparseAllreduceNCCLTest(const std::string& path, int worldSize, int inputDim)
: NCCLTest(path, worldSize, kBackendDefaultTimeout, inputDim) {}

c10::intrusive_ptr<c10d::Work> run() {
// For the duration of this function, make THC use our streams
c10::cuda::CUDAMultiStreamGuard guard(streams_);
launchDeviceSleep();
valueInitializationForSparse();
auto results = pg_->allreduce_sparse(tensors_);
return results;
}
};

class BroadcastNCCLTest : public NCCLTest {
public:
BroadcastNCCLTest(const std::string& path, int worldSize)
Expand Down Expand Up @@ -361,6 +399,108 @@ void testAllreduce(const std::string& path, int rank, int size) {
}
}

void testSparseAllreduce(const std::string& path, int rank, int size) {
const int inputDim = 3;
auto test = SparseAllreduceNCCLTest(path, size, inputDim);
test.initialize(rank, size);
auto work = test.run();
// Wait for work to finish
test.wait(work);

const auto input_tensors = test.getTensors();

// validate the work output is same as tensor
auto output_tensor = work->result();
// Validation
int totalNumGPUs = test.numDevices() * size;
// Add one since we are seeding with an additional 1 to prevent empty tensors
totalNumGPUs++;
const auto expected = (totalNumGPUs * (totalNumGPUs - 1)) / 2;
for (const auto i : c10::irange(input_tensors.size())) {
const auto& tensor = input_tensors[i];

// validate the tensor is sparse
EXPECT_EQ(tensor.is_sparse(), true);

auto indices = tensor._indices();
auto values = tensor._values();

// validate indices are expected size
auto sizes = indices.sizes();
EXPECT_EQ(sizes.size(), 2);
if (sizes[0] == 1) {
// row indices
EXPECT_EQ(sizes[1], inputDim);
} else if (sizes[0] == 2) {
// coorindate indices
EXPECT_EQ(sizes[1], inputDim * inputDim);
}

// validate all tensor values are expected value
const auto* const data = values.data_ptr<float>();
for (const auto k : c10::irange(values.numel())) {
EXPECT_EQ(data[k], expected)
<< "Allreduce outputs do not match expected outputs";
}

// expect the input and output tensors should be the same
auto input_dense = tensor.to_dense();
auto output_dense = output_tensor[i].to(input_dense.device()).to_dense();
EXPECT_TRUE(input_dense.allclose(output_dense));
}
}

void testSparseAllreduceLarge(const std::string& path, int rank, int size) {
const int inputDim = 2500;
auto test = SparseAllreduceNCCLTest(path, size, inputDim);
test.initialize(rank, size);
auto work = test.run();
// Wait for work to finish
test.wait(work);

const auto input_tensors = test.getTensors();

// validate the work output is same as tensor
auto output_tensor = work->result();
// Validation
int totalNumGPUs = test.numDevices() * size;
// Add one since we are seeding with an additional 1 to prevent empty tensors
totalNumGPUs++;
const auto expected = (totalNumGPUs * (totalNumGPUs - 1)) / 2;
for (const auto i : c10::irange(input_tensors.size())) {
const auto& tensor = input_tensors[i];

// validate the tensor is sparse
EXPECT_EQ(tensor.is_sparse(), true);

auto indices = tensor._indices();
auto values = tensor._values();

// validate indices are expected size
auto sizes = indices.sizes();
EXPECT_EQ(sizes.size(), 2);
if (sizes[0] == 1) {
// row indices
EXPECT_EQ(sizes[1], inputDim);
} else if (sizes[0] == 2) {
// coorindate indices
EXPECT_EQ(sizes[1], inputDim * inputDim);
}

// validate all tensor values are expected value
const auto* const data = values.data_ptr<float>();
for (const auto k : c10::irange(values.numel())) {
EXPECT_EQ(data[k], expected)
<< "Allreduce outputs do not match expected outputs";
}

// expect the input and output tensors should be the same
auto input_dense = tensor.to_dense();
auto output_dense = output_tensor[i].to(input_dense.device()).to_dense();
EXPECT_TRUE(input_dense.allclose(output_dense));
}
}

void testBroadcast(const std::string& path, int rank, int size) {
auto test = BroadcastNCCLTest(path, size);
test.initialize(rank, size);
Expand Down Expand Up @@ -731,3 +871,16 @@ TEST_F(ProcessGroupNCCLTest, testBackendName) {
std::string(c10d::NCCL_BACKEND_NAME));
}
}

#ifdef IS_NCCL_EXP
TEST_F(ProcessGroupNCCLTest, testSparseAllreduce) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testSparseAllreduce(file.path, rank_, size_);
testSparseAllreduceLarge(file.path, rank_, size_);
}
}
#endif
93 changes: 93 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,30 @@ def broadcast(xs, rootRank, rootTensor):
for tensor in xs:
self.assertEqual(tensor, expected_tensor)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
def test_sparse_allreduce_ops(self):
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())

indices = torch.tensor([[0, 1]])
values = torch.tensor([[1, 2, 0], [4, 0, 6]])
sparse_tensor = torch.sparse_coo_tensor(indices, values, size=(2, 3)).to(self.rank)

# sparse allreduce call is wrapped in a try catch since the c10d API is only available in the nccl experimental branch
try:
work = pg.allreduce([sparse_tensor])
work.wait()

# work.result() returns a list of size 1, with the allreduce output as a dense tensor
a = torch.tensor([[2, 4, 0], [8, 0, 12]]).to(self.rank)
self.assertEqual(work.result()[0], a)
except RuntimeError as e:
if "allreduce_sparse is only available in the NCCL experimental branch." in str(e):
pass
else:
# Rethrow the exception if it's a different error
raise

@requires_nccl()
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
Expand Down Expand Up @@ -3049,6 +3073,75 @@ def test_new_group_local_sync_duplicated_pg(self):
self._test_new_group_local_sync_duplicate_pg(backend="nccl")


class SparseCollective(MultiProcessTestCase):
@property
def world_size(self):
return 1

def setUp(self):
super().setUp()
# NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
# that use NCCL_BLOCKING_WAIT will test it as expected.
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
# self.num_gpus = torch.cuda.device_count()
self._spawn_processes()

def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass

class ToyModel(nn.Module):
def __init__(self, rank, vocab_size, embedding_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, sparse=True).to(rank)
self.linear = nn.Linear(embedding_dim, 1).to(rank)

def forward(self, inputs):
embedded = self.embedding(inputs)
# embedded shape: (batch_size, sequence_length, embedding_dim)
flattened = torch.mean(embedded, dim=1)
# flattened shape: (batch_size, embedding_dim)
output = self.linear(flattened)
# output shape: (batch_size, 1)
return output

@requires_nccl()
@skip_if_lt_x_gpu(1)
def test_ddp_set_sparse_metadata(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)

vocab_size = 5

model = SparseCollective.ToyModel(self.rank, vocab_size=vocab_size, embedding_dim=10)
ddp_model = DistributedDataParallel(model)
inputs = torch.tensor([[1, 0, 0], [0, 0, 0], [0, 0, 0]]).to(self.rank)
# set sparse metadata on the DDP model
indices = torch.Tensor(list(range(vocab_size)))
ddp_model._set_sparse_metadata({"embedding.weight" : indices})
# forward pass
try:
output = ddp_model(inputs)
loss = output.sum()

# backward pass
loss.backward()
self.assertTrue(ddp_model.module.embedding.weight.grad.indices, indices)
except RuntimeError as e:
if "allreduce_sparse is only available in the NCCL experimental branch." in str(e):
pass
else:
# Rethrow the exception if it's a different error
raise


if __name__ == "__main__":
assert (
Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2877,7 +2877,7 @@ def f_branch_return_non_tensor(x):
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"HigherOrderOperator can't return non-tensor scalar output",
"HigherOrderOperator body's output must consist of tensors only",
):
torch._dynamo.export(
f_branch_return_non_tensor,
Expand Down
22 changes: 12 additions & 10 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,24 +699,26 @@ def f(x):
)

def test_fallback_on_nested_tuple_output(self):
# We can likely support this in the future, I just don't want to deal
# with it right now
counters.clear()
cnt = CompileCounter()

backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)

@torch.compile(backend=cnt)
def f(x):
return wrap(lambda x: ((x.sin(), x.cos()),), x)
((a, b),) = wrap(lambda x: ((x.sin(), x.cos()),), x)
return a + b

x = torch.randn(2, 3)
result = f(x)

self.assertEqual(result, ((x.sin(), x.cos()),))
self.assertEqual(cnt.frame_count, 0)
self.assertEqual(
dict(counters["graph_break"]),
{"HigherOrderOperator body's output must consist of tensors only": 1},
)
self.assertEqual(result, x.sin() + x.cos())
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(len(backend.graphs), 1)
wrap_node = find_first_node(backend.graphs[0], wrap)
self.assertTrue(len(wrap_node.args), 1)
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
self.assertEqual(op_count(body_function), 2)

def test_fallback_on_output_with_dict(self):
# We can likely support this in the future, I just don't want to deal
Expand Down

0 comments on commit b8c4556

Please sign in to comment.