Skip to content

Commit

Permalink
Merge pull request #1180 from sys-bio/issue-piecewise-fix
Browse files Browse the repository at this point in the history
Add rootfinding breaks for when piecewise conditions change.
  • Loading branch information
luciansmith committed Feb 2, 2024
2 parents 1ce4ee9 + 5f188c2 commit 812eb79
Show file tree
Hide file tree
Showing 31 changed files with 925 additions and 661 deletions.
2 changes: 2 additions & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ if (BUILD_LLVM)
llvm/FunctionResolver.cpp
llvm/GetEventValuesCodeGen.cpp
llvm/GetInitialValuesCodeGen.cpp
llvm/GetPiecewiseTriggerCodeGen.cpp
llvm/GetValuesCodeGen.cpp
llvm/Jit.cpp
llvm/JitFactory.cpp
Expand Down Expand Up @@ -270,6 +271,7 @@ if (BUILD_LLVM)
llvm/GetEventValuesCodeGen.h
llvm/GetInitialValueCodeGenBase.h
llvm/GetInitialValuesCodeGen.h
llvm/GetPiecewiseTriggerCodeGen.h
llvm/GetValueCodeGenBase.h
llvm/GetValuesCodeGen.h
llvm/Jit.h
Expand Down
32 changes: 16 additions & 16 deletions source/CVODEIntegrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ namespace rr {

int cvodeDyDtFcn(realtype t, N_Vector cv_y, N_Vector cv_ydot, void *userData);

int cvodeRootFcn(realtype t, N_Vector y, realtype *gout, void *userData);

int cvodeEventAndPiecewiseRootFcn(realtype t, N_Vector y, realtype *gout, void *userData);

// Sets the value of an element in a N_Vector object
inline void SetVector(N_Vector v, int Index, double Value) {
Expand Down Expand Up @@ -657,11 +656,6 @@ namespace rr {
mStateVector = N_VNew_Serial(allocStateVectorSize);
variableStepPostEventState.resize(allocStateVectorSize);


// for (int i = 0; i < allocStateVectorSize; i++) {
// SetVector(mStateVector, i, 0.);
// }
// set mStateVector to the values that are currently in the model
auto states = new double[allocStateVectorSize];
mModel->getStateVector(states);

Expand Down Expand Up @@ -698,12 +692,12 @@ namespace rr {
handleCVODEError(err);
}

if (mModel->getNumEvents() > 0) {
if ((err = CVodeRootInit(mCVODE_Memory, mModel->getNumEvents(),
cvodeRootFcn)) != CV_SUCCESS) {
if (mModel->getNumEvents() + mModel->getNumPiecewiseTriggers() > 0) {
if ((err = CVodeRootInit(mCVODE_Memory, mModel->getNumEvents() + mModel->getNumPiecewiseTriggers(),
cvodeEventAndPiecewiseRootFcn)) != CV_SUCCESS) {
handleCVODEError(err);
}
rrLog(Logger::LOG_TRACE) << "CVRootInit executed.....";
rrLog(Logger::LOG_TRACE) << "CVRootInit executed for events.....";
}

/**
Expand Down Expand Up @@ -1034,16 +1028,22 @@ namespace rr {

// int (*CVRootFn)(realtype t, N_Vector y, realtype *gout, void *user_data)
// Cvode calls this to check for event changes
int cvodeRootFcn(realtype time, N_Vector y_vector, realtype *gout, void *user_data) {
CVODEIntegrator *cvInstance = (CVODEIntegrator *) user_data;
int cvodeEventAndPiecewiseRootFcn(realtype time, N_Vector y_vector, realtype* gout, void* user_data) {
CVODEIntegrator* cvInstance = (CVODEIntegrator*)user_data;

assert(cvInstance && "user data pointer is NULL on CVODE root callback");

ExecutableModel *model = cvInstance->mModel;
ExecutableModel* model = cvInstance->mModel;

double *y = NV_DATA_S(y_vector);
double* y = NV_DATA_S(y_vector);

model->getEventRoots(time, y, gout);
if (model->getNumEvents() > 0) {
model->getEventRoots(time, y, gout);
}

if (model->getNumPiecewiseTriggers() > 0) {
model->getPiecewiseTriggerRoots(time, y, gout + model->getNumEvents());
}

return CV_SUCCESS;
}
Expand Down
2 changes: 1 addition & 1 deletion source/CVODEIntegrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ namespace rr {

friend int cvodeDyDtFcn(double t, N_Vector cv_y, N_Vector cv_ydot, void *f_data);

friend int cvodeRootFcn(double t, N_Vector y, double *gout, void *g_data);
friend int cvodeEventAndPiecewiseRootFcn(double t, N_Vector y, double *gout, void *g_data);

};

Expand Down
1 change: 0 additions & 1 deletion source/llvm/CodeGenBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class CodeGenBase
builder(*mgc.getJitNonOwning()->getBuilderNonOwning()),
options(mgc.getOptions()),
function(0)
// functionPassManager(mgc.getJitNonOwning().getFunctionPassManager())
{

};
Expand Down
110 changes: 110 additions & 0 deletions source/llvm/GetPiecewiseTriggerCodeGen.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* GetPiecewiseTriggerCodeGen.cpp
*
* Created on: Aug 10, 2013
* Author: andy
*/
#pragma hdrstop
#include "GetPiecewiseTriggerCodeGen.h"
#include "LLVMException.h"
#include "ModelDataSymbolResolver.h"
#include "rrLogger.h"

#include <Poco/Logger.h>
#include <llvm/GetPiecewiseTriggerCodeGen.h>
#include <vector>

using namespace llvm;

using namespace libsbml;

namespace rrllvm
{
const char* GetPiecewiseTriggerCodeGen::FunctionName = "getPiecewiseTrigger";
const char* GetPiecewiseTriggerCodeGen::IndexArgName = "triggerIndx";

GetPiecewiseTriggerCodeGen::GetPiecewiseTriggerCodeGen(const ModelGeneratorContext& mgc)
: CodeGenBase<GetPiecewiseTriggerCodeGen_FunctionPtr>(mgc)
, piecewiseTriggers(mgc.getPiecewiseTriggers())
{
};

llvm::Value* GetPiecewiseTriggerCodeGen::codeGen()
{
// make the set init value function
llvm::Type* argTypes[] = {
llvm::PointerType::get(ModelDataIRBuilder::getStructType(this->module), 0),
llvm::Type::getInt32Ty(this->context)
};

const char* argNames[] = {
"modelData", IndexArgName
};

llvm::Value* args[] = { 0, 0 };

llvm::Type* retType = getRetType();

llvm::BasicBlock* entry = this->codeGenHeader(FunctionName, retType,
argTypes, argNames, args);

ModelDataLoadSymbolResolver resolver(args[0], this->modelGenContext);

ASTNodeCodeGen astCodeGen(this->builder, resolver, this->modelGenContext, args[0]);

// default, return NaN
llvm::BasicBlock* def = llvm::BasicBlock::Create(this->context, "default", this->function);
this->builder.SetInsertPoint(def);

llvm::Value* defRet = createRet(0);
this->builder.CreateRet(defRet);

// write the switch at the func entry point, the switch is also the
// entry block terminator
this->builder.SetInsertPoint(entry);

llvm::SwitchInst* s = this->builder.CreateSwitch(args[1], def, piecewiseTriggers->size());

for (uint i = 0; i < piecewiseTriggers->size(); ++i)
{
char block_name[64];
std::sprintf(block_name, "piecewiseTrigger_%i_block", i);
llvm::BasicBlock* block = llvm::BasicBlock::Create(this->context, block_name, this->function);
this->builder.SetInsertPoint(block);
resolver.flushCache();

llvm::Value* value = astCodeGen.codeGenBoolean((*piecewiseTriggers)[i]);

// convert type to return type
value = createRet(value);

this->builder.CreateRet(value);
s->addCase(llvm::ConstantInt::get(llvm::Type::getInt32Ty(this->context), i), block);
}

return this->verifyFunction();
}

llvm::Type* GetPiecewiseTriggerCodeGen::getRetType()
{
return llvm::Type::getInt8Ty(context);
};

llvm::Value* GetPiecewiseTriggerCodeGen::createRet(llvm::Value* value)
{
if (!value)
{
return llvm::ConstantInt::get(getRetType(), 0xff, false);
}
else
{
return builder.CreateIntCast(value, getRetType(), false);
}
}


} /* namespace rr */




59 changes: 59 additions & 0 deletions source/llvm/GetPiecewiseTriggerCodeGen.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* GetPiecewiseTriggerCodeGen.h
*
* Created on: Aug 10, 2013
* Author: andy
*/

#ifndef RRLLVMGetPiecewiseTriggerCodeGen_H_
#define RRLLVMGetPiecewiseTriggerCodeGen_H_

#include "CodeGenBase.h"
#include "ModelGeneratorContext.h"
#include "SymbolForest.h"
#include "ASTNodeCodeGen.h"
#include "ASTNodeFactory.h"
#include "ModelDataIRBuilder.h"
#include "ModelDataSymbolResolver.h"
#include "LLVMException.h"
#include "rrLogger.h"
#include <sbml/Model.h>
#include <Poco/Logger.h>
#include <vector>
#include <cstdio>

namespace rrllvm
{
//Based on GetEventTriggerCodeGen (-LS)

typedef unsigned char (*GetPiecewiseTriggerCodeGen_FunctionPtr)(LLVMModelData*, size_t);

class GetPiecewiseTriggerCodeGen :
public CodeGenBase<GetPiecewiseTriggerCodeGen_FunctionPtr>
{
public:
GetPiecewiseTriggerCodeGen(const ModelGeneratorContext& mgc);
virtual ~GetPiecewiseTriggerCodeGen() {};

llvm::Value* codeGen();

typedef GetPiecewiseTriggerCodeGen_FunctionPtr FunctionPtr;

static const char* FunctionName;
static const char* IndexArgName;

llvm::Type* getRetType();

llvm::Value* createRet(llvm::Value*);

private:
const std::vector<libsbml::ASTNode*>* piecewiseTriggers;
};


} /* namespace rr */




#endif /* RRLLVMGETVALUECODEGENBASE_H_ */
2 changes: 2 additions & 0 deletions source/llvm/Jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ namespace rrllvm {
"eventTrigger"));
modelResources->eventAssignPtr = reinterpret_cast<EventAssignCodeGen::FunctionPtr>(lookupFunctionAddress(
"eventAssign"));
modelResources->getPiecewiseTriggerPtr = reinterpret_cast<GetPiecewiseTriggerCodeGen::FunctionPtr>(lookupFunctionAddress(
"getPiecewiseTrigger"));
modelResources->evalVolatileStoichPtr = reinterpret_cast<EvalVolatileStoichCodeGen::FunctionPtr>(lookupFunctionAddress(
"evalVolatileStoich"));
modelResources->evalConversionFactorPtr = reinterpret_cast<EvalConversionFactorCodeGen::FunctionPtr>(lookupFunctionAddress(
Expand Down
42 changes: 36 additions & 6 deletions source/llvm/LLVMExecutableModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ LLVMExecutableModel::LLVMExecutableModel() :
getEventDelayPtr(0),
eventTriggerPtr(0),
eventAssignPtr(0),
getPiecewiseTriggerPtr(0),
evalVolatileStoichPtr(0),
evalConversionFactorPtr(0),
setBoundarySpeciesAmountPtr(0),
Expand Down Expand Up @@ -229,6 +230,7 @@ LLVMExecutableModel::LLVMExecutableModel(
getEventDelayPtr(modelResources->getEventDelayPtr),
eventTriggerPtr(modelResources->eventTriggerPtr),
eventAssignPtr(modelResources->eventAssignPtr),
getPiecewiseTriggerPtr(modelResources->getPiecewiseTriggerPtr),
evalVolatileStoichPtr(modelResources->evalVolatileStoichPtr),
evalConversionFactorPtr(modelResources->evalConversionFactorPtr),
setBoundarySpeciesAmountPtr(modelResources->setBoundarySpeciesAmountPtr),
Expand Down Expand Up @@ -294,6 +296,7 @@ LLVMExecutableModel::LLVMExecutableModel(std::istream& in, uint modelGeneratorOp
getEventDelayPtr = resources->getEventDelayPtr;
eventTriggerPtr = resources->eventTriggerPtr;
eventAssignPtr = resources->eventAssignPtr;
getPiecewiseTriggerPtr = resources->getPiecewiseTriggerPtr;
evalVolatileStoichPtr = resources->evalVolatileStoichPtr;
evalConversionFactorPtr = resources->evalConversionFactorPtr;
setBoundarySpeciesAmountPtr = resources->setBoundarySpeciesAmountPtr;
Expand Down Expand Up @@ -1871,6 +1874,11 @@ void LLVMExecutableModel::getEventIds(std::list<std::string>& out)
std::copy(eventIds.begin(), eventIds.end(), std::back_inserter(out));
}

int LLVMExecutableModel::getNumPiecewiseTriggers()
{
return modelData->numPiecewiseTriggers;
}

void LLVMExecutableModel::getAssignmentRuleIds(std::list<std::string>& out)
{
std::vector<std::string> arIds = symbols->getAssignmentRuleIds();
Expand Down Expand Up @@ -2462,12 +2470,6 @@ void LLVMExecutableModel::getEventRoots(double time, const double* y, double* g

if (y)
{
//memcpy(modelData->rateRuleValues, y,
// modelData->numRateRules * sizeof(double));

//memcpy(modelData->floatingSpeciesAmounts, y + modelData->numRateRules,
// modelData->numIndFloatingSpecies * sizeof(double));

modelData->rateRuleValuesAlias = const_cast<double*>(y);
modelData->floatingSpeciesAmountsAlias = const_cast<double*>(y + modelData->numRateRules);

Expand All @@ -2487,6 +2489,34 @@ void LLVMExecutableModel::getEventRoots(double time, const double* y, double* g
return;
}

void LLVMExecutableModel::getPiecewiseTriggerRoots(double time, const double* y, double* gdot)
{
modelData->time = time;

double* savedRateRules = modelData->rateRuleValuesAlias;
double* savedFloatingSpeciesAmounts = modelData->floatingSpeciesAmountsAlias;

if (y)
{
modelData->rateRuleValuesAlias = const_cast<double*>(y);
modelData->floatingSpeciesAmountsAlias = const_cast<double*>(y + modelData->numRateRules);

evalVolatileStoichPtr(modelData);
}

for (uint i = 0; i < modelData->numPiecewiseTriggers; ++i)
{
unsigned char triggered = getPiecewiseTriggerPtr(modelData, i);

gdot[i] = triggered ? 1.0 : -1.0;
}

modelData->rateRuleValuesAlias = savedRateRules;
modelData->floatingSpeciesAmountsAlias = savedFloatingSpeciesAmounts;

return;
}

double LLVMExecutableModel::getNextPendingEventTime(bool pop)
{
return pendingEvents.getNextPendingEventTime();
Expand Down

0 comments on commit 812eb79

Please sign in to comment.