From 85b71fd700957be36b8fc047757fab4ba06241c2 Mon Sep 17 00:00:00 2001 From: Anoop Kapoor Date: Wed, 15 Oct 2025 14:37:40 -0700 Subject: [PATCH 1/2] @FIR-999 - Create SOFT_MAX for tsavorite-backend for GGML --- ggml-tsi-kernel | 2 +- ggml/include/ggml-tsavorite.h | 5 +- ggml/src/ggml-tsavorite/ggml-tsavorite.cpp | 178 ++++++++++++++++----- 3 files changed, 139 insertions(+), 46 deletions(-) diff --git a/ggml-tsi-kernel b/ggml-tsi-kernel index 7dd3227e1f8b1..ab6aecb365ebf 160000 --- a/ggml-tsi-kernel +++ b/ggml-tsi-kernel @@ -1 +1 @@ -Subproject commit 7dd3227e1f8b16d58245ea433e26048736e5d6f0 +Subproject commit ab6aecb365ebf7ee8cd707d44ef4af6f10dc18af diff --git a/ggml/include/ggml-tsavorite.h b/ggml/include/ggml-tsavorite.h index eb2e27d19b3c6..e455c2e6784de 100644 --- a/ggml/include/ggml-tsavorite.h +++ b/ggml/include/ggml-tsavorite.h @@ -140,6 +140,8 @@ enum ggml_tsavorite_kernel_type { GGML_TSAVORITE_KERNEL_TYPE_GEGLU_ERF, GGML_TSAVORITE_KERNEL_TYPE_GEGLU_QUICK, + GGML_TSAVORITE_KERNEL_TYPE_SOFT_MAX, + GGML_TSAVORITE_KERNEL_TYPE_COUNT }; @@ -156,7 +158,7 @@ typedef struct tensor_log_ { uint32_t leaf2_len; uint32_t node_len; enum ggml_tsavorite_tensor_data_type data_type; - enum ggml_tsavorite_kernel_type kernel_type; + enum ggml_op kernel_type; uint64_t num_of_op; FILE *log_file; const ggml_tensor *tensor; @@ -185,6 +187,7 @@ extern void _mlir_ciface_txe_sin_host(void *a, void *res); extern void _mlir_ciface_txe_sigmoid_host(void *a, void *res); extern void _mlir_ciface_txe_silu_host(void *a, void *res); extern void _mlir_ciface_txe_swiglu_host(void *a, void *b, void *res); +extern void _mlir_ciface_txe_soft_max_host(void *a, void *b, void *res, void *buf); extern void _mlir_ciface_txe_rms_norm_host(void *a, void *res, void *buf); /* diff --git a/ggml/src/ggml-tsavorite/ggml-tsavorite.cpp b/ggml/src/ggml-tsavorite/ggml-tsavorite.cpp index 430e1894c5015..8e5e575f69ac5 100644 --- a/ggml/src/ggml-tsavorite/ggml-tsavorite.cpp +++ b/ggml/src/ggml-tsavorite/ggml-tsavorite.cpp @@ -73,6 +73,7 @@ struct _txe_device_t { }; struct _txe_compute_pipeline_state_t { + void (*_mlir_fptr_3_input[DATA_TYPE_MAX_INDEX])(void *, void *, void *, void *); void (*_mlir_fptr_2_input[DATA_TYPE_MAX_INDEX])(void *, void *, void *); void (*_mlir_fptr_1_input[DATA_TYPE_MAX_INDEX])(void *, void *); std::string kernel_name; @@ -256,8 +257,8 @@ void ggml_tsi_log_tensor_data(tensor_log log_data) { fprintf(log_data.log_file, "\n\n"); fprintf(log_data.log_file, "#############################################################\n"); fprintf(log_data.log_file, - "Tensor Number %ld and Type %d \n leaf1 len %d, leaf2 len %d, Node len %d\n", - log_data.num_of_op, log_data.kernel_type, log_data.leaf1_len, log_data.leaf2_len, + "Tensor Number %ld and Type %s \n leaf1 len %d, leaf2 len %d, Node len %d\n", + log_data.num_of_op, ggml_op_name(log_data.kernel_type), log_data.leaf1_len, log_data.leaf2_len, log_data.node_len); fprintf(log_data.log_file, "############################################################\n"); fprintf(log_data.log_file, "\n\n"); @@ -485,6 +486,13 @@ static txe_compute_pipeline_state_s tsi_kernel_setup(enum ggml_tsavorite_kernel_ flag = true; break; } + case GGML_TSAVORITE_KERNEL_TYPE_SOFT_MAX: + { + kernel_pipeline->_mlir_fptr_3_input[DATA_TYPE_F32_INDEX] = &_mlir_ciface_txe_soft_max_host; + kernel_pipeline->kernel_name = "TXE_SOFTMAX"; + flag = true; + break; + } default: break; } @@ -634,6 +642,7 @@ static struct ggml_backend_tsavorite_context *ggml_tsavorite_init(ggml_backend_d GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_SILU, true); GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_RMS_NORM, true); GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_SWIGLU, true); + GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_SOFT_MAX, true); } GGML_TSAVORITE_LOG_INFO("End %s\n", __func__); @@ -746,6 +755,7 @@ static bool ggml_tsavorite_supports_op(const struct ggml_backend_tsavorite_devic case GGML_OP_SQR: case GGML_OP_SIN: case GGML_OP_RMS_NORM: + case GGML_OP_SOFT_MAX: break; case GGML_OP_GLU: { @@ -927,6 +937,15 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, enum ggml_tsavorite_input_tensors_count num_of_input_tensors; tensor_log log_data; + MemRefDescriptor* buf = create_mlir_buf(96); + + if (!buf) { + GGML_TSAVORITE_LOG_ERROR("tsi_alloc failied for creating memory for buf \n"); + return GGML_STATUS_ABORTED; + } + buf->offset = 0; + buf->data = buf->base = (void *)(buf+1); + for (int i = 0; i < cgraph->n_nodes; i++) { int32_t kernel_sub_type=-1; #if defined(GGML_PERF) || defined(GGML_PERF_RELEASE) || defined(GGML_PERF_DETAIL) @@ -982,6 +1001,10 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, kernel_type = GGML_TSAVORITE_KERNEL_TYPE_RMS_NORM; num_of_input_tensors = TSAVORITE_UNARY_INPUT_TENSORS; break; + case GGML_OP_SOFT_MAX: + kernel_type = GGML_TSAVORITE_KERNEL_TYPE_SOFT_MAX; + num_of_input_tensors = TSAVORITE_TWO_INPUT_TENSORS; + break; case GGML_OP_GLU: kernel_type = tsi_glu_kernel_type(node); if (!src1) @@ -1023,7 +1046,8 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, } if (!ctx->kernels[kernel_type].pipeline || - (!ctx->kernels[kernel_type].pipeline->_mlir_fptr_2_input[kernel_sub_type] && + (!ctx->kernels[kernel_type].pipeline->_mlir_fptr_3_input[kernel_sub_type] && + !ctx->kernels[kernel_type].pipeline->_mlir_fptr_2_input[kernel_sub_type] && !ctx->kernels[kernel_type].pipeline->_mlir_fptr_1_input[kernel_sub_type])) { GGML_TSAVORITE_LOG_ERROR("Kernel Type %d, not supported \n", kernel_type); return GGML_STATUS_ABORTED; @@ -1091,7 +1115,7 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, log_data.node_len = num_elem_node; log_data.log_file = tsi_op_log_file; log_data.num_of_op = num_of_op; - log_data.kernel_type = kernel_type; + log_data.kernel_type = node->op; log_data.data_type = GGML_TSAVORITE_TENSOR_HEADER; ggml_tsi_log_tensor_data(log_data); @@ -1108,36 +1132,108 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, ggml_tensor *dst = node; const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS - - for (int ir = 0; ir < nr; ++ir) { - const int64_t i03 = ir / (ne02 * ne01); - const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; - const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float *dst_ptr = (float *)((char *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); - float *src0_ptr = (float *)((char *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); - float *src1_ptr = (float *)((char *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); - - // The following below code operates exclusively on Rank 0 - // (i.e., the first dimension) for all blob-related processing. - - for (int64_t r = 0; r < nr0; ++r) { - srcP0->shape[0] = ne10; - srcP1->shape[0] = ne10; - nodeP->shape[0] = ne10; - srcP1->data = srcP1->base = (void *)(src1_ptr); - srcP0->data = srcP0->base = (void *)(src0_ptr + r * ne10); - nodeP->data = nodeP->base = (void *)(dst_ptr + r * ne10); - // kernel call - ctx->kernels[kernel_type].pipeline->_mlir_fptr_2_input[kernel_sub_type](srcP0, srcP1, nodeP); - ++device->stats.op_run_count[kernel_type].num_of_kernel_call; - } + /* The current SoftMax implementation does not consider the src2 input, + * as none of the popular models we currently use require it. + * However, for future enhancements to SOFT_MAX, we plan to support src2 + * for sinking-based maximization. In that case, src2 will be used to + * recalculate the maximum value. + */ + if( kernel_type == GGML_TSAVORITE_KERNEL_TYPE_SOFT_MAX) { + const ggml_tensor * src2 = dst->src[2]; + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + // TODO: is this supposed to be ceil instead of floor? + // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + // sinks + const float * sk = src2 ? (float *)((char *) src2->data) : nullptr; + //here src2 is NULL for particular model hence u can ignore this for now + if (src2) { + printf("\n ANOOP src2 is not null\n"); + } + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01 += 1) { + const int64_t i11 = i01; + const int64_t i12 = i02%ne12; + const int64_t i13 = i03%ne13; + + // ALiBi + const uint32_t h = i02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + + srcP0->shape[0] = ne00; + srcP1->shape[0] = ne00; + nodeP->shape[0] = ne00; + srcP1->data = srcP1->base = (void *)(mp_f32); + srcP0->data = srcP0->base = (void *)(sp); + nodeP->data = nodeP->base = (void *)(dp); + + float *val = (float *)buf->data; + val[0] = scale; + ctx->kernels[kernel_type].pipeline->_mlir_fptr_3_input[kernel_sub_type](srcP0, srcP1, nodeP, buf); + } + } + } + } else { + GGML_TENSOR_BINARY_OP_LOCALS + + for (int ir = 0; ir < nr; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float *dst_ptr = (float *)((char *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + float *src0_ptr = (float *)((char *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + float *src1_ptr = (float *)((char *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + // The following below code operates exclusively on Rank 0 + // (i.e., the first dimension) for all blob-related processing. + + for (int64_t r = 0; r < nr0; ++r) { + srcP0->shape[0] = ne10; + srcP1->shape[0] = ne10; + nodeP->shape[0] = ne10; + srcP1->data = srcP1->base = (void *)(src1_ptr); + srcP0->data = srcP0->base = (void *)(src0_ptr + r * ne10); + nodeP->data = nodeP->base = (void *)(dst_ptr + r * ne10); + // kernel call + ctx->kernels[kernel_type].pipeline->_mlir_fptr_2_input[kernel_sub_type](srcP0, srcP1, nodeP); + ++device->stats.op_run_count[kernel_type].num_of_kernel_call; + } + } } if (ggml_tsavorite_log_type_val == GGML_TSAVORITE_LOG_DEBUG) { @@ -1184,7 +1280,7 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, log_data.node_len = num_elem_src0; log_data.log_file = tsi_op_log_file; log_data.num_of_op = num_of_op; - log_data.kernel_type = kernel_type; + log_data.kernel_type = node->op; log_data.data_type = GGML_TSAVORITE_TENSOR_HEADER; ggml_tsi_log_tensor_data(log_data); @@ -1214,15 +1310,6 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, // Although only 32 elements are strictly necessary, reducing this would require changes to the RMS kernel. // The remaining 32 elements are used to store src0->ne[0], replicated across each of the last 32 entries. - MemRefDescriptor* buf = create_mlir_buf(96); - - if (!buf) { - GGML_TSAVORITE_LOG_ERROR("tsi_alloc failied for creating memory for buf \n"); - return GGML_STATUS_ABORTED; - } - buf->offset = 0; - buf->data = buf->base = (void *)(buf+1); - float *val = (float *)buf->data; int i; for(i=64; i <= 95; ++i) @@ -1250,6 +1337,7 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, } ctx->kernels[kernel_type].pipeline->_mlir_fptr_2_input[kernel_sub_type](srcP0, nodeP, buf); + } else { // kernel call @@ -1460,6 +1548,7 @@ static void ggml_backend_tsavorite_log_allocated_size(txe_device_s device, size_ static ggml_backend_buffer_t ggml_backend_tsavorite_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { GGML_TSAVORITE_LOG_INFO("Start %s\n", __func__); + tsi_log_setup(); struct ggml_backend_tsavorite_buffer_context *ctx = (struct ggml_backend_tsavorite_buffer_context *)calloc( 1, sizeof(struct ggml_backend_tsavorite_buffer_context)); @@ -1984,6 +2073,7 @@ static bool ggml_backend_tsavorite_device_offload_op(ggml_backend_dev_t dev, case GGML_OP_SQR: case GGML_OP_SIN: case GGML_OP_RMS_NORM: + case GGML_OP_SOFT_MAX: break; case GGML_OP_GLU: { From 41137ce544999f4f5efa320ec50624515bb64d07 Mon Sep 17 00:00:00 2001 From: Anoop Kapoor Date: Fri, 17 Oct 2025 12:10:19 -0700 Subject: [PATCH 2/2] Added memory Alignment for 128 Bytes --- ggml/include/ggml-tsavorite.h | 2 +- ggml/src/ggml-tsavorite/ggml-tsavorite.cpp | 125 +++++++++++++-------- tsi-pkg-build.sh | 6 +- 3 files changed, 82 insertions(+), 51 deletions(-) diff --git a/ggml/include/ggml-tsavorite.h b/ggml/include/ggml-tsavorite.h index e455c2e6784de..79a1ff155ec4c 100644 --- a/ggml/include/ggml-tsavorite.h +++ b/ggml/include/ggml-tsavorite.h @@ -213,7 +213,7 @@ extern void ggml_tsi_log_tensor_data(tensor_log log_data); // GGML supports tensors with a maximum rank of 4 #define MEM_REF_DESCRIPTOR_RANK 4 -#define TSI_TVU_LOAD_SIZE 32 +#define TSI_TVU_MEM_ALIGN 128 // // backend API diff --git a/ggml/src/ggml-tsavorite/ggml-tsavorite.cpp b/ggml/src/ggml-tsavorite/ggml-tsavorite.cpp index 8e5e575f69ac5..c2cf6d5b5141b 100644 --- a/ggml/src/ggml-tsavorite/ggml-tsavorite.cpp +++ b/ggml/src/ggml-tsavorite/ggml-tsavorite.cpp @@ -52,6 +52,38 @@ typedef struct _txe_command_buffer_t *txe_command_buffer_s; #endif /* USE_COMMAND_BUFFERS */ typedef struct ggml_backend_tsavorite_buffer ggml_backend_tsavorite_buffer_s; +const int Rank = MEM_REF_DESCRIPTOR_RANK; +MemRefDescriptor* glob_buf; + +template +// Assumes tsi_alloc is available and returns a pointer to allocated memory +static MemRefDescriptor* create_mlir_buf(int K) { + // TVU load size (e.g., 32 for 1024-bit vector with 32-bit elements) + const int32_t mem_align = TSI_TVU_MEM_ALIGN; + // we are supporting only float or F32 + int data_type_len = 4; + // MemRef Header also added + int total_bytes = (sizeof(MemRefDescriptor) + 4*K); + + // Round up K to the next multiple of tvu_size + int32_t total_align_bytes = ((total_bytes % mem_align) != 0) ? ((total_bytes / mem_align) + 1) * mem_align : total_bytes; + + // Allocate memory dynamically: space for header + data + MemRefDescriptor* header = (MemRefDescriptor*) tsi_alloc(total_align_bytes); + + if (!header) { + return header; + } + // Advance pointer to skip header and get to data + int32_t* data = (int32_t*)(header + 1); + + for (int32_t i = 0; i < K; ++i) { + data[i] = 0; + } + return header; +} + + struct _txe_device_t { char name[100]; uint32_t max_buf_len; @@ -343,7 +375,6 @@ static void _mlir_ciface_txe_add_test (void *src0, void *src1, void *res) if (!src0 || !src1 || !res) return; - const int Rank = MEM_REF_DESCRIPTOR_RANK; MemRefDescriptor *srcP0, *srcP1, *nodeP; srcP0 = (MemRefDescriptor *)src0; srcP1 = (MemRefDescriptor *)src1; @@ -368,7 +399,6 @@ static void _mlir_ciface_txe_mult_test (void *src0, void *src1, void *res) if (!src0 || !src1 || !res) return; - const int Rank = MEM_REF_DESCRIPTOR_RANK; MemRefDescriptor *srcP0, *srcP1, *nodeP; srcP0 = (MemRefDescriptor *)src0; srcP1 = (MemRefDescriptor *)src1; @@ -489,6 +519,7 @@ static txe_compute_pipeline_state_s tsi_kernel_setup(enum ggml_tsavorite_kernel_ case GGML_TSAVORITE_KERNEL_TYPE_SOFT_MAX: { kernel_pipeline->_mlir_fptr_3_input[DATA_TYPE_F32_INDEX] = &_mlir_ciface_txe_soft_max_host; + //kernel_pipeline->_mlir_fptr_2_input[DATA_TYPE_F16_INDEX] = &_mlir_ciface_txe_soft_max_16_host; kernel_pipeline->kernel_name = "TXE_SOFTMAX"; flag = true; break; @@ -553,7 +584,11 @@ static void *ggml_tsavorite_host_malloc(size_t n) { void *data = NULL; GGML_TSAVORITE_LOG_INFO("Start %s\n", __func__); GGML_TSAVORITE_LOG_INFO("\n Allocating memory from tsi_alloc with size %ld \n", n); - data = tsi_alloc(n); + + const int32_t mem_align = TSI_TVU_MEM_ALIGN; + int total_align_bytes = (n/mem_align +1)*mem_align; + data = tsi_alloc(total_align_bytes); + GGML_TSAVORITE_LOG_CONT("\n Allocating memory from tsi_alloc with size %ld starting memory %p\n", n, data); @@ -644,6 +679,12 @@ static struct ggml_backend_tsavorite_context *ggml_tsavorite_init(ggml_backend_d GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_SWIGLU, true); GGML_TSAVORITE_KERNEL(GGML_TSAVORITE_KERNEL_TYPE_SOFT_MAX, true); } + glob_buf = create_mlir_buf(96); + if (!glob_buf) { + GGML_TSAVORITE_LOG_ERROR("tsi_alloc failied for creating memory for buf \n"); + free(ctx); + return NULL; + } GGML_TSAVORITE_LOG_INFO("End %s\n", __func__); return ctx; @@ -755,7 +796,9 @@ static bool ggml_tsavorite_supports_op(const struct ggml_backend_tsavorite_devic case GGML_OP_SQR: case GGML_OP_SIN: case GGML_OP_RMS_NORM: - case GGML_OP_SOFT_MAX: + #ifdef GGML_TARGET_POSIX + case GGML_OP_SOFT_MAX: + #endif /* GGML_TARGET_POSIX */ break; case GGML_OP_GLU: { @@ -811,31 +854,6 @@ static void ggml_tsavorite_decompose_unary_kernel(uint32_t num_elem, ggml_tensor return; } -template -// Assumes tsi_alloc is available and returns a pointer to allocated memory -static MemRefDescriptor* create_mlir_buf(int K) { - // TVU load size (e.g., 32 for 1024-bit vector with 32-bit elements) - const int32_t tvu_size = TSI_TVU_LOAD_SIZE; - - // Round up K to the next multiple of tvu_size - int32_t num_of_elem = ((K % tvu_size) != 0) ? ((K / tvu_size) + 1) * tvu_size : K; - - // Allocate memory dynamically: space for header + data - MemRefDescriptor* header = (MemRefDescriptor*) tsi_alloc( - sizeof(MemRefDescriptor) + num_of_elem * sizeof(float) - ); - - if (!header) { - return header; - } - // Advance pointer to skip header and get to data - int32_t* data = (int32_t*)(header + 1); - - for (int32_t i = 0; i < num_of_elem; ++i) { - data[i] = 0; - } - return header; -} static enum ggml_tsavorite_kernel_type tsi_glu_kernel_type(struct ggml_tensor *node) { const ggml_glu_op op = ggml_get_glu_op(node); @@ -926,7 +944,6 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, return GGML_STATUS_FAILED; } // MemRefDescriptor - const int Rank = MEM_REF_DESCRIPTOR_RANK; MemRefDescriptor *srcP0, *srcP1, *nodeP; struct ggml_tensor *src0, *src1, *node; uint32_t num_elem_src0, num_elem_src1, num_elem_node; @@ -937,14 +954,6 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, enum ggml_tsavorite_input_tensors_count num_of_input_tensors; tensor_log log_data; - MemRefDescriptor* buf = create_mlir_buf(96); - - if (!buf) { - GGML_TSAVORITE_LOG_ERROR("tsi_alloc failied for creating memory for buf \n"); - return GGML_STATUS_ABORTED; - } - buf->offset = 0; - buf->data = buf->base = (void *)(buf+1); for (int i = 0; i < cgraph->n_nodes; i++) { int32_t kernel_sub_type=-1; @@ -968,6 +977,21 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, printf("\n kernel_sub_type not suppored\n"); return GGML_STATUS_ABORTED; } + + if (node->op == GGML_OP_RMS_NORM || node->op == GGML_OP_SOFT_MAX) { + if (!glob_buf) { + GGML_TSAVORITE_LOG_ERROR("tsi_alloc failied for creating memory for buf \n"); + return GGML_STATUS_ABORTED; + } + glob_buf->offset = 0; + glob_buf->data = glob_buf->base = (void *)(glob_buf+1); + + float *vall = (float *)glob_buf->data; + int ii; + for(ii=0; ii <= 95; ++ii) + vall[ii] = 0; + } + switch (node->op) { case GGML_OP_ADD: kernel_type = GGML_TSAVORITE_KERNEL_TYPE_ADD; @@ -1115,6 +1139,7 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, log_data.node_len = num_elem_node; log_data.log_file = tsi_op_log_file; log_data.num_of_op = num_of_op; + //log_data.kernel_type = kernel_type; log_data.kernel_type = node->op; log_data.data_type = GGML_TSAVORITE_TENSOR_HEADER; @@ -1169,7 +1194,7 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, const float * sk = src2 ? (float *)((char *) src2->data) : nullptr; //here src2 is NULL for particular model hence u can ignore this for now if (src2) { - printf("\n ANOOP src2 is not null\n"); + printf("\n src2 is not null for SOFT_MAX\n"); } for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -1196,9 +1221,10 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, srcP0->data = srcP0->base = (void *)(sp); nodeP->data = nodeP->base = (void *)(dp); - float *val = (float *)buf->data; + float *val = (float *)glob_buf->data; val[0] = scale; - ctx->kernels[kernel_type].pipeline->_mlir_fptr_3_input[kernel_sub_type](srcP0, srcP1, nodeP, buf); + ctx->kernels[kernel_type].pipeline->_mlir_fptr_3_input[kernel_sub_type](srcP0, srcP1, nodeP, glob_buf); + ++device->stats.op_run_count[kernel_type].num_of_kernel_call; } } } @@ -1280,6 +1306,7 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, log_data.node_len = num_elem_src0; log_data.log_file = tsi_op_log_file; log_data.num_of_op = num_of_op; + //log_data.kernel_type = kernel_type; log_data.kernel_type = node->op; log_data.data_type = GGML_TSAVORITE_TENSOR_HEADER; @@ -1310,7 +1337,8 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, // Although only 32 elements are strictly necessary, reducing this would require changes to the RMS kernel. // The remaining 32 elements are used to store src0->ne[0], replicated across each of the last 32 entries. - float *val = (float *)buf->data; + + float *val = (float *)glob_buf->data; int i; for(i=64; i <= 95; ++i) val[i] = node->ne[0]; @@ -1336,7 +1364,7 @@ static enum ggml_status ggml_tsavorite_graph_compute(ggml_backend_t backend, strides = strides * src0->ne[i]; } - ctx->kernels[kernel_type].pipeline->_mlir_fptr_2_input[kernel_sub_type](srcP0, nodeP, buf); + ctx->kernels[kernel_type].pipeline->_mlir_fptr_2_input[kernel_sub_type](srcP0, nodeP, glob_buf); } else { @@ -1442,7 +1470,6 @@ static void *ggml_backend_tsavorite_buffer_get_base(ggml_backend_buffer_t buffer static ggml_status ggml_backend_tsavorite_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor *tensor) { GGML_TSAVORITE_LOG_INFO("Start %s\n", __func__); - const int Rank = MEM_REF_DESCRIPTOR_RANK; MemRefDescriptor tensor_data_header; tensor->data = (void *)(sizeof(tensor_data_header) + (char *)tensor->data); GGML_TSAVORITE_LOG_INFO("End %s\n", __func__); @@ -1633,7 +1660,6 @@ static size_t ggml_backend_tsavorite_buffer_type_get_alloc_size(ggml_backend_buf GGML_TSAVORITE_LOG_ERROR("\n tsavorite device is NULL \n"); return 0; } - const int Rank = MEM_REF_DESCRIPTOR_RANK; MemRefDescriptor tensor_data_header; ggml_backend_tsavorite_device_rel( (struct ggml_backend_tsavorite_device_context *)buft->device->context); @@ -1645,7 +1671,10 @@ static size_t ggml_backend_tsavorite_buffer_type_get_alloc_size(ggml_backend_buf // Add 128-byte buffer to avoid crossing memory boundaries during TVU 1024-bit operations. // TVU processes data in 1024-bit chunks, so the last elements may exceed allocated space without this padding. - return (sizeof(tensor_data_header) + ggml_nbytes(tensor) + 128); + const int32_t mem_align = TSI_TVU_MEM_ALIGN; + // I also added extra Padding buffer + size_t n = (((sizeof(tensor_data_header) + ggml_nbytes(tensor))/mem_align +1)*mem_align + mem_align); + return (n); TSI_UNUSED(buft); } @@ -2073,7 +2102,9 @@ static bool ggml_backend_tsavorite_device_offload_op(ggml_backend_dev_t dev, case GGML_OP_SQR: case GGML_OP_SIN: case GGML_OP_RMS_NORM: - case GGML_OP_SOFT_MAX: + #ifdef GGML_TARGET_POSIX + case GGML_OP_SOFT_MAX: + #endif /* GGML_TARGET_POSIX */ break; case GGML_OP_GLU: { diff --git a/tsi-pkg-build.sh b/tsi-pkg-build.sh index 38da84e3826ec..a436189629097 100755 --- a/tsi-pkg-build.sh +++ b/tsi-pkg-build.sh @@ -38,11 +38,11 @@ cd ../../ echo 'building llama.cp, ggml for tsavorite and other binary for posix' if [ "$(echo "$1" | tr '[:upper:]' '[:lower:]')" = "release" ]; then - cmake -B build-posix -DGGML_TSAVORITE=ON -DGGML_TSAVORITE_TARGET=posix -DCMAKE_C_FLAGS="-DGGML_PERF_RELEASE" -DCMAKE_CXX_FLAGS="-DGGML_PERF_RELEASE" + cmake -B build-posix -DGGML_TSAVORITE=ON -DGGML_TSAVORITE_TARGET=posix -DCMAKE_C_FLAGS="-DGGML_PERF_RELEASE -DGGML_TARGET_POSIX" -DCMAKE_CXX_FLAGS="-DGGML_PERF_RELEASE -DGGML_TARGET_POSIX" elif [ "$(echo "$1" | tr '[:upper:]' '[:lower:]')" = "debug" ]; then - cmake -B build-posix -DGGML_TSAVORITE=ON -DGGML_TSAVORITE_TARGET=posix -DCMAKE_C_FLAGS="-DGGML_PERF_DETAIL" -DCMAKE_CXX_FLAGS="-DGGML_PERF_DETAIL" + cmake -B build-posix -DGGML_TSAVORITE=ON -DGGML_TSAVORITE_TARGET=posix -DCMAKE_C_FLAGS="-DGGML_PERF_DETAIL -DGGML_TARGET_POSIX" -DCMAKE_CXX_FLAGS="-DGGML_PERF_DETAIL -DGGML_TARGET_POSIX" else - cmake -B build-posix -DGGML_TSAVORITE=ON -DGGML_TSAVORITE_TARGET=posix -DCMAKE_C_FLAGS="-DGGML_PERF" -DCMAKE_CXX_FLAGS="-DGGML_PERF" + cmake -B build-posix -DGGML_TSAVORITE=ON -DGGML_TSAVORITE_TARGET=posix -DCMAKE_C_FLAGS="-DGGML_PERF -DGGML_TARGET_POSIX" -DCMAKE_CXX_FLAGS="-DGGML_PERF -DGGML_TARGET_POSIX" fi cmake --build build-posix --config Release