Skip to content

Commit

Permalink
[NB] update record function handling (OpenModelica#12577)
Browse files Browse the repository at this point in the history
- [NF] correctly identify all non default constructor functions
  - inline record constructors after inlining functions in replacements
  - add tuples to the check for inlining record elemetns
  • Loading branch information
kabdelhak committed Jun 14, 2024
1 parent 5cb4181 commit aeb305e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 22 deletions.
36 changes: 22 additions & 14 deletions OMCompiler/Compiler/NBackEnd/Modules/2_Pre/NBInline.mo
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,24 @@ protected
protected
UnorderedMap<Absyn.Path, Function> replacements "rules for replacements are stored inside here";
UnorderedSet<VariablePointer> set "new iterators from function bodies";
VariablePointers variables = VarData.getVariables(varData);
algorithm
// collect functions
replacements := UnorderedMap.new<Function>(AbsynUtil.pathHash, AbsynUtil.pathEqual);
replacements := FunctionTree.fold(funcTree, function collectInlineFunctions(inline_types = inline_types), replacements);

// apply replacements
eqData := Replacements.replaceFunctions(eqData, replacements);
eqData := Replacements.replaceFunctions(eqData, variables, replacements);

// replace record constucters after functions because record operator
// functions will produce record constructors once inlined
eqData := inlineRecordsTuples(eqData, VarData.getVariables(varData));
eqData := inlineRecordsTuples(eqData, variables);

// collect new iterators from replaced function bodies
set := UnorderedSet.new(BVariable.hash, BVariable.equalName);
eqData := EqData.map(eqData, function BackendDAE.lowerEquationIterators(variables = VarData.getVariables(varData), set = set));
eqData := EqData.map(eqData, function BackendDAE.lowerEquationIterators(variables = variables, set = set));
varData := VarData.addTypedList(varData, UnorderedSet.toList(set), NBVariable.VarData.VarType.ITERATOR);
eqData := EqData.mapExp(eqData, function BackendDAE.lowerComponentReferenceExp(variables = VarData.getVariables(varData)));
eqData := EqData.mapExp(eqData, function BackendDAE.lowerComponentReferenceExp(variables = variables));
end inline;

function collectInlineFunctions
Expand Down Expand Up @@ -313,16 +314,8 @@ protected
algorithm
tmp_eqns := Pointer.access(record_eqns);
for i in 1:recordSize loop
new_lhs := Expression.nthRecordElement(i, lhs);
new_rhs := Expression.nthRecordElement(i, rhs);

// lower indexed record constructor elements
new_lhs := Expression.map(new_lhs, inlineRecordConstructorElements);
new_rhs := Expression.map(new_rhs, inlineRecordConstructorElements);

// lower the new component references of record attributes
new_lhs := Expression.map(new_lhs, function BackendDAE.lowerComponentReferenceExp(variables = variables));
new_rhs := Expression.map(new_rhs, function BackendDAE.lowerComponentReferenceExp(variables = variables));
new_lhs := inlineRecordExp(lhs, i, variables);
new_rhs := inlineRecordExp(rhs, i, variables);

// create new equation
tmp_eqn := Equation.makeAssignment(new_lhs, new_rhs, index, NBEquation.SIMULATION_STR, iter, attr);
Expand All @@ -343,6 +336,21 @@ protected
new_eqn := Equation.DUMMY_EQUATION();
end inlineRecordEquationWork;

public
function inlineRecordExp
"inlines record constructors in a single expression"
input output Expression exp;
input Integer index;
input VariablePointers variables;
algorithm
exp := Expression.nthRecordElement(index, exp);
// lower indexed record constructor elements
exp := Expression.map(exp, inlineRecordConstructorElements);
// lower the new component references of record attributes
exp := Expression.map(exp, function BackendDAE.lowerComponentReferenceExp(variables = variables));
end inlineRecordExp;

protected
function inlineRecordConstructorElements
"removes indexed constructor element calls
Constructor(a,b,c)[2] --> b"
Expand Down
31 changes: 29 additions & 2 deletions OMCompiler/Compiler/NBackEnd/Util/NBReplacements.mo
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ protected
// Backend imports
import BVariable = NBVariable;
import NBEquation.{EqData, Equation, EquationPointers};
import Inline = NBInline;
import Solve = NBSolve;
import StrongComponent = NBStrongComponent;
import NBVariable.{VarData, VariablePointers};
Expand Down Expand Up @@ -265,17 +266,19 @@ public
"replaces all function calls in the replacements map with their body expressions,
if possible."
input output EqData eqData;
input VariablePointers variables;
input UnorderedMap<Absyn.Path, Function> replacements;
algorithm
// do nothing if replacements are empty
if UnorderedMap.isEmpty(replacements) then return; end if;
eqData := EqData.mapExp(eqData, function applyFuncExp(replacements = replacements));
eqData := EqData.mapExp(eqData, function applyFuncExp(replacements = replacements, variables = variables));
end replaceFunctions;

function applyFuncExp
"Needs to be mapped with Expression.map()"
input output Expression exp "Replacement happens inside this expression";
input UnorderedMap<Absyn.Path, Function> replacements "rules for replacements are stored inside here";
input VariablePointers variables;
algorithm
exp := match exp
local
Expand Down Expand Up @@ -320,6 +323,9 @@ public
body_exp := Expression.map(body_exp, function applySimpleExp(replacements = local_replacements));
body_exp := SimplifyExp.combineBinaries(body_exp);
body_exp := SimplifyExp.simplifyDump(body_exp, true, getInstanceName(), "\n");
// inline possible record constructors
//body_exp := Expression.map(body_exp, function applyFuncTupleExp(variables = variables));


if Flags.isSet(Flags.DUMPBACKENDINLINE) then
print("[" + getInstanceName() + "] Inlining: " + Expression.toString(exp) + "\n");
Expand All @@ -331,6 +337,26 @@ public
end match;
end applyFuncExp;

function applyFuncTupleExp
input output Expression exp;
input VariablePointers variables;
protected
Type ty;
Option<Integer> sz;
list<Expression> inlined_record = {};
algorithm
ty := Expression.typeOf(exp);
sz := Type.complexSize(ty);

// if the call returns a record constructor, it has to be inlined
if Util.isSome(sz) then
for i in Util.getOption(sz):-1:1 loop
inlined_record := Inline.inlineRecordExp(exp, i, variables) :: inlined_record;
end for;
exp := Expression.TUPLE(ty, inlined_record);
end if;
end applyFuncTupleExp;

function addInputArgTpl
"adds an input to argument replacement and also adds
all record children replacements."
Expand Down Expand Up @@ -359,7 +385,8 @@ public
then list(Expression.fromCref(BVariable.getVarName(child)) for child in arg_children);

// if it is a basic record, take its elements
case Expression.RECORD() then arg.elements;
case Expression.RECORD() then arg.elements;
case Expression.TUPLE() then arg.elements;

// if the argument is a record constructor, map it to its attributes
case Expression.CALL(call = call as Call.TYPED_CALL(fn = fn)) algorithm
Expand Down
17 changes: 11 additions & 6 deletions OMCompiler/Compiler/NFFrontEnd/NFFunction.mo
Original file line number Diff line number Diff line change
Expand Up @@ -1990,12 +1990,17 @@ uniontype Function

function isNonDefaultRecordConstructor
input Function fn;
output Boolean isConstructor;
algorithm
isConstructor := match fn.path
case Absyn.Path.QUALIFIED(path = Absyn.Path.QUALIFIED(name = "'constructor'")) then true;
else false;
end match;
output Boolean b = isNonDefaultRecordConstructorPath(fn.path);
function isNonDefaultRecordConstructorPath
input Absyn.Path path;
output Boolean b;
algorithm
b := match path
case Absyn.Path.QUALIFIED(name = "'constructor'") then true;
case Absyn.Path.QUALIFIED() then isNonDefaultRecordConstructorPath(path.path);
else false;
end match;
end isNonDefaultRecordConstructorPath;
end isNonDefaultRecordConstructor;

function toDAE
Expand Down

0 comments on commit aeb305e

Please sign in to comment.