Skip to content

Commit

Permalink
#6491: fix moreh logsoftmax
Browse files Browse the repository at this point in the history
  • Loading branch information
hschoi4448 authored and HariniMohan0102 committed May 7, 2024
1 parent db9930a commit 051156a
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions tt_eager/tt_dnn/kernels/compute/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ ALWI void mul_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
mul_tiles_init(icb0, icb1);
}

ALWI void mask_tile_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) {
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mask_tile_init();
}

class ArgFetcher {
private:
int arg_idx = 0;
Expand Down Expand Up @@ -180,7 +173,7 @@ ALWI void mul_tiles_and_mask_tile_to_cb(
copy_tile_init_with_dt(maskcb);
copy_tile(maskcb, mtile, dst_mask);

mask_tile_init_with_dt(dst0, dst_mask);
mask_tile_init();
mask_tile(dst0, dst_mask);
tile_regs_commit();

Expand Down Expand Up @@ -214,15 +207,15 @@ ALWI void mul_tiles_log_to_cb(
cb_wait_front(icb1, itile1 + 1);

tile_regs_acquire();
mul_tiles_init();
mul_tiles_init_with_dt(icb0, icb1);
mul_tiles(icb0, icb1, itile0, itile1, dst0);

log_tile_init();
log_tile(dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

if (pop0)
Expand Down Expand Up @@ -286,6 +279,9 @@ ALWI void mul_tiles_bcast_rows_log_to_cb(
cb_wait_front(icb1, itile1 + 1);

tile_regs_acquire();
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_bcast_rows_init_short();
mul_tiles_bcast_rows(icb0, icb1, itile0, itile1, dst0);

Expand All @@ -294,7 +290,7 @@ ALWI void mul_tiles_bcast_rows_log_to_cb(
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

if (pop0)
Expand Down Expand Up @@ -358,6 +354,9 @@ ALWI void mul_tiles_bcast_cols_log_to_cb(
cb_wait_front(icb1, itile1 + 1);

tile_regs_acquire();
#if defined FP32_DEST_ACC_EN
unpack_reconfig_data_format(icb0, icb1);
#endif
mul_bcast_cols_init_short();
mul_tiles_bcast_cols(icb0, icb1, itile0, itile1, dst0);

Expand All @@ -366,7 +365,7 @@ ALWI void mul_tiles_bcast_cols_log_to_cb(
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, ocb);
pack_tile_with_dt(dst0, ocb);
tile_regs_release();

if (pop0)
Expand Down Expand Up @@ -446,7 +445,7 @@ ALWI void mask_tile_to_cb(uint32_t icb, uint32_t maskcb, uint32_t ocb, uint32_t
copy_tile_init_with_dt(maskcb);
copy_tile(maskcb, mtile, dst_mask);

mask_tile_init_with_dt(icb, maskcb);
mask_tile_init();
mask_tile(dst0, dst_mask);

tile_regs_commit();
Expand Down Expand Up @@ -654,7 +653,7 @@ ALWI void rexp_tile_to_cb(uint32_t icb, uint32_t ocb, uint32_t itile = 0, uint32
tile_regs_commit();

tile_regs_wait();
pack_tile(dst, ocb);
pack_tile_with_dt(dst, ocb);
tile_regs_release();

if (pop)
Expand Down Expand Up @@ -691,7 +690,7 @@ ALWI void exp_tile_and_mask_tile_to_cb(
copy_tile_init_with_dt(maskcb);
copy_tile(maskcb, mtile, dst_mask);

mask_tile_init_with_dt(dst, dst_mask);
mask_tile_init();
mask_tile(dst, dst_mask);
tile_regs_commit();

Expand Down Expand Up @@ -722,7 +721,7 @@ ALWI void rexp_tile_and_mask_tile_to_cb(
cb_wait_front(maskcb, mtile + 1);

tile_regs_acquire();
copy_tile_init();
copy_tile_init_with_dt(icb);
copy_tile(icb, itile, dst);

if (pop)
Expand All @@ -734,7 +733,7 @@ ALWI void rexp_tile_and_mask_tile_to_cb(
exp_tile_init();
exp_tile(dst);

copy_tile_init();
copy_tile_init_with_dt(maskcb);
copy_tile(maskcb, mtile, dst_mask);

mask_tile_init();
Expand All @@ -745,7 +744,7 @@ ALWI void rexp_tile_and_mask_tile_to_cb(
cb_pop_front(maskcb, popm);

tile_regs_wait();
pack_tile(dst, ocb);
pack_tile_with_dt(dst, ocb);
tile_regs_release();

cb_push_back(ocb, onetile);
Expand Down

0 comments on commit 051156a

Please sign in to comment.