Skip to content

Commit

Permalink
Merge pull request #10751 from trilinos/tcclevenger/import_and_fill_c…
Browse files Browse the repository at this point in the history
…omplete_for_bcrs

Tpetra: add BlockCrsMatrix::importAndFillComplete
  • Loading branch information
csiefer2 committed Aug 25, 2022
2 parents 78aac8e + 8e12ab3 commit 5c7e9d9
Show file tree
Hide file tree
Showing 3 changed files with 595 additions and 5 deletions.
55 changes: 55 additions & 0 deletions packages/tpetra/core/src/Tpetra_BlockCrsMatrix_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@

namespace Tpetra {

template<class BlockCrsMatrixType>
Teuchos::RCP<BlockCrsMatrixType>
importAndFillCompleteBlockCrsMatrix (const Teuchos::RCP<const BlockCrsMatrixType>& sourceMatrix,
const Import<typename BlockCrsMatrixType::local_ordinal_type,
typename BlockCrsMatrixType::global_ordinal_type,
typename BlockCrsMatrixType::node_type>& importer,
const Teuchos::RCP<const Map<typename BlockCrsMatrixType::local_ordinal_type,
typename BlockCrsMatrixType::global_ordinal_type,
typename BlockCrsMatrixType::node_type> >& domainMap = Teuchos::null,
const Teuchos::RCP<const Map<typename BlockCrsMatrixType::local_ordinal_type,
typename BlockCrsMatrixType::global_ordinal_type,
typename BlockCrsMatrixType::node_type> >& rangeMap = Teuchos::null,
const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null);

/// \class BlockCrsMatrix
/// \brief Sparse matrix whose entries are small dense square blocks,
/// all of the same dimensions.
Expand Down Expand Up @@ -378,6 +392,13 @@ class BlockCrsMatrix :
const Scalar alpha = Teuchos::ScalarTraits<Scalar>::one (),
const Scalar beta = Teuchos::ScalarTraits<Scalar>::zero ());

void
importAndFillComplete (Teuchos::RCP<BlockCrsMatrix<Scalar, LO, GO, Node> >& destMatrix,
const Import<LO, GO, Node>& importer,
const Teuchos::RCP<const map_type>& domainMap,
const Teuchos::RCP<const map_type>& rangeMap,
const Teuchos::RCP<Teuchos::ParameterList>& params = Teuchos::null) const;

/// \brief Replace values at the given (mesh, i.e., block) column
/// indices, in the given (mesh, i.e., block) row.
///
Expand Down Expand Up @@ -1195,8 +1216,42 @@ class BlockCrsMatrix :
virtual typename ::Tpetra::RowMatrix<Scalar, LO, GO, Node>::mag_type
getFrobeniusNorm () const override;
//@}

// Friend declaration for nonmember function.
template<class BlockCrsMatrixType>
friend Teuchos::RCP<BlockCrsMatrixType>
Tpetra::importAndFillCompleteBlockCrsMatrix (const Teuchos::RCP<const BlockCrsMatrixType>& sourceMatrix,
const Import<typename BlockCrsMatrixType::local_ordinal_type,
typename BlockCrsMatrixType::global_ordinal_type,
typename BlockCrsMatrixType::node_type>& importer,
const Teuchos::RCP<const Map<typename BlockCrsMatrixType::local_ordinal_type,
typename BlockCrsMatrixType::global_ordinal_type,
typename BlockCrsMatrixType::node_type> >& domainMap,
const Teuchos::RCP<const Map<typename BlockCrsMatrixType::local_ordinal_type,
typename BlockCrsMatrixType::global_ordinal_type,
typename BlockCrsMatrixType::node_type> >& rangeMap,
const Teuchos::RCP<Teuchos::ParameterList>& params);
};

template<class BlockCrsMatrixType>
Teuchos::RCP<BlockCrsMatrixType>
importAndFillCompleteBlockCrsMatrix (const Teuchos::RCP<const BlockCrsMatrixType>& sourceMatrix,
const Import<typename BlockCrsMatrixType::local_ordinal_type,
typename BlockCrsMatrixType::global_ordinal_type,
typename BlockCrsMatrixType::node_type>& importer,
const Teuchos::RCP<const Map<typename BlockCrsMatrixType::local_ordinal_type,
typename BlockCrsMatrixType::global_ordinal_type,
typename BlockCrsMatrixType::node_type> >& domainMap,
const Teuchos::RCP<const Map<typename BlockCrsMatrixType::local_ordinal_type,
typename BlockCrsMatrixType::global_ordinal_type,
typename BlockCrsMatrixType::node_type> >& rangeMap,
const Teuchos::RCP<Teuchos::ParameterList>& params)
{
Teuchos::RCP<BlockCrsMatrixType> destMatrix;
sourceMatrix->importAndFillComplete (destMatrix, importer, domainMap, rangeMap, params);
return destMatrix;
}

} // namespace Tpetra

