Skip to content

Commit

Permalink
Merge pull request #1187 from sys-bio/implement-rateOf
Browse files Browse the repository at this point in the history
Implement the rateOf construct.
  • Loading branch information
luciansmith committed Feb 25, 2024
2 parents 3188d64 + af6bb3f commit 2bc22d7
Show file tree
Hide file tree
Showing 38 changed files with 996 additions and 18 deletions.
157 changes: 156 additions & 1 deletion source/llvm/ASTNodeCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ llvm::Value* ASTNodeCodeGen::codeGen(const libsbml::ASTNode* ast)
case AST_FUNCTION_DELAY:
result = delayExprCodeGen(ast);
break;

case AST_FUNCTION_RATE_OF:
result = rateOfCodeGen(ast);
break;
case AST_LAMBDA:
result = notImplemented(ast);
break;
Expand Down Expand Up @@ -637,6 +639,159 @@ llvm::Value* ASTNodeCodeGen::delayExprCodeGen(const libsbml::ASTNode* ast)
//return codeGen(ast->getChild(0));
}

llvm::Value* ASTNodeCodeGen::rateOfCodeGen(const libsbml::ASTNode* ast)
{
if (ast->getNumChildren() != 1) {
throw_llvm_exception("AST type 'rateOf' requires exactly one child.");
}

ASTNode* child = ast->getChild(0);
if (child->getType() != AST_NAME)
{
char* formula = SBML_formulaToL3String(ast);
std::stringstream err;
err << "The rateOf csymbol may only be used on individual symbols, i.e. 'rateOf(S1)'. The expression '" << formula << "' is illegal.";
free(formula);
throw_llvm_exception(err.str())
}
string name = child->getName();
const LLVMModelDataSymbols& dataSymbols = ctx.getModelDataSymbols();
//ModelDataIRBuilder mdbuilder(modelData, dataSymbols, builder);
Value* rate = NULL;
resolver.recursiveSymbolPush("rateOf(" + name + ")");
//Looking for a rate. Our options are: local parameter, floating species, rate rule target, assignment rule target (an error), and everything else is zero.
if (resolver.isLocalParameter(name))
{
rate = ConstantFP::get(builder.getContext(), APFloat(0.0));
}
else if (dataSymbols.isIndependentFloatingSpecies(name))
{
//NOTE: if we stored rates persistently, we could use the following instead:
//rate = mdbuilder.createFloatSpeciesAmtRateLoad(name, name + "_amtRate");
const Model* model = ctx.getModel();
const Species* species = model->getSpecies(name);
bool speciesIsAmount = species->getHasOnlySubstanceUnits();
string compId = species->getCompartment();
bool compIsConst = true;

if (!speciesIsAmount) {
//Determine if the compartment changes in time.
const Rule* compRule = model->getRule(compId);
if (compRule != NULL) {
if (compRule->getTypeCode() == SBML_ASSIGNMENT_RULE) {
std::stringstream err;
err << "Unable to calculate rateOf(" << name << "), because "
<< name << " is a concentration, not an amount, but its compartment ("
<< compId << ") chages due to an assignment rule. Since we cannot"
<< " calculate derivatives on the fly, we cannot determine the rate of"
<< " change of the species " << name << " concentration";
throw_llvm_exception(err.str());
}
else {
//The only other option is a rate rule, which means the compartment changes in time.
compIsConst = false;
}
}
}

//The functions are different depending on how complicated things get, so we need references. The final Value will be created from the overallRef ASTNode.
ASTNode amtRate(AST_PLUS);
ASTNode* amtRateRef = &amtRate;
ASTNode* overallRef = &amtRate;

//Deal with conversion factors:
string conversion_factor = "";
if (model->isSetConversionFactor()) {
conversion_factor = model->getConversionFactor();
}
if (species->isSetConversionFactor()) {
conversion_factor = species->getConversionFactor();
}
if (!conversion_factor.empty()) {
amtRate.setType(AST_TIMES); //rate * conversion factor
ASTNode* unscaledRate = new ASTNode(AST_PLUS);
amtRate.addChild(unscaledRate);
ASTNode* cf = new ASTNode(AST_NAME);
cf->setName(conversion_factor.c_str());
amtRate.addChild(cf);
amtRateRef = unscaledRate;
}

//Sum all the kinetic laws * stoichiometries in which this species appears:
for (int rxn = 0; rxn < model->getNumReactions(); rxn++) {
const Reaction* reaction = model->getReaction(rxn);
if (reaction->getReactant(name) == NULL &&
reaction->getProduct(name) == NULL) {
continue;
}
ASTNode* times = new ASTNode(AST_TIMES);
ASTNode* stoich = new ASTNode(AST_NAME);
stoich->setName((reaction->getId() + ":" + name).c_str());
times->addChild(stoich);
times->addChild(reaction->getKineticLaw()->getMath()->deepCopy());
amtRateRef->addChild(times);
}

//Now deal with the case where the species symbol is a concentration, not an amount.
ASTNode scaledConcentrationRate(AST_DIVIDE);
ASTNode combinedConcentrationRate(AST_MINUS);
if (!speciesIsAmount) {
//Need to divide by the compartment size (rate / comp)
scaledConcentrationRate.addChild(overallRef->deepCopy());
ASTNode* compsize = new ASTNode(AST_NAME);
compsize->setName(compId.c_str());
scaledConcentrationRate.addChild(compsize);
overallRef = &scaledConcentrationRate;

//If the compartment changes in time, it becomes even more complicated!
if (!compIsConst) {
//The equation is: rateOf([S1]) = rateOf(S1)/C - [S1]/C * rateOf(C)
combinedConcentrationRate.addChild(overallRef->deepCopy());
ASTNode* subdiv = new ASTNode(AST_DIVIDE);
ASTNode* mult = new ASTNode(AST_TIMES);
ASTNode* species = new ASTNode(AST_NAME);
species->setName(name.c_str());
ASTNode* rateOfComp = new ASTNode(AST_FUNCTION_RATE_OF);
ASTNode* comp = new ASTNode(AST_NAME);
comp->setName(compId.c_str());
rateOfComp->addChild(comp);
mult->addChild(rateOfComp);
mult->addChild(species);
subdiv->addChild(mult);
subdiv->addChild(comp->deepCopy());
combinedConcentrationRate.addChild(subdiv);
overallRef = &combinedConcentrationRate;
}

}
//string formula = SBML_formulaToL3String(overallRef);
rate = ASTNodeCodeGen(builder, resolver, ctx, modelData).codeGenDouble(overallRef);
}
else if (dataSymbols.hasRateRule(name))
{
//NOTE: if we stored rates persistently, we could use the following instead:
//rate = mdbuilder.createRateRuleRateLoad(name, name + "_rate");
const LLVMModelSymbols& modelSymbols = ctx.getModelSymbols();
SymbolForest::ConstIterator i = modelSymbols.getRateRules().find(
name);
if (i != modelSymbols.getRateRules().end())
{
rate = ASTNodeCodeGen(builder, resolver, ctx, modelData).codeGenDouble(i->second);
}
}
else if (dataSymbols.hasAssignmentRule(name))
{
throw_llvm_exception("Unable to define the rateOf for symbol '" + name + "' as it is changed by an assignment rule.");
}
else
{
rate = ConstantFP::get(builder.getContext(), APFloat(0.0));
}
assert(rate);
resolver.recursiveSymbolPop();
return rate;
}

