Skip to content

Commit

Permalink
#6491: support fp32_dest_acc_en for moreh_softmax forward
Browse files Browse the repository at this point in the history
  • Loading branch information
hschoi4448 authored and tt-aho committed May 4, 2024
1 parent e517e56 commit 6a41973
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 62 deletions.
73 changes: 56 additions & 17 deletions tt_eager/tt_dnn/kernels/compute/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,36 @@ ALWI void REL() { release_dst(tt::DstMode::Half); }

namespace ckernel {

ALWI void pack_tile_with_dt(uint32_t ifrom_dst, uint32_t icb)
{
#if defined FP32_DEST_ACC_EN
PACK(( pack_reconfig_data_format(icb) ));
#endif
pack_tile(ifrom_dst, icb);
}

ALWI void copy_tile_init_with_dt(uint32_t icb)
{
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format_srca(icb);
#endif
copy_tile_init();
}

ALWI void add_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
add_tiles_init(icb0, icb1);
}

ALWI void mul_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_tiles_init(icb0, icb1);
}

class ArgFetcher {
private:
int arg_idx = 0;
Expand Down Expand Up @@ -58,12 +88,12 @@ ALWI void mul_tiles_to_cb(
cb_wait_front(icb1, itile1 + 1);

tile_regs_acquire();
mul_tiles_init();
mul_tiles_init_with_dt(icb0, icb1);
mul_tiles(icb0, icb1, itile0, itile1, dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

if (pop0)
Expand Down Expand Up @@ -206,12 +236,15 @@ ALWI void mul_tiles_bcast_rows_to_cb(
cb_wait_front(icb1, itile1 + 1);

tile_regs_acquire();
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_bcast_rows_init_short();
mul_tiles_bcast_rows(icb0, icb1, itile0, itile1, dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

if (pop0)
Expand Down Expand Up @@ -275,12 +308,15 @@ ALWI void mul_tiles_bcast_cols_to_cb(
cb_wait_front(icb1, itile1 + 1);

tile_regs_acquire();
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_bcast_cols_init_short();
mul_tiles_bcast_cols(icb0, icb1, itile0, itile1, dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

if (pop0)
Expand Down Expand Up @@ -335,12 +371,12 @@ ALWI void copy_tile_to_cb(uint32_t icb, uint32_t ocb, uint32_t itile = 0, uint32
cb_wait_front(icb, itile + 1);

tile_regs_acquire();
copy_tile_init();
copy_tile_init_with_dt(icb);
copy_tile(icb, itile, dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

if (pop)
Expand All @@ -364,12 +400,12 @@ ALWI void add_tiles_to_cb(
cb_wait_front(icb1, itile1 + 1);

tile_regs_acquire();
add_tiles_init();
add_tiles_init_with_dt(icb0, icb1);
add_tiles(icb0, icb1, itile0, itile1, dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

if (pop0)
Expand Down Expand Up @@ -562,15 +598,15 @@ ALWI void exp_tile_to_cb(uint32_t icb, uint32_t ocb, uint32_t itile = 0, uint32_
cb_wait_front(icb, itile + 1);

tile_regs_acquire();
copy_tile_init();
copy_tile_init_with_dt(icb);
copy_tile(icb, itile, dst);

exp_tile_init();
exp_tile(dst);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst, ocb);
pack_tile_with_dt(dst, ocb);
tile_regs_release();

if (pop)
Expand All @@ -585,7 +621,7 @@ ALWI void rexp_tile_to_cb(uint32_t icb, uint32_t ocb, uint32_t itile = 0, uint32
cb_wait_front(icb, itile + 1);

tile_regs_acquire();
copy_tile_init();
copy_tile_init_with_dt(icb);
copy_tile(icb, itile, dst);

negative_tile_init();
Expand Down Expand Up @@ -621,7 +657,7 @@ ALWI void exp_tile_and_mask_tile_to_cb(
cb_wait_front(maskcb, mtile + 1);

tile_regs_acquire();
copy_tile_init();
copy_tile_init_with_dt(icb);
copy_tile(icb, itile, dst);

if (pop)
Expand All @@ -630,7 +666,7 @@ ALWI void exp_tile_and_mask_tile_to_cb(
exp_tile_init();
exp_tile(dst);

copy_tile_init();
copy_tile_init_with_dt(maskcb);
copy_tile(maskcb, mtile, dst_mask);

mask_tile_init();
Expand All @@ -641,7 +677,7 @@ ALWI void exp_tile_and_mask_tile_to_cb(
cb_pop_front(maskcb, popm);

tile_regs_wait();
pack_tile(dst, ocb);
pack_tile_with_dt(dst, ocb);
tile_regs_release();

cb_push_back(ocb, onetile);
Expand Down Expand Up @@ -701,15 +737,15 @@ ALWI void recip_tile_to_cb(uint32_t icb, uint32_t ocb, uint32_t itile = 0, uint3
cb_wait_front(icb, itile + 1);

tile_regs_acquire();
copy_tile_init();
copy_tile_init_with_dt(icb);
copy_tile(icb, itile, dst0);

recip_tile_init();
recip_tile(dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

if (pop)
Expand All @@ -733,6 +769,9 @@ ALWI void reduce_tile_and_recip_tile_to_cb(
cb_wait_front(icb1, onetile);

tile_regs_acquire();
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
reduce_init_delta<false>(reduce_op, dim);
for (uint32_t x = 0; x < size; ++x) {
cb_wait_front(icb0, x + 1); // must be a cumulative wait for correctness
Expand All @@ -752,7 +791,7 @@ ALWI void reduce_tile_and_recip_tile_to_cb(
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

cb_push_back(ocb, onetile);
Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ struct CircularBufferArg {
CircularBufferArg(uint32_t buffer_index, uint32_t num_tiles) : buffer_index(buffer_index), num_tiles(num_tiles) {
data_format = tt::DataFormat::Invalid;
}
CircularBufferArg(uint32_t buffer_index, uint32_t num_tiles, tt::DataFormat data_format) : buffer_index(buffer_index), num_tiles(num_tiles), data_format(data_format) {
}
};

[[maybe_unused]] std::vector<CBHandle> CreateCircularBuffer(
Expand Down
33 changes: 21 additions & 12 deletions tt_eager/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ operation::ProgramWithCallbacks MorehSoftmax::create_program(const std::vector<T

switch (parallelization_strategy) {
case MorehSoftmaxOpParallelizationStrategy::SMALL_W:
return {moreh_softmax_w_small(input, output, this->core_range, this->op)};
return {moreh_softmax_w_small(input, output, this->core_range, this->op, this->compute_kernel_config)};
case MorehSoftmaxOpParallelizationStrategy::SMALL_H:
return {moreh_softmax_h_small(input, output, this->core_range, this->op)};
return {moreh_softmax_h_small(input, output, this->core_range, this->op, this->compute_kernel_config)};
case MorehSoftmaxOpParallelizationStrategy::LARGE_W:
return {moreh_softmax_w_large(input, output, this->core_range, this->op)};
return {moreh_softmax_w_large(input, output, this->core_range, this->op, this->compute_kernel_config)};
case MorehSoftmaxOpParallelizationStrategy::LARGE_H:
return {moreh_softmax_h_large(input, output, this->core_range, this->op)};
return {moreh_softmax_h_large(input, output, this->core_range, this->op, this->compute_kernel_config)};
case MorehSoftmaxOpParallelizationStrategy::LARGE_C:
return {moreh_softmax_c_large(input, output, this->dim, this->core_range, this->op)};
return {moreh_softmax_c_large(input, output, this->dim, this->core_range, this->op, this->compute_kernel_config)};
case MorehSoftmaxOpParallelizationStrategy::NONE:
default: break;
}

return {moreh_softmax_h_large(input, output, this->core_range, this->op)};
return {moreh_softmax_h_large(input, output, this->core_range, this->op, this->compute_kernel_config)};
}

MorehSoftmaxOpParallelizationStrategy MorehSoftmax::get_parallelization_strategy(
Expand Down Expand Up @@ -138,19 +138,22 @@ Tensor moreh_softmax(
uint32_t dim,
std::optional<Tensor> output_tensor,
const MorehSoftmaxOpParallelizationStrategy strategy,
const MemoryConfig &output_mem_config) {
const MemoryConfig &output_mem_config,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config) {

auto device = input_tensor.device();
auto grid_coord = device->compute_with_storage_grid_size();
const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1});

auto kernel_config_val = init_device_compute_kernel_config(device->arch(), compute_kernel_config);
output_tensor = operation::run(
MorehSoftmax{
.dim = dim,
.core_range = all_cores,
.op = MorehSoftmaxOp::SOFTMAX,
.strategy = strategy,
.output_mem_config = output_mem_config},
.output_mem_config = output_mem_config,
.compute_kernel_config = kernel_config_val},
{input_tensor},
{},
{output_tensor}).at(0);
Expand All @@ -163,19 +166,22 @@ Tensor moreh_softmin(
uint32_t dim,
std::optional<Tensor> output_tensor,
const MorehSoftmaxOpParallelizationStrategy strategy,
const MemoryConfig &output_mem_config) {
const MemoryConfig &output_mem_config,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config) {

auto device = input_tensor.device();
auto grid_coord = device->compute_with_storage_grid_size();
const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1});

auto kernel_config_val = init_device_compute_kernel_config(device->arch(), compute_kernel_config);
output_tensor = operation::run(
MorehSoftmax{
.dim = dim,
.core_range = all_cores,
.op = MorehSoftmaxOp::SOFTMIN,
.strategy = strategy,
.output_mem_config = output_mem_config},
.output_mem_config = output_mem_config,
.compute_kernel_config = kernel_config_val},
{input_tensor},
{},
{output_tensor}).at(0);
Expand All @@ -188,19 +194,22 @@ Tensor moreh_logsoftmax(
uint32_t dim,
std::optional<Tensor> output_tensor,
const MorehSoftmaxOpParallelizationStrategy strategy,
const MemoryConfig &output_mem_config) {
const MemoryConfig &output_mem_config,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config) {

auto device = input_tensor.device();
auto grid_coord = device->compute_with_storage_grid_size();
const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1});

auto kernel_config_val = init_device_compute_kernel_config(device->arch(), compute_kernel_config);
output_tensor = operation::run(
MorehSoftmax{
.dim = dim,
.core_range = all_cores,
.op = MorehSoftmaxOp::LOGSOFTMAX,
.strategy = strategy,
.output_mem_config = output_mem_config},
.output_mem_config = output_mem_config,
.compute_kernel_config = kernel_config_val},
{input_tensor},
{},
{output_tensor}).at(0);
Expand Down
25 changes: 15 additions & 10 deletions tt_eager/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "tt_dnn/op_library/operation.hpp"
#include "tt_eager/tensor/tensor.hpp"
#include "tt_dnn/op_library/compute_kernel_config.hpp"
#include <optional>

namespace tt {
Expand Down Expand Up @@ -35,31 +36,32 @@ bool is_moreh_softmax_w_small_available(const Tensor &tensor);
bool is_moreh_softmax_h_small_available(const Tensor &tensor);

operation::ProgramWithCallbacks moreh_softmax_w_small(
const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op);
const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const DeviceComputeKernelConfig compute_kernel_config);
operation::ProgramWithCallbacks moreh_softmax_w_large(
const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op);
const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const DeviceComputeKernelConfig compute_kernel_config);
operation::ProgramWithCallbacks moreh_softmax_h_small(
const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op);
const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const DeviceComputeKernelConfig compute_kernel_config);
operation::ProgramWithCallbacks moreh_softmax_h_large(
const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op);
const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const DeviceComputeKernelConfig compute_kernel_config);
operation::ProgramWithCallbacks moreh_softmax_c_large(
const Tensor &input, const Tensor &output, uint32_t dim, const CoreRange core_range, const MorehSoftmaxOp op);
const Tensor &input, const Tensor &output, uint32_t dim, const CoreRange core_range, const MorehSoftmaxOp op, const DeviceComputeKernelConfig compute_kernel_config);

struct MorehSoftmax {
const uint32_t dim;
const CoreRange core_range; // unused for now
const MorehSoftmaxOp op;
const MorehSoftmaxOpParallelizationStrategy strategy;
const MemoryConfig output_mem_config;
const DeviceComputeKernelConfig compute_kernel_config;

void validate_with_output_tensors(const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<Shape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
MorehSoftmaxOpParallelizationStrategy get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const;
static constexpr auto attribute_names = std::make_tuple("dim", "op", "strategy", "output_mem_config");
static constexpr auto attribute_names = std::make_tuple("dim", "op", "strategy", "output_mem_config", "compute_kernel_config");
const auto attribute_values() const {
return std::make_tuple(std::cref(this->dim), std::cref(this->op), std::cref(this->strategy), std::cref(this->output_mem_config));
return std::make_tuple(std::cref(this->dim), std::cref(this->op), std::cref(this->strategy), std::cref(this->output_mem_config), std::cref(this->compute_kernel_config));
}
};

Expand All @@ -69,21 +71,24 @@ Tensor moreh_softmax(
uint32_t dim,
std::optional<Tensor> output_tensor = std::nullopt,
const MorehSoftmaxOpParallelizationStrategy strategy = MorehSoftmaxOpParallelizationStrategy::NONE,
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt);

Tensor moreh_softmin(
const Tensor &input_tensor,
uint32_t dim,
std::optional<Tensor> output_tensor = std::nullopt,
const MorehSoftmaxOpParallelizationStrategy strategy = MorehSoftmaxOpParallelizationStrategy::NONE,
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt);

Tensor moreh_logsoftmax(
const Tensor &input_tensor,
uint32_t dim,
std::optional<Tensor> output_tensor = std::nullopt,
const MorehSoftmaxOpParallelizationStrategy strategy = MorehSoftmaxOpParallelizationStrategy::NONE,
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt);

} // namespace primary
} // namespace operations
Expand Down
Loading

0 comments on commit 6a41973

Please sign in to comment.