#endif // TPETRA_BLOCKCRSMATRIX_DECL_HPP
46 changes: 41 additions & 5 deletions packages/tpetra/core/src/Tpetra_BlockCrsMatrix_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,40 @@ class GetLocalDiagCopy {
}
}

template<class Scalar, class LO, class GO, class Node>
void
BlockCrsMatrix<Scalar, LO, GO, Node>::
importAndFillComplete (Teuchos::RCP<BlockCrsMatrix<Scalar, LO, GO, Node> >& destMatrix,
const Import<LO, GO, Node>& importer,
const Teuchos::RCP<const map_type>& domainMap,
const Teuchos::RCP<const map_type>& rangeMap,
const Teuchos::RCP<Teuchos::ParameterList>& params) const
{
using Teuchos::RCP;
using Teuchos::rcp;
using this_type = BlockCrsMatrix<Scalar, LO, GO, Node>;

// Right now, we make many assumptions...
TEUCHOS_TEST_FOR_EXCEPTION(!destMatrix.is_null(), std::invalid_argument,
"Right now, assuming destMatrix is null.");
TEUCHOS_TEST_FOR_EXCEPTION(!domainMap.is_null(), std::invalid_argument,
"Right now, assuming domainMap is null.");
TEUCHOS_TEST_FOR_EXCEPTION(!rangeMap.is_null(), std::invalid_argument,
"Right now, assuming rangeMap is null.");
TEUCHOS_TEST_FOR_EXCEPTION(!params.is_null(), std::invalid_argument,
"Right now, assuming params is null.");

// BlockCrsMatrix requires a complete graph at construction.
// So first step is to import and fill complete the destGraph.
RCP<crs_graph_type> destGraph = rcp (new crs_graph_type (importer.getTargetMap(), 0));
destGraph->doImport(this->getCrsGraph(), importer, Tpetra::INSERT);
destGraph->fillComplete();

// Final step, create and import the destMatrix.
destMatrix = rcp (new this_type (*destGraph, getBlockSize()));
destMatrix->doImport(*this, importer, Tpetra::INSERT);
}

template<class Scalar, class LO, class GO, class Node>
void
BlockCrsMatrix<Scalar, LO, GO, Node>::
Expand Down Expand Up @@ -2738,23 +2772,25 @@ class GetLocalDiagCopy {
errorDuringUnpack () = 0;
{
using policy_type = Kokkos::TeamPolicy<host_exec>;
const auto policy = policy_type (numImportLIDs, 1, 1)
.set_scratch_size (0, Kokkos::PerTeam (sizeof (GO) * maxRowNumEnt +
sizeof (LO) * maxRowNumEnt +
numBytesPerValue * maxRowNumScalarEnt));
size_t scratch_per_row = sizeof(GO) * maxRowNumEnt + sizeof (LO) * maxRowNumEnt + numBytesPerValue * maxRowNumScalarEnt
+ 2 * sizeof(GO); // Yeah, this is a fudge factor

const auto policy = policy_type (numImportLIDs, 1, 1)
.set_scratch_size (0, Kokkos::PerTeam (scratch_per_row));
using host_scratch_space = typename host_exec::scratch_memory_space;

using pair_type = Kokkos::pair<size_t, size_t>;
Kokkos::parallel_for
("Tpetra::BlockCrsMatrix::unpackAndCombine: unpack", policy,
[=] (const typename policy_type::member_type& member) {
const size_t i = member.league_rank();

Kokkos::View<GO*, host_scratch_space> gblColInds
(member.team_scratch (0), maxRowNumEnt);
Kokkos::View<LO*, host_scratch_space> lclColInds
(member.team_scratch (0), maxRowNumEnt);
Kokkos::View<impl_scalar_type*, host_scratch_space> vals
(member.team_scratch (0), maxRowNumScalarEnt);


const size_t offval = offset(i);
const LO lclRow = importLIDsHost(i);
Expand Down
Loading

0 comments on commit 5c7e9d9

Please sign in to comment.