llvm::Value* ASTNodeCodeGen::nameExprCodeGen(const libsbml::ASTNode* ast)
{
switch(ast->getType())
Expand Down
2 changes: 2 additions & 0 deletions source/llvm/ASTNodeCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class ASTNodeCodeGen

llvm::Value *delayExprCodeGen(const libsbml::ASTNode *ast);

llvm::Value* rateOfCodeGen(const libsbml::ASTNode* ast);

llvm::Value *nameExprCodeGen(const libsbml::ASTNode *ast);

llvm::Value *realExprCodeGen(const libsbml::ASTNode *ast);
Expand Down
2 changes: 2 additions & 0 deletions source/llvm/CodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ namespace rrllvm

virtual void recursiveSymbolPop() = 0;

virtual bool isLocalParameter(const std::string& symbol) = 0;

/**
* nested conditionals (or functions?) can push a local cache block, where
* symbols would be chached. These need to be popped as these symbols are
Expand Down
6 changes: 4 additions & 2 deletions source/llvm/EvalVolatileStoichCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ namespace rrllvm {
ModelDataLoadSymbolResolver resolver(modelData, modelGenContext);
ModelDataIRBuilder mdbuilder(modelData, dataSymbols, builder);

LLVMModelDataSymbols* dataSymbolsPtr = const_cast<LLVMModelDataSymbols*>(&dataSymbols);

ASTNodeCodeGen astCodeGen(builder, resolver, modelGenContext, modelData);

const ListOfReactions *reactions = model->getListOfReactions();
Expand Down Expand Up @@ -79,7 +81,7 @@ namespace rrllvm {
assert(value && "value for species reference stoichiometry is 0");

const LLVMModelDataSymbols::SpeciesReferenceInfo &info =
dataSymbols.getNamedSpeciesReferenceInfo(p->getId());
dataSymbolsPtr->getNamedSpeciesReferenceInfo(p->getId());

mdbuilder.createStoichiometryStore(info.row, info.column,
value, p->getId());
Expand Down Expand Up @@ -117,7 +119,7 @@ namespace rrllvm {
value = builder.CreateFMul(negOne, value, "neg_" + r->getId());

const LLVMModelDataSymbols::SpeciesReferenceInfo &info =
dataSymbols.getNamedSpeciesReferenceInfo(r->getId());
dataSymbolsPtr->getNamedSpeciesReferenceInfo(r->getId());

mdbuilder.createStoichiometryStore(info.row, info.column, value,
r->getId());
Expand Down
5 changes: 5 additions & 0 deletions source/llvm/FunctionResolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,9 @@ void FunctionResolver::recursiveSymbolPop()
parentResolver.recursiveSymbolPop();
}

bool FunctionResolver::isLocalParameter(const std::string& symbol)
{
return false;
}

} /* namespace rr */
2 changes: 2 additions & 0 deletions source/llvm/FunctionResolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class FunctionResolver: public LoadSymbolResolver

virtual void recursiveSymbolPop();

virtual bool isLocalParameter(const std::string& symbol);

private:
LoadSymbolResolver& parentResolver;
const ModelGeneratorContext& modelGenContext;
Expand Down
8 changes: 8 additions & 0 deletions source/llvm/KineticLawParameterResolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,12 @@ void KineticLawParameterResolver::recursiveSymbolPop()
parentResolver.recursiveSymbolPop();
}

