Skip to content

Commit

Permalink
Add missing documentation for fused kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 11, 2024
1 parent f505546 commit d3b6f5d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
3 changes: 3 additions & 0 deletions operators/cuda/negxplus1.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

namespace contrib {

/**
* NegXPlus1(X) = 1 - X
*/
template <typename T>
struct NegXPlus1 {
template <typename TDict>
Expand Down
9 changes: 9 additions & 0 deletions operators/cuda/scatter_nd_of_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

namespace contrib {

/**
* ScatterNDOfShape(shape, indices, updates) = ScatterND(ConstantOfShape(shape, value=0), indices, updates)
*/
template <typename T>
struct ScatterNDOfShape {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
Expand Down Expand Up @@ -71,6 +74,12 @@ struct ScatterNDOfShape {
};


/**
* MaskedScatterNDOfShape(shape, indices, updates) = ScatterND(ConstantOfShape(shape, value=0),
* indices[indices != maskedValue],
* updates[indices != maskedValue])
*
*/
template <typename T>
struct MaskedScatterNDOfShape {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
Expand Down
3 changes: 3 additions & 0 deletions operators/cuda/transpose_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

namespace contrib {

/**
* Transpose2DCast(X, to=to) = Cast(Transpose(X, perm=[1, 0]), to=to)
*/
template <typename TIN, typename TOUT>
struct Transpose2DCast {
template <typename TDict>
Expand Down

0 comments on commit d3b6f5d

Please sign in to comment.