Skip to content

Commit

Permalink
Added optimisation for matrix times vector to avoid allocating large …
Browse files Browse the repository at this point in the history
…arrays.
  • Loading branch information
Gareth Aneurin Tribello authored and Gareth Aneurin Tribello committed Jul 19, 2024
1 parent 94d33a9 commit 033d957
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
20 changes: 10 additions & 10 deletions src/adjmat/AdjacencyMatrixBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,10 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi
unsigned base = 3*getNumberOfAtoms(); for(unsigned j=0; j<9; ++j) myvals.updateIndex( w_ind, base+j );
// And the indices for the derivatives of the row of the matrix
if( chainContinuesAfterThisAction() ) {
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
}
}

Expand Down Expand Up @@ -355,12 +355,12 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi
myvals.addDerivative( z_index, base+2, -atom[0] ); myvals.addDerivative( z_index, base+5, -atom[1] ); myvals.addDerivative( z_index, base+8, -atom[2] );
for(unsigned k=0; k<9; ++k) { myvals.updateIndex( x_index, base+k ); myvals.updateIndex( y_index, base+k ); myvals.updateIndex( z_index, base+k ); }
if( chainContinuesAfterThisAction() ) {
for(unsigned k=1; k<4; ++k) {
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
}
for(unsigned k=1; k<4; ++k) {
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/core/ActionWithVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class ActionWithVector:
/// This is overridden in ActionWithMatrix
virtual void getAllActionLabelsInMatrixChain( std::vector<std::string>& matchain ) const {}
/// Get the number of derivatives in the stream
void getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat );
virtual void getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat );
/// Get every the label of every value that is calculated in this chain
void getAllActionLabelsInChain( std::vector<std::string>& mylabels ) const ;
/// We override clearInputForces here to ensure that forces are deleted from all values
Expand Down Expand Up @@ -186,7 +186,7 @@ bool ActionWithVector::actionInChain() const {
return (action_to_do_before!=NULL);
}

inline
inline
bool ActionWithVector::chainContinuesAfterThisAction() const {
return (action_to_do_after!=NULL);
}
Expand Down
13 changes: 12 additions & 1 deletion src/matrixtools/MatrixTimesVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class MatrixTimesVector : public ActionWithMatrix {
explicit MatrixTimesVector(const ActionOptions&);
std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
unsigned getNumberOfColumns() const override { plumed_error(); }
unsigned getNumberOfDerivatives();
void getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat ) override ;
unsigned getNumberOfDerivatives() override ;
void prepare() override ;
void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
bool isInSubChain( unsigned& nder ) override { nder = arg_deriv_starts[0]; return true; }
Expand Down Expand Up @@ -162,6 +163,16 @@ void MatrixTimesVector::prepare() {
std::vector<unsigned> shape(1); shape[0] = getPntrToArgument(0)->getShape()[0]; myval->setShape(shape);
}

void MatrixTimesVector::getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat ) {
if( actionInChain() ) { ActionWithVector::getNumberOfStreamedDerivatives( nderivatives, stopat ); return; }

nderivatives = 0;
for(unsigned i=0; i<getNumberOfArguments(); ++i) {
arg_deriv_starts[i] = nderivatives;
nderivatives += getPntrToArgument(i)->getNumberOfStoredValues();
}
}

void MatrixTimesVector::performTask( const unsigned& task_index, MultiValue& myvals ) const {
if( actionInChain() ) { ActionWithMatrix::performTask( task_index, myvals ); return; }

Expand Down

1 comment on commit 033d957

@PlumedBot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found broken examples in automatic/ANGLES.tmp
Found broken examples in automatic/ANN.tmp
Found broken examples in automatic/CAVITY.tmp
Found broken examples in automatic/CLASSICAL_MDS.tmp
Found broken examples in automatic/CLUSTER_DIAMETER.tmp
Found broken examples in automatic/CLUSTER_DISTRIBUTION.tmp
Found broken examples in automatic/CLUSTER_PROPERTIES.tmp
Found broken examples in automatic/CONSTANT.tmp
Found broken examples in automatic/CONTACT_MATRIX.tmp
Found broken examples in automatic/CONTACT_MATRIX_PROPER.tmp
Found broken examples in automatic/COORDINATIONNUMBER.tmp
Found broken examples in automatic/DFSCLUSTERING.tmp
Found broken examples in automatic/DISTANCE_FROM_CONTOUR.tmp
Found broken examples in automatic/EDS.tmp
Found broken examples in automatic/EMMI.tmp
Found broken examples in automatic/ENVIRONMENTSIMILARITY.tmp
Found broken examples in automatic/FIND_CONTOUR.tmp
Found broken examples in automatic/FIND_CONTOUR_SURFACE.tmp
Found broken examples in automatic/FIND_SPHERICAL_CONTOUR.tmp
Found broken examples in automatic/FOURIER_TRANSFORM.tmp
Found broken examples in automatic/FUNCPATHGENERAL.tmp
Found broken examples in automatic/FUNCPATHMSD.tmp
Found broken examples in automatic/FUNNEL.tmp
Found broken examples in automatic/FUNNEL_PS.tmp
Found broken examples in automatic/GHBFIX.tmp
Found broken examples in automatic/GPROPERTYMAP.tmp
Found broken examples in automatic/HBOND_MATRIX.tmp
Found broken examples in automatic/INCLUDE.tmp
Found broken examples in automatic/INCYLINDER.tmp
Found broken examples in automatic/INENVELOPE.tmp
Found broken examples in automatic/INTERPOLATE_GRID.tmp
Found broken examples in automatic/LOCAL_AVERAGE.tmp
Found broken examples in automatic/MAZE_OPTIMIZER_BIAS.tmp
Found broken examples in automatic/MAZE_RANDOM_ACCELERATION_MD.tmp
Found broken examples in automatic/MAZE_SIMULATED_ANNEALING.tmp
Found broken examples in automatic/MAZE_STEERED_MD.tmp
Found broken examples in automatic/METATENSOR.tmp
Found broken examples in automatic/MULTICOLVARDENS.tmp
Found broken examples in automatic/OUTPUT_CLUSTER.tmp
Found broken examples in automatic/PAMM.tmp
Found broken examples in automatic/PCA.tmp
Found broken examples in automatic/PCAVARS.tmp
Found broken examples in automatic/PIV.tmp
Found broken examples in automatic/PLUMED.tmp
Found broken examples in automatic/PYCVINTERFACE.tmp
Found broken examples in automatic/PYTHONFUNCTION.tmp
Found broken examples in automatic/Q3.tmp
Found broken examples in automatic/Q4.tmp
Found broken examples in automatic/Q6.tmp
Found broken examples in automatic/QUATERNION.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_LINEAR_PROJ.tmp
Found broken examples in automatic/SIZESHAPE_POSITION_MAHA_DIST.tmp
Found broken examples in automatic/SPRINT.tmp
Found broken examples in automatic/TETRAHEDRALPORE.tmp
Found broken examples in automatic/TORSIONS.tmp
Found broken examples in automatic/WHAM_WEIGHTS.tmp
Found broken examples in AnalysisPP.md
Found broken examples in CollectiveVariablesPP.md
Found broken examples in MiscelaneousPP.md

Please sign in to comment.