bool KineticLawParameterResolver::isLocalParameter(const std::string& symbol)
{
if (kineticLaw.getLocalParameter(symbol) != NULL) {
return true;
}
return kineticLaw.getParameter(symbol) != NULL;
}

} /* namespace rr */
2 changes: 2 additions & 0 deletions source/llvm/KineticLawParameterResolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace rrllvm

virtual void recursiveSymbolPop();

virtual bool isLocalParameter(const std::string& symbol);

private:
LoadSymbolResolver& parentResolver;
const libsbml::KineticLaw& kineticLaw;
Expand Down
29 changes: 27 additions & 2 deletions source/llvm/LLVMModelDataSymbols.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1625,17 +1625,42 @@ size_t LLVMModelDataSymbols::getEventBufferSize(size_t eventId) const

bool LLVMModelDataSymbols::isNamedSpeciesReference(const std::string& id) const
{
return namedSpeciesReferenceInfo.find(id) != namedSpeciesReferenceInfo.end();
if (namedSpeciesReferenceInfo.find(id) != namedSpeciesReferenceInfo.end()) {
return true;
}
if (id.find(":") != string::npos) {
//LS: This whole section is for when we ask for rateOf(specid).
string rxnid = id.substr(0, id.find(":"));
string specid = id.substr(id.find(":") + 1, id.length());
if (getReactionIndex(rxnid) != -1 && getFloatingSpeciesIndex(specid) != -1) {
return true;
}
}
return false;
}

