From 051156a0b519ef8d1565d082bc57f0e4274bff0e Mon Sep 17 00:00:00 2001 From: hschoi Date: Fri, 3 May 2024 08:50:40 +0000 Subject: [PATCH] #6491: fix moreh logsoftmax --- .../tt_dnn/kernels/compute/moreh_common.hpp | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/tt_eager/tt_dnn/kernels/compute/moreh_common.hpp b/tt_eager/tt_dnn/kernels/compute/moreh_common.hpp index a3f2b0b2beb..ec304e1321d 100644 --- a/tt_eager/tt_dnn/kernels/compute/moreh_common.hpp +++ b/tt_eager/tt_dnn/kernels/compute/moreh_common.hpp @@ -68,13 +68,6 @@ ALWI void mul_tiles_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) { mul_tiles_init(icb0, icb1); } -ALWI void mask_tile_init_with_dt(uint32_t icb0 = 0, uint32_t icb1 = 1) { - #if defined FP32_DEST_ACC_EN - unpack_reconfig_data_format(icb0, icb1); - #endif - mask_tile_init(); -} - class ArgFetcher { private: int arg_idx = 0; @@ -180,7 +173,7 @@ ALWI void mul_tiles_and_mask_tile_to_cb( copy_tile_init_with_dt(maskcb); copy_tile(maskcb, mtile, dst_mask); - mask_tile_init_with_dt(dst0, dst_mask); + mask_tile_init(); mask_tile(dst0, dst_mask); tile_regs_commit(); @@ -214,7 +207,7 @@ ALWI void mul_tiles_log_to_cb( cb_wait_front(icb1, itile1 + 1); tile_regs_acquire(); - mul_tiles_init(); + mul_tiles_init_with_dt(icb0, icb1); mul_tiles(icb0, icb1, itile0, itile1, dst0); log_tile_init(); @@ -222,7 +215,7 @@ ALWI void mul_tiles_log_to_cb( tile_regs_commit(); tile_regs_wait(); - pack_tile(dst0, ocb); + pack_tile_with_dt(dst0, ocb); tile_regs_release(); if (pop0) @@ -286,6 +279,9 @@ ALWI void mul_tiles_bcast_rows_log_to_cb( cb_wait_front(icb1, itile1 + 1); tile_regs_acquire(); + #if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(icb0, icb1); + #endif mul_bcast_rows_init_short(); mul_tiles_bcast_rows(icb0, icb1, itile0, itile1, dst0); @@ -294,7 +290,7 @@ ALWI void mul_tiles_bcast_rows_log_to_cb( tile_regs_commit(); tile_regs_wait(); - pack_tile(dst0, ocb); + pack_tile_with_dt(dst0, ocb); tile_regs_release(); if (pop0) @@ -358,6 +354,9 @@ ALWI void mul_tiles_bcast_cols_log_to_cb( cb_wait_front(icb1, itile1 + 1); tile_regs_acquire(); + #if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(icb0, icb1); + #endif mul_bcast_cols_init_short(); mul_tiles_bcast_cols(icb0, icb1, itile0, itile1, dst0); @@ -366,7 +365,7 @@ ALWI void mul_tiles_bcast_cols_log_to_cb( tile_regs_commit(); tile_regs_wait(); - pack_tile(dst0, ocb); + pack_tile_with_dt(dst0, ocb); tile_regs_release(); if (pop0) @@ -446,7 +445,7 @@ ALWI void mask_tile_to_cb(uint32_t icb, uint32_t maskcb, uint32_t ocb, uint32_t copy_tile_init_with_dt(maskcb); copy_tile(maskcb, mtile, dst_mask); - mask_tile_init_with_dt(icb, maskcb); + mask_tile_init(); mask_tile(dst0, dst_mask); tile_regs_commit(); @@ -654,7 +653,7 @@ ALWI void rexp_tile_to_cb(uint32_t icb, uint32_t ocb, uint32_t itile = 0, uint32 tile_regs_commit(); tile_regs_wait(); - pack_tile(dst, ocb); + pack_tile_with_dt(dst, ocb); tile_regs_release(); if (pop) @@ -691,7 +690,7 @@ ALWI void exp_tile_and_mask_tile_to_cb( copy_tile_init_with_dt(maskcb); copy_tile(maskcb, mtile, dst_mask); - mask_tile_init_with_dt(dst, dst_mask); + mask_tile_init(); mask_tile(dst, dst_mask); tile_regs_commit(); @@ -722,7 +721,7 @@ ALWI void rexp_tile_and_mask_tile_to_cb( cb_wait_front(maskcb, mtile + 1); tile_regs_acquire(); - copy_tile_init(); + copy_tile_init_with_dt(icb); copy_tile(icb, itile, dst); if (pop) @@ -734,7 +733,7 @@ ALWI void rexp_tile_and_mask_tile_to_cb( exp_tile_init(); exp_tile(dst); - copy_tile_init(); + copy_tile_init_with_dt(maskcb); copy_tile(maskcb, mtile, dst_mask); mask_tile_init(); @@ -745,7 +744,7 @@ ALWI void rexp_tile_and_mask_tile_to_cb( cb_pop_front(maskcb, popm); tile_regs_wait(); - pack_tile(dst, ocb); + pack_tile_with_dt(dst, ocb); tile_regs_release(); cb_push_back(ocb, onetile);