Skip to content

Commit

Permalink
#7516: Remove reshape on to_memory_config, make it a composit, and pr…
Browse files Browse the repository at this point in the history
…ovide support for automatically using launch_op
  • Loading branch information
eyonland authored and arakhmati committed May 10, 2024
1 parent ba2f81d commit 28712e7
Show file tree
Hide file tree
Showing 20 changed files with 594 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def get_expected_times(functional_vit):
@pytest.mark.parametrize("sequence_size", [196]) ## padded from 197 to 224
@pytest.mark.parametrize("functional_vit", [ttnn_optimized_sharded_vit])
def test_performance_vit_encoder(device, use_program_cache, model_name, batch_size, sequence_size, functional_vit):
ttnn.dump_stack_trace_on_segfault()

config = transformers.ViTConfig.from_pretrained(model_name)
config.num_hidden_layers = 12
model = transformers.ViTForImageClassification.from_pretrained(
Expand Down Expand Up @@ -127,6 +129,8 @@ def test_performance_vit_encoder(device, use_program_cache, model_name, batch_si
def test_performance_vit_e2e(
device, use_program_cache, model_name, batch_size, image_size, sequence_size, functional_vit
):
ttnn.dump_stack_trace_on_segfault()

config = transformers.ViTConfig.from_pretrained(model_name)
config.num_hidden_layers = 12
model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config)
Expand Down
46 changes: 46 additions & 0 deletions tests/ttnn/unit_tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize("scalar", [3])
@pytest.mark.parametrize("size", [64])
def test_add_1D_tensor_and_scalar(device, scalar, size):
device.enable_async(True)

torch.manual_seed(0)

torch_input_tensor = torch.rand((size,), dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor + scalar

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = input_tensor + scalar
output_tensor = ttnn.to_torch(output_tensor, torch_rank=1)

assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988
assert output_tensor.shape == (size,)


@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [64])
def test_add_2D_tensors(device, h, w):
device.enable_async(True)

torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
output = ttnn.add(input_tensor_a, input_tensor_b)
output = ttnn.to_torch(output)

assert_with_pcc(torch_output_tensor, output, 0.9999)
3 changes: 3 additions & 0 deletions tt_eager/tt_dnn/op_library/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ struct function_traits<fn<TReturn, TArgs...> T::*> {
template <class FnPtr>
using last_arg_of_function_t = typename function_traits<FnPtr>::last_arg_t;

template <class FnPtr>
using return_arg_of_function_t = typename function_traits<FnPtr>::return_t;

template<typename, typename = std::void_t<>>
struct has_create_program : std::false_type {};

Expand Down
30 changes: 20 additions & 10 deletions tt_eager/tt_dnn/op_library/run_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ inline bool any_tensor_on_multi_device(const Tensors& tensors) {
return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& tensor) { return tensor.storage_type() == StorageType::MULTI_DEVICE; });
}