const LLVMModelDataSymbols::SpeciesReferenceInfo&
LLVMModelDataSymbols::getNamedSpeciesReferenceInfo(const std::string& id) const
LLVMModelDataSymbols::getNamedSpeciesReferenceInfo(const std::string& id)
{
StringRefInfoMap::const_iterator i = namedSpeciesReferenceInfo.find(id);
if (i != namedSpeciesReferenceInfo.end())
{
return i->second;
}
else if (id.find(":") != string::npos) {
//LS: This whole section is for when we ask for rateOf(specid).
string rxnid = id.substr(0, id.find(":"));
string specid = id.substr(id.find(":") + 1, id.length());
int rxnIdx = getReactionIndex(rxnid);
int speciesIdx = getFloatingSpeciesIndex(specid);
if (rxnIdx == -1 || speciesIdx == -1) {
throw_llvm_exception(id + " is not a named SpeciesReference: '" + rxnid + "' and '" + specid + "' are not a valid combination of reaction and species.");
}
SpeciesReferenceInfo info =
{ static_cast<uint>(speciesIdx), static_cast<uint>(rxnIdx), Product, rxnid };
namedSpeciesReferenceInfo[id] = info;
return namedSpeciesReferenceInfo[id];
}
else
{
throw_llvm_exception(id + " is not a named SpeciesReference");
Expand Down
2 changes: 1 addition & 1 deletion source/llvm/LLVMModelDataSymbols.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class LLVMModelDataSymbols
bool isNamedSpeciesReference(const std::string& id) const;

const SpeciesReferenceInfo& getNamedSpeciesReferenceInfo(
const std::string& id) const;
const std::string& id);


/******* Conserved Moiety Section ********************************************/
Expand Down
5 changes: 5 additions & 0 deletions source/llvm/LoadSymbolResolverBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ void LoadSymbolResolverBase::recursiveSymbolPop()
symbolStack.pop_back();
}

bool LoadSymbolResolverBase::isLocalParameter(const std::string& symbol)
{
return false;
}

void LoadSymbolResolverBase::flushCache()
{
symbolCache.clear();
Expand Down
2 changes: 2 additions & 0 deletions source/llvm/LoadSymbolResolverBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class LoadSymbolResolverBase: public LoadSymbolResolver

void recursiveSymbolPop() override;

virtual bool isLocalParameter(const std::string& symbol) override;

/**
* Flush the symbol cache. This is required in branches and switch blocks as
* a symbol used in a previous block can not be re-used in the current block.
Expand Down
8 changes: 6 additions & 2 deletions source/llvm/ModelDataSymbolResolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,12 @@ namespace rrllvm
return cacheValue(symbol, args, mdbuilder.createRateRuleValueLoad(symbol));
}

LLVMModelDataSymbols* modelDataSymbolsPtr = const_cast<LLVMModelDataSymbols*>(&modelDataSymbols);

if (modelDataSymbols.isNamedSpeciesReference(symbol))
{
const LLVMModelDataSymbols::SpeciesReferenceInfo& info =
modelDataSymbols.getNamedSpeciesReferenceInfo(symbol);
modelDataSymbolsPtr->getNamedSpeciesReferenceInfo(symbol);

Value* value = mdbuilder.createStoichiometryLoad(info.row, info.column, symbol);

Expand Down Expand Up @@ -265,6 +267,8 @@ namespace rrllvm
// at this point, we have already taken care of the species amount /
// conc conversion, rest are just plain stores.

LLVMModelDataSymbols* modelDataSymbolsPtr = const_cast<LLVMModelDataSymbols*>(&modelDataSymbols);

if (modelDataSymbols.hasRateRule(symbol))
{
return mdbuilder.createRateRuleValueStore(symbol, value);
Expand All @@ -283,7 +287,7 @@ namespace rrllvm
else if (modelDataSymbols.isNamedSpeciesReference(symbol))
{
const LLVMModelDataSymbols::SpeciesReferenceInfo& info =
modelDataSymbols.getNamedSpeciesReferenceInfo(symbol);
modelDataSymbolsPtr->getNamedSpeciesReferenceInfo(symbol);

if (info.type == LLVMModelDataSymbols::MultiReactantProduct)
{
Expand Down

0 comments on commit 2bc22d7

Please sign in to comment.