Skip to content

Commit

Permalink
more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 11, 2024
1 parent 6269413 commit 8af428a
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions operators/cuda/add_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ inline void _FillOutputShape3Op(std::vector<int64_t>& dimsA,
}
}

/**
* AddOrMulSharedInput(A, B, C) = A + B, A + C ifaddition is true
* AddOrMulSharedInput(A, B, C) = A * B, A * C ifaddition is false
*
* The operator supports broadcast on first dimensions.
* A[1, J] + B[I, J] is supported,
* A[1, J, 1] + B[I, J, K] is not supported,
*/
template <typename T, bool addition>
struct AddOrMulSharedInput {
template <typename TDict>
Expand Down Expand Up @@ -61,6 +69,14 @@ struct AddOrMulSharedInput {
}
};

/**
* AddOrMulTwice(A, B, C) = A + B + C ifaddition is true
* AddOrMulTwice(A, B, C) = A * B * C ifaddition is false
*
* The operator supports broadcast on first dimensions.
* A[1, J] + B[I, J] is supported,
* A[1, J, 1] + B[I, J, K] is not supported,
*/
template <typename T, bool addition>
struct AddOrMulTwice {
template <typename TDict>
Expand Down Expand Up @@ -97,6 +113,17 @@ struct AddOrMulTwice {
}
};

/**
* AddAndMul(A, B, C) = (A + B) * C if addition_first is true
* AddAndMul(A, B, C) = A * B + C if addition_first is false
*
* The operator supports broadcast on first dimensions.
* A[1, J] + B[I, J] is supported,
* A[1, J, 1] + B[I, J, K] is not supported,
*
* If switchMiddleAxis is true, then the output is transposed, then
* AddAndMul(A, B, C, switchMiddleAxis=1) = Transpose((A + B) * C, perm=[0, 2, 1, 3])
*/
template <typename T, bool addition_first>
struct AddAndMul {
template <typename TDict>
Expand Down Expand Up @@ -154,6 +181,17 @@ struct AddAndMul {
bool switchMiddelAxis_;
};

/**
* SubAndMul(A, B, C) = (A - B) * C if subtract_first is true
* SubAndMul(A, B, C) = A * B - C if subtract_first is false
*
* The operator supports broadcast on first dimensions.
* A[1, J] + B[I, J] is supported,
* A[1, J, 1] + B[I, J, K] is not supported,
*
* If negative is true, then the output is transposed, then
* SubAndMul(A, B, C, negative=1) = (B - A) * C
*/
template <typename T, bool subtract_first>
struct SubAndMul {
template <typename TDict>
Expand Down

0 comments on commit 8af428a

Please sign in to comment.