static Device* get_device(const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors = {}) {
Device* get_device(const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors) {
for (auto& input_tensor : input_tensors) {
if (input_tensor.storage_type() == StorageType::DEVICE) {
return input_tensor.device();
Expand Down Expand Up @@ -657,12 +657,18 @@ void launch_op(
std::function<std::vector<Tensor>(const std::vector<Tensor>&, const std::vector<std::optional<const Tensor>>&)>&& op_func,
const std::vector<Tensor> input_tensors,
std::vector<Tensor>& output_tensors,
const std::vector<std::optional<const Tensor>> optional_input_tensors
const std::vector<std::optional<const Tensor>> optional_input_tensors,
bool enable_autoformat_device
) {
// Send host side op compile and run to the worker queue
// Assert to ensure that worker threads are specified.
ZoneScopedN("LaunchOp");
auto& workers = output_tensors.at(0).workers;
if (not enable_autoformat_device and workers.empty()) {
// Run on the host
output_tensors = op_func(input_tensors, optional_input_tensors);
return;
}
for (auto& output_tensor : output_tensors) {
TT_FATAL(output_tensor.workers.size(), "Worker threads must be specified for outputs populated by launch_op. This API can only be used for creating output tensors on device.");
TT_FATAL(output_tensor.workers == workers, "Worker threads must be consistent across all outputs populated by launch_op.");
Expand Down Expand Up @@ -815,7 +821,10 @@ void validate_workers_and_storage(const std::vector<Tensor>& inputs, const std::
}
}

std::vector<Device*> get_workers_for_op_output(const std::vector<Tensor>&& inputs, const std::vector<std::optional<const Tensor>>&& optional_inputs) {
std::vector<Device*> get_workers_for_op_output(
const std::vector<Tensor>&& inputs,
const std::vector<std::optional<const Tensor>>&& optional_inputs,
bool enable_autoformat_device) {
std::vector<Device*> workers_for_op = {};
// Infer output workers from inputs. For multi-device tensors the number
// of workers used for the op (and assigned to the ouput) is the minimum
Expand All @@ -841,14 +850,15 @@ std::vector<Device*> get_workers_for_op_output(const std::vector<Tensor>&& input
}
}
}
validate_workers_and_storage(inputs, optional_inputs, workers_for_op);
// Workers not specified - inputs are on host and not multi-device.
// Use the default device from autoformat.
if (not workers_for_op.size()) {
TT_FATAL(AutoFormat::GetDefaultDevice(), "Default device must be specified using AutoFormat::SetDefaultDevice, if workers are not specified for inputs to op.");
workers_for_op = {AutoFormat::GetDefaultDevice()};
if (enable_autoformat_device) {
validate_workers_and_storage(inputs, optional_inputs, workers_for_op);
// Workers not specified - inputs are on host and not multi-device.
// Use the default device from autoformat.
if (not workers_for_op.size()) {
TT_FATAL(AutoFormat::GetDefaultDevice(), "Default device must be specified using AutoFormat::SetDefaultDevice, if workers are not specified for inputs to op.");
workers_for_op = {AutoFormat::GetDefaultDevice()};
}
}
return workers_for_op;
}

}
15 changes: 12 additions & 3 deletions tt_eager/tt_dnn/op_library/run_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ auto generic_create_output_tensors(

namespace run_operation_state {
namespace detail {

struct RunOperationState {

RunOperationState() {}
Expand Down Expand Up @@ -96,6 +97,7 @@ inline const auto& get_composite_parent_names() {


namespace detail {

template<typename ReturnType, typename... Args>
struct CompositeOperation {

Expand Down Expand Up @@ -387,8 +389,8 @@ void launch_op(
std::function<std::vector<Tensor>(const Tensors&, const OptionalConstTensors&)>&& op_func,
const std::vector<Tensor> input_tensors,
std::vector<Tensor>& output_tensors,
const std::vector<std::optional<const Tensor>> optional_input_tensors = {}
);
const std::vector<std::optional<const Tensor>> optional_input_tensors = {},
bool enable_autoformat_device = true);

void launch_with_autoformat(
std::function<std::vector<Tensor>(const std::vector<Tensor>&, const std::vector<std::optional<const Tensor>>&)>&& op_func,
Expand All @@ -397,7 +399,14 @@ void launch_with_autoformat(
const std::vector<std::optional<const Tensor>> optional_input_tensors = {}
);

std::vector<Device*> get_workers_for_op_output(const std::vector<Tensor>&& inputs, const std::vector<std::optional<const Tensor>>&& optional_inputs = {});
std::vector<Device*> get_workers_for_op_output(
const std::vector<Tensor>&& inputs,
const std::vector<std::optional<const Tensor>>&& optional_inputs = {},
bool enable_autoformat_device = true);

namespace detail{
Device* get_device(const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors = {});
}

} //namespace operation

Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/pybind11/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ void py_module(py::module& module) {

module.def("get_memory_config", &ttnn::get_memory_config);
module.def("set_printoptions", &ttnn::set_printoptions, py::kw_only(), py::arg("profile"));
module.def("dump_stack_trace_on_segfault", &ttnn::core::dump_stack_trace_on_segfault);
}

} // namespace core
Expand Down
37 changes: 30 additions & 7 deletions ttnn/cpp/pybind11/operations/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,35 @@ void py_module(py::module& module) {

module.def("deallocate", &ttnn::operations::core::deallocate, py::arg("tensor"), py::arg("force") = true);

module.def(
"to_memory_config",
&ttnn::operations::core::to_memory_config,
py::arg("tensor"),
py::arg("memory_config"),
py::arg("dtype") = std::nullopt);
bind_registered_operation(
module,
ttnn::to_memory_config,
R"doc(to_memory_config(tensor: ttnn.Tensor, memory_config: MemoryConfig, dtype: Optional[DataType] = None) -> ttnn.Tensor
Converts a tensor to the desired mem_config, used for converting tensors to sharded tensors or interleaved, and to convert DRAM to L1 and vice versa
Args:
* :attr:`tensor`: the ttnn.Tensor
* :attr:`memory_config`: the desired MemoryConfig
* :attr:`dtype`: the optional `ttnn` data type.
Example::
>>> device_id = 0
>>> device = ttnn.open_device(device_id=device_id)
>>> tensor = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device)
>>> tensor = ttnn.to_memory_config(tensor, memory_config)
)doc",
ttnn::pybind_overload_t{
[](const std::decay_t<decltype(ttnn::to_memory_config)> self,
const ttnn::Tensor& tensor,
const ttnn::MemoryConfig& memory_config,
const std::optional<ttnn::DataType>& dtype) -> ttnn::Tensor {
return self(tensor, memory_config, dtype);
},
py::arg("tensor"),
py::arg("memory_config"),
py::arg("dtype") = std::nullopt});

module.def(
"reallocate",
Expand Down Expand Up @@ -158,7 +181,7 @@ Deallocates device tensor and returns a reallocated tensor
py::arg("dtype") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("device") = std::nullopt});

}

} // namespace core
Expand Down
13 changes: 13 additions & 0 deletions ttnn/cpp/ttnn/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "tt_eager/tensor/tensor_impl.hpp" // TTNN_TENSOR_PRINT_PROFILE
#include "tt_eager/tensor/types.hpp"
#include "ttnn/types.hpp"
#include <csignal>

namespace ttnn {
using tt::tt_metal::any_tensor_on_multi_device;
Expand Down Expand Up @@ -97,6 +98,18 @@ inline void set_printoptions(const std::string& profile) {
}).value();
}

inline void segfault_handler(int sig) {
std::cerr << tt::assert::backtrace_to_string() << std::endl;
exit(EXIT_FAILURE);
}

inline void dump_stack_trace_on_segfault() {
if (std::signal(SIGSEGV, segfault_handler) == SIG_ERR) {
std::cerr << "Error: cannot handle SIGSEGV" << std::endl;
exit(EXIT_FAILURE);
}
}

} // namespace core

using core::CONFIG;
Expand Down
Loading

0 comments on commit 28712e7

Please sign in to comment.