Skip to content

Commit

Permalink
[flang] Fold NORM2() (llvm#66240)
Browse files Browse the repository at this point in the history
Fold references to the (relatively new) intrinsic function NORM2 at
compilation time when the argument(s) are all constants. (Getting this
done right involved some changes to the API of the accumulator function
objects used by the DoReduction<> template, which rippled through some
other reduction function folding code.)
  • Loading branch information
klausler authored and zahiraam committed Oct 24, 2023
1 parent b48858e commit 5290835
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 68 deletions.
36 changes: 23 additions & 13 deletions flang/lib/Evaluate/fold-integer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,26 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
}

// COUNT()
template <typename T, int MASK_KIND> class CountAccumulator {
using MaskT = Type<TypeCategory::Logical, MASK_KIND>;

public:
CountAccumulator(const Constant<MaskT> &mask) : mask_{mask} {}
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
if (mask_.At(at).IsTrue()) {
auto incremented{element.AddSigned(Scalar<T>{1})};
overflow_ |= incremented.overflow;
element = incremented.value;
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &) const {}

private:
const Constant<MaskT> &mask_;
bool overflow_{false};
};

template <typename T, int maskKind>
static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
using LogicalResult = Type<TypeCategory::Logical, maskKind>;
Expand All @@ -274,17 +294,9 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
: Folder<LogicalResult>{context}.Folding(arg[0])}) {
std::optional<int> dim;
if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
bool overflow{false};
auto accumulator{
[&mask, &overflow](Scalar<T> &element, const ConstantSubscripts &at) {
if (mask->At(at).IsTrue()) {
auto incremented{element.AddSigned(Scalar<T>{1})};
overflow |= incremented.overflow;
element = incremented.value;
}
}};
CountAccumulator<T, maskKind> accumulator{*mask};
Constant<T> result{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
if (overflow) {
if (accumulator.overflow()) {
context.messages().Say(
"Result of intrinsic function COUNT overflows its result type"_warn_en_US);
}
Expand Down Expand Up @@ -513,9 +525,7 @@ static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
element = (element.*operation)(array->At(at));
}};
OperationAccumulator<T> accumulator{*array, operation};
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
Expand Down
5 changes: 1 addition & 4 deletions flang/lib/Evaluate/fold-logical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@ static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
Scalar<T> identity) {
static_assert(T::category == TypeCategory::Logical);
using Element = Scalar<T>;
std::optional<int> dim;
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
/*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
auto accumulator{[&](Element &element, const ConstantSubscripts &at) {
element = (element.*operation)(array->At(at));
}};
OperationAccumulator accumulator{*array, operation};
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
}
return Expr<T>{std::move(ref)};
Expand Down
78 changes: 77 additions & 1 deletion flang/lib/Evaluate/fold-real.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,80 @@ static Expr<T> FoldTransformationalBessel(
return Expr<T>{std::move(funcRef)};
}

// NORM2
template <int KIND> class Norm2Accumulator {
using T = Type<TypeCategory::Real, KIND>;

public:
Norm2Accumulator(
const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
// Kahan summation of scaled elements
auto scale{maxAbs_.At(maxAbsAt_)};
if (scale.IsZero()) {
// If maxAbs is zero, so are all elements, and result
element = scale;
} else {
auto item{array_.At(at)};
auto scaled{item.Divide(scale).value};
auto square{item.Multiply(scaled).value};
auto next{square.Add(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &result) {
auto corrected{result.Add(correction_, rounding_)};
overflow_ |= corrected.flags.test(RealFlag::Overflow);
correction_ = Scalar<T>{};
auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))};
maxAbs_.IncrementSubscripts(maxAbsAt_);
overflow_ |= rescaled.flags.test(RealFlag::Overflow);
result = rescaled.value.SQRT().value;
}

private:
const Constant<T> &array_;
const Constant<T> &maxAbs_;
const Rounding rounding_;
bool overflow_{false};
Scalar<T> correction_{};
ConstantSubscripts maxAbsAt_{maxAbs_.lbounds()};
};

template <int KIND>
static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context,
FunctionRef<Type<TypeCategory::Real, KIND>> &&funcRef) {
using T = Type<TypeCategory::Real, KIND>;
using Element = typename Constant<T>::Element;
std::optional<int> dim;
const Element identity{};
if (std::optional<Constant<T>> array{
ProcessReductionArgs<T>(context, funcRef.arguments(), dim, identity,
/*X=*/0, /*DIM=*/1)}) {
MaxvalMinvalAccumulator<T, /*ABS=*/true> maxAbsAccumulator{
RelationalOperator::GT, context, *array};
Constant<T> maxAbs{
DoReduction<T>(*array, dim, identity, maxAbsAccumulator)};
Norm2Accumulator norm2Accumulator{
*array, maxAbs, context.targetCharacteristics().roundingMode()};
Constant<T> result{DoReduction<T>(*array, dim, identity, norm2Accumulator)};
if (norm2Accumulator.overflow()) {
context.messages().Say(
"NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND);
}
return Expr<T>{std::move(result)};
}
return Expr<T>{std::move(funcRef)};
}

template <int KIND>
Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
FoldingContext &context,
Expand Down Expand Up @@ -238,6 +312,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
},
sExpr->u);
}
} else if (name == "norm2") {
return FoldNorm2<T::kind>(context, std::move(funcRef));
} else if (name == "product") {
auto one{Scalar<T>::FromInteger(value::Integer<8>{1}).value};
return FoldProduct<T>(context, std::move(funcRef), one);
Expand Down Expand Up @@ -354,7 +430,7 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
return result.value;
}));
}
// TODO: dot_product, matmul, norm2
// TODO: matmul
return Expr<T>{std::move(funcRef)};
}

Expand Down
Loading

0 comments on commit 5290835

Please sign in to comment.