Skip to content

Commit

Permalink
Rename range.stridable to range.strides by vass (chapel-lang#22441)
Browse files Browse the repository at this point in the history
This PR switches the boolean field `range.stridable` to `range.strides`
of the type `enum strideKind { one, negOne, positive, negative, any }`
according to the design discussion concluded in
  chapel-lang#17131 (comment)

The `by` operator infers the `strides` of the resulting range
based on its `step` argument.
Ex. `r by 1` produces `r` rather than a range with `stridable = true`;
`1..n by anUnsignedInteger` produces a range with `strideKind.positive`; etc.

`stridable` queries on ranges and domains are still supported without
a warning. The warnings have been added to the module and compiler code.
In this PR they are commented out and marked with "RSDW" for
"range.stridable deprecation warning". They will be enabled in a future PR.

I modified Chapel code outside of modules/internal as little
as possible. My goal was make sure that the deprecation code
works correctly. To support domain maps that have not been converted
from stridable to strideKind, the compiler converts accesses
to the former field `stridable` to the new field `strides` in
BaseRectangularDom and BaseArrOverRectangularDom.

The bulk of updates to test code caters for a mix of un-converted uses
of `stridable` that produce ranges with `strideKind.any` and
`strides`-base code produces ranges with `strideKind.positive` or other
more specific strideKinds.

The bulk of updates to test .good files updates the `strides` components
of the range and domain types from `false` to `one` and from `true` to
`positive`, `any`, etc. Some tests get compile-time warnings or errors
instead of, or in addition to, the runtime warnings or errors in those
cases where the compiler can now determine the corresponding condition
statically based on the `strides` parameter of the range or domain.

While there:

* add a future test library/standard/Reflection/primitives/ResolvesDmap
* overhaul range.displayRepresentation() to display all the range fields
* disallow a range assignment when the corresponding idxType assignment
  is illegal
* support safeCast between ranges of identical enum types, because
  domain assignment+initialization depends on it via
  chpl_assignDomainWithGetSetIndices()
* simplify some codes by merging the two branches of `if stridable`

Development history compressed into 36425cd: 42fa8bf..b8c3c86

Post-merge TODOs:

* uncomment the deprecation warnings
* switch the rest of module and test code to `strides`
* add deprecation tests
* update dyno resolver to handle `range.strides` field in Resolver::exit()
* rename `hasPosNegUnitStride` to a better name
* add tests of new assignment / initialization behaviors
  • Loading branch information
vasslitvinov authored Jun 2, 2023
2 parents f66f7f1 + 6ec17bd commit 1634178
Show file tree
Hide file tree
Showing 311 changed files with 3,565 additions and 1,792 deletions.
196 changes: 190 additions & 6 deletions compiler/AST/AggregateType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,14 +911,44 @@ static Symbol* substitutionForField(Symbol* field, SymbolMap& subs) {
return retval;
}

// Deprecated by Vass in 1.31: given `range(boundedType=...),
// redirect it to `range(bounds=...)`, with a deprecation warning.
static void replaceStridesWithStridableSE(SymExpr* se) {
if (se->symbol() == gTrue) se->replace(new SymExpr(gStrideAny));
else if (se->symbol() == gFalse) se->replace(new SymExpr(gStrideOne));
else INT_FATAL(se, "need to handle a non-param boolean");
}

// Deprecated by Vass in 1.31: with a deprecation warning, redirect
// `range(boundedType=...)` or `range(stridable=...)` to
// `range(bounds=...)` or `range(strides=...)`;
// `rect dom or arr class(stridable=...)` to `(strides=...)`
static void checkRangeDeprecations(AggregateType* at, NamedExpr* ne,
Symbol*& field) {
if (!strcmp(ne->name, "boundedType") && at->symbol->hasFlag(FLAG_RANGE)) {
USR_WARN(ne,
"range.boundedType is deprecated; please use '.bounds' instead");
field = at->getField("bounds", false);
bool isBoundedType = !strcmp(ne->name, "boundedType");
bool isStridable = !strcmp(ne->name, "stridable");
if ((isBoundedType || isStridable) && at->symbol->hasFlag(FLAG_RANGE))
{
if (isBoundedType) {
USR_WARN(ne,
"range.boundedType is deprecated; please use '.bounds' instead");
field = at->getField("bounds");
}
else { // "stridable"
#if 0 //RSDW
USR_WARN(ne,
"range.stridable is deprecated; please use '.strides' instead");
#endif
field = at->getField("strides");
replaceStridesWithStridableSE(toSymExpr(ne->actual));
}
} else if (isStridable) {
if (AggregateType* base = baseRectDsiParent(at)) {
#if 0 //RSDW
USR_WARN(ne,
"domain.stridable is deprecated; please use '.strides' instead");
#endif
field = base->getField("strides");
replaceStridesWithStridableSE(toSymExpr(ne->actual));
}
}
}

Expand Down Expand Up @@ -1198,6 +1228,31 @@ static Type* resolveFieldTypeForInstantiation(Symbol* field, CallExpr* call, con
return ret;
}

static bool hasStrideFieldToAdjust(TypeSymbol* ts) {
if (ts->hasFlag(FLAG_RANGE)) return true;
if (!strcmp(ts->name, "BaseRectangularDom") ||
!strcmp(ts->name, "BaseArrOverRectangularDom"))
return ts->getModule()->modTag == MOD_INTERNAL;
return false;
}

// Deprecated by Vass in 1.31: given `range(..., aBoolValue)`,
// redirect it to `range(..., boundKind.one|any)`, with a deprecation warning.
static void checkRangeDeprecations(AggregateType* at, CallExpr* call,
Symbol* field, Symbol*& val) {
if (hasStrideFieldToAdjust(at->symbol) && !strcmp(field->name, "strides")
&& (val->type == dtBool)) {
#if 0 //RSDW
USR_WARN(call, "%s(..., s) is deprecated when s is a boolean;"
" please use values of the type 'enum strideKind' for s instead",
at->symbol->hasFlag(FLAG_RANGE) : "range" : "domain");
#endif
if (val == gTrue) val = gStrideAny;
else if (val == gFalse) val = gStrideOne;
else INT_FATAL(call, "need to handle a non-param boolean");
}
}

static void checkTypesForInstantiation(AggregateType* at, CallExpr* call, const char* callString, Symbol* field, Symbol* val) {
const char* typeSignature = at->typeSignature;
if (field->hasFlag(FLAG_PARAM)) {
Expand Down Expand Up @@ -1274,6 +1329,7 @@ AggregateType* AggregateType::generateType(SymbolMap& subs, CallExpr* call, cons
if (val != gUninstantiated) {
retval->genericField = index;

checkRangeDeprecations(this, call, field, val); // may update 'val'
checkTypesForInstantiation(this, call, callString, field, val);

retval = retval->getInstantiation(val, index, insnPoint);
Expand Down Expand Up @@ -2156,6 +2212,130 @@ bool AggregateType::isFieldInThisClass(const char* name) const {
return retval;
}


// support for deprecation by Vass in 1.31 to implement #17131
// here through addRangeDeprecationClone(...)

// Return the parent class if it is BaseRectangularDom or
// BaseArrOverRectangularDom, otherwise return nil.
AggregateType* baseRectDsiParent(AggregateType* ag) {
if (ag->aggregateTag != AGGREGATE_CLASS)
return nullptr;

while (ag->dispatchParents.n > 0) {
AggregateType* parentAG = ag->dispatchParents.v[0];
TypeSymbol* psym = parentAG->symbol;

if (psym->hasFlag(FLAG_OBJECT_CLASS))
return nullptr;

if ( (!strcmp(psym->name, "BaseArrOverRectangularDom") ||
!strcmp(psym->name, "BaseRectangularDom") )
&& psym->getModule()->modTag == MOD_INTERNAL)
return parentAG;

ag = parentAG; // continue searching
}

return nullptr; // dummy
}


// Traverse parents until we get to class BaseRectangularDom or
// BaseArrOverRectangularDom, then return its field 'strides'.
static Symbol* stridesFieldOfBaseRectParent(AggregateType* ag) {
if (AggregateType* parent = baseRectDsiParent(ag))
return parent->getField("strides");
else
return nullptr;
}

// Returns stridesFieldOfBaseRectParent() if 'use' is in the context
// of a DSI class
Symbol* stridesFieldInDsiContext(Expr* use) {
if (TypeSymbol* parentSym = toTypeSymbol(use->parentSymbol))
if (AggregateType* ag = toAggregateType(parentSym->type))
return stridesFieldOfBaseRectParent(ag);
return nullptr;
}

AggregateType* dsiTypeBeingConstructed(CallExpr* parentCall) {
if (SymExpr* callee = toSymExpr(parentCall->baseExpr))
if (TypeSymbol* calleeTS = toTypeSymbol(callee->symbol()))
if (AggregateType* calleeAG =
toAggregateType(canonicalClassType(calleeTS->type)))
return baseRectDsiParent(calleeAG);
return nullptr;
}

// Same as baseRectDsiParent(), plus returns 'ag'
// if it is BaseRectangularDom or BaseArrOverRectangularDom,
static AggregateType* baseRectDsiParentOrSelf(AggregateType* ag) {
TypeSymbol* psym = ag->symbol;
if ( (!strcmp(psym->name, "BaseArrOverRectangularDom") ||
!strcmp(psym->name, "BaseRectangularDom") )
&& psym->getModule()->modTag == MOD_INTERNAL)
return ag;
else
return baseRectDsiParent(ag);
}

// Replaces all uses of stridesFml2 with sblFormal.
static void replaceStridesWithStridableArg(ArgSymbol* stridesFml2,
ArgSymbol* sblFormal, bool isBase) {
if (isBase) {
// For BaseRectangularDom and BaseArrOverRectangularDom, there must be
// a single use of stridesFml2 in a call(":", stridesFml2, strideKind).
// We replace it with a call("chpl_strideKind", sblFormal).
SymExpr* use = stridesFml2->getSingleUse();
CallExpr* parent = toCallExpr(use->parentExpr);
INT_ASSERT(use, parent->isNamedAstr(astrScolon));
parent->replace(new CallExpr("chpl_strideKind", sblFormal));

} else {
for_SymbolSymExprs(se, stridesFml2) {
SymExpr* replSE = new SymExpr(sblFormal);
// if 'chpl_stridable(strides)', replace all of it with 'stridable'
// if 'strides=strides', replace it with 'stridable=stridable'
CallExpr* pCall = toCallExpr(se->parentExpr);
NamedExpr* pNamed = toNamedExpr(se->parentExpr);
if (pCall != nullptr && pCall->isNamed("chpl_stridable"))
pCall->replace(replSE);
else if (pNamed != nullptr && !strcmp(pNamed->name, "strides"))
pNamed->replace(new NamedExpr("stridable", replSE));
else
se->replace(replSE);
}
}
}

// Supports deprecation by Vass in 1.31 to implement #17131:
// for BaseRectangularDom, BaseArrOverRectangularDom, and their children,
// given fn1=init(..., strides, ...), adds fn2=init(..., stridable:bool, ...).
static void addRangeDeprecationClone(AggregateType* base, AggregateType* cur,
FnSymbol* fn1) {
// the position of the 'stride' field
int stridesPos = strcmp(base->symbol->name, "BaseRectangularDom") ? 10 : 5;
ArgSymbol* stridesFml1 = fn1->getFormal(stridesPos);

// something is unexpected, bail out
if (strcmp(stridesFml1->name, "strides") ||
stridesFml1->type != gStrideAny->type ) return;

FnSymbol* fn2 = fn1->copy();
fn1->defPoint->insertAfter(new DefExpr(fn2));

ArgSymbol* stridesFml2 = fn2->getFormal(stridesPos);
INT_ASSERT(!strcmp(stridesFml2->name, "strides") &&
stridesFml2->type == gStrideAny->type);

// replace 'strides' with 'stridable' throughout
ArgSymbol* sblFormal = new ArgSymbol(INTENT_PARAM, "stridable", dtBool);
stridesFml2->defPoint->replace(new DefExpr(sblFormal));
replaceStridesWithStridableArg(stridesFml2, sblFormal, base==cur);
cur->methods.add(fn2);
}

void AggregateType::buildDefaultInitializer() {
if (builtDefaultInit == false &&
symbol->hasFlag(FLAG_REF) == false) {
Expand Down Expand Up @@ -2207,6 +2387,10 @@ void AggregateType::buildDefaultInitializer() {
checkUseBeforeDefs(fn);

methods.add(fn);

if (AggregateType* base = baseRectDsiParentOrSelf(this))
addRangeDeprecationClone(base, this, fn);

} else {
USR_FATAL(this, "Unable to generate initializer for type '%s'", this->symbol->name);
}
Expand Down
1 change: 1 addition & 0 deletions compiler/AST/symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Symbol *gDummyRef = NULL;
Symbol *gFixupRequiredToken = NULL;
Symbol *gTypeDefaultToken = NULL;
Symbol *gLeaderTag = NULL, *gFollowerTag = NULL, *gStandaloneTag = NULL;
Symbol *gStrideOne = NULL, *gStrideAny = NULL; // deprecated by Vass in 1.31
Symbol *gModuleToken = NULL;
Symbol *gNoInit = NULL;
Symbol *gSplitInit = NULL;
Expand Down
9 changes: 8 additions & 1 deletion compiler/AST/wellknown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ FnSymbol *gChplBuildLocaleId;

void gatherIteratorTags() {
forv_Vec(TypeSymbol, ts, gTypeSymbols) {
if (strcmp(ts->name, iterKindTypename) == 0) {
if (strcmp(ts->name, iterKindTypename) == 0
|| strcmp(ts->name, strideKindTypename) == 0) {
if (EnumType* enumType = toEnumType(ts->type)) {
for_alist(expr, enumType->constants) {
if (DefExpr* def = toDefExpr(expr)) {
Expand All @@ -109,6 +110,12 @@ void gatherIteratorTags() {

} else if (strcmp(name, iterKindStandaloneTagname) == 0) {
gStandaloneTag = def->sym;

} else if (strcmp(name, strideKindOneTagname) == 0) {
gStrideOne = def->sym;

} else if (strcmp(name, strideKindAnyTagname) == 0) {
gStrideAny = def->sym;
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions compiler/include/AggregateType.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ class AggregateType final : public Type {
bool mIsGenericWithDefaults;
};

// support for deprecation by Vass in 1.31 to implement #17131
AggregateType* dsiTypeBeingConstructed(CallExpr* parentCall);
AggregateType* baseRectDsiParent(AggregateType* ag);
Symbol* stridesFieldInDsiContext(Expr* use);

extern AggregateType* dtObject;

extern AggregateType* dtBytes;
Expand Down
3 changes: 3 additions & 0 deletions compiler/include/misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
#define iterKindFollowerTagname "follower"
#define iterKindStandaloneTagname "standalone"
#define iterFollowthisArgname "followThis"
#define strideKindTypename "strideKind"
#define strideKindOneTagname "one"
#define strideKindAnyTagname "any"

#define tupleInitName "chpl__init_tuple"

Expand Down
3 changes: 3 additions & 0 deletions compiler/include/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ void normalize(Expr* expr);
void checkUseBeforeDefs(FnSymbol* fn);
void addMentionToEndOfStatement(Expr* node, CallExpr* existingEndOfStatement);
Expr* partOfNonNormalizableExpr(Expr* expr);
// support for deprecation by Vass in 1.31 to implement #17131
bool tryReplaceStridable(CallExpr* parentCall, const char* name,
UnresolvedSymExpr* use);

// parallel.cpp
Type* getOrMakeRefTypeDuringCodegen(Type* type);
Expand Down
1 change: 1 addition & 0 deletions compiler/include/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ extern Symbol *gUnknown;
extern Symbol *gMethodToken;
extern Symbol *gTypeDefaultToken;
extern Symbol *gLeaderTag, *gFollowerTag, *gStandaloneTag;
extern Symbol *gStrideOne, *gStrideAny; //deprecation in 1.31 for #17131
extern Symbol *gModuleToken;
extern Symbol *gNoInit;
extern Symbol *gSplitInit;
Expand Down
47 changes: 47 additions & 0 deletions compiler/passes/normalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,8 @@ void checkUseBeforeDefs(FnSymbol* fn) {

} else if (UnresolvedSymExpr* use = toUnresolvedSymExpr(ast)) {
CallExpr* call = toCallExpr(use->parentExpr);
if (call == nullptr && isNamedExpr(use->parentExpr))
call = toCallExpr(use->parentExpr->parentExpr);

if (call == NULL ||
(call->baseExpr != use &&
Expand All @@ -762,6 +764,9 @@ void checkUseBeforeDefs(FnSymbol* fn) {
if (isFnSymbol(fn->defPoint->parentSymbol) == false) {
const char* name = use->unresolved;

if (tryReplaceStridable(call, name, use))
continue;

// Only complain one time
if (undeclared.find(name) == undeclared.end()) {
USR_FATAL_CONT(use,
Expand Down Expand Up @@ -866,6 +871,48 @@ static Symbol* theDefinedSymbol(BaseAST* ast) {
return retval;
}

static bool replaceWithStridesField(Symbol* stridesField, Expr* use) {
SET_LINENO(use);
NamedExpr* ne = toNamedExpr(use->parentExpr);
if (ne != nullptr && !strcmp(ne->name, "stridable"))
ne->replace(new NamedExpr("strides", new SymExpr(stridesField)));
else
use->replace(new SymExpr(stridesField));
return true;
}

static bool replaceWithStridableCall(Symbol* stridesField, Expr* use) {
SET_LINENO(use);
use->replace(new CallExpr("chpl_stridable", stridesField));
return true;
}

// Supports deprecation by Vass in 1.31 to implement #17131.
// chpl__buildDomainRuntimeType(..., stridable) -->
// chpl__buildDomainRuntimeType(..., strides) if we are in a DSI class.
bool tryReplaceStridable(CallExpr* parentCall, const char* name,
UnresolvedSymExpr* use)
{
if (strcmp(name, "stridable")) return false;

Symbol* stridesField = stridesFieldInDsiContext(use);
if (stridesField == nullptr) return false;

if (parentCall == nullptr)
return replaceWithStridableCall(stridesField, use);

if (UnresolvedSymExpr* callee = toUnresolvedSymExpr(parentCall->baseExpr)) {
if (!strcmp(callee->unresolved, "chpl__buildDomainRuntimeType"))
return replaceWithStridesField(stridesField, use);
}
else if (dsiTypeBeingConstructed(parentCall) != nullptr) {
return replaceWithStridesField(stridesField, use);
}

// cannot use 'strides' directly because the context may not take it
return replaceWithStridableCall(stridesField, use);
}

/************************************* | **************************************
* *
* *
Expand Down
Loading

0 comments on commit 1634178

Please sign in to comment.