Skip to content

Commit

Permalink
Merge pull request #3 from Shixiaowei02/dev/pass_2
Browse files Browse the repository at this point in the history
add analysis components
  • Loading branch information
jiweibo committed Apr 4, 2023
2 parents a45a037 + 53437b1 commit 6e2139a
Show file tree
Hide file tree
Showing 5 changed files with 355 additions and 24 deletions.
141 changes: 141 additions & 0 deletions paddle/infra/Analysis/DataFlow/DenseAnalysis.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "Analysis/DataFlow/DenseAnalysis.h"

namespace infra {
namespace dataflow {

bool AbstractDenseAnalysis::Initialize(Operation* top) {
VisitOperation(top);
bool ret = true;
for (auto& region : top->getRegions()) {
for (auto& block : region) {
VisitBlock(&block);
for (auto& op : block) {
ret = ret && Initialize(&op);
}
}
}
return ret;
}

bool AbstractDenseAnalysis::Visit(ProgramPoint point) {
if (auto* op = point.dyn_cast<Operation*>()) {
VisitOperation(op);
} else if (auto* block = point.dyn_cast<Block*>()) {
VisitBlock(block);
} else {
return false;
}
return true;
}

void AbstractDenseAnalysis::VisitOperation(Operation* op) {
if (auto branch = ::mlir::dyn_cast<::mlir::RegionBranchOpInterface>(op)) {
VisitRegionBranchOperation(op, branch);
} else if (auto call = ::mlir::dyn_cast<::mlir::CallOpInterface>(op)) {
VisitCallOperation(op, call);
} else {
const AbstractDenseLattice* before;
if (auto* prev = op->getPrevNode()) {
before = GetLatticeFor(op, prev);
} else if (auto* prev = op->getBlock()) {
before = GetLatticeFor(op, prev);
}
VisitOperationImpl(op, *before, GetLattice(op));
}
}

void AbstractDenseAnalysis::VisitBlock(Block* block) {
if (block->isEntryBlock()) {
if (auto callable =
::mlir::dyn_cast<CallableOpInterface>(block->getParentOp())) {
VisitCallableOperation(block, callable);
} else if (auto branch = ::mlir::dyn_cast<RegionBranchOpInterface>(
block->getParentOp())) {
VisitRegionBranchOperation(block, branch);
} else {
SetToEntryState(GetLattice(block));
}
} else {
for (auto it = block->pred_begin(); it != block->pred_end(); ++it) {
Block* pred = *it;
Operation* terminator = pred->getTerminator();
Join(GetLattice(block), *GetLatticeFor(block, terminator));
}
}
}

void AbstractDenseAnalysis::VisitRegionBranchOperation(
ProgramPoint point, RegionBranchOpInterface branch) {
auto* after = GetLattice(point);
const auto* predecessors = GetOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown());
for (Operation* op : predecessors->getKnownPredecessors()) {
const AbstractDenseLattice* before;
if (op == branch) {
if (auto* prev = op->getPrevNode()) {
before = GetLatticeFor(op, prev);
} else if (auto* prev = op->getBlock()) {
before = GetLatticeFor(op, prev);
}
} else {
before = GetLatticeFor(point, op);
}
Join(after, *before);
}
}

void AbstractDenseAnalysis::VisitCallOperation(ProgramPoint op,
CallOpInterface call) {
auto* after = GetLattice(op);
const auto* predecessors = GetOrCreateFor<PredecessorState>(op, call);
if (!predecessors->allPredecessorsKnown()) {
SetToEntryState(after);
return;
}
for (auto* predecessor : predecessors->getKnownPredecessors()) {
Join(after, *GetLatticeFor(op, predecessor));
}
}

void AbstractDenseAnalysis::VisitCallableOperation(
ProgramPoint block, CallableOpInterface callable) {
auto* after = GetLattice(block);
assert(callable.getCallableRegion() == block.get<Block*>()->getParent());
const auto* callsites = GetOrCreateFor<PredecessorState>(block, callable);
if (!callsites->allPredecessorsKnown()) {
return SetToEntryState(after);
}
for (Operation* op : callsites->getKnownPredecessors()) {
const AbstractDenseLattice* before;
if (auto* prev = op->getPrevNode()) {
before = GetLatticeFor(op, prev);
} else if (auto* prev = op->getBlock()) {
before = GetLatticeFor(op, prev);
}
Join(after, *before);
}
}

const AbstractDenseLattice* AbstractDenseAnalysis::GetLatticeFor(
ProgramPoint dependent, ProgramPoint point) {
AbstractDenseLattice* state = GetLattice(point);
AddDependency(state, dependent);
return state;
}

} // namespace dataflow
} // namespace infra
99 changes: 99 additions & 0 deletions paddle/infra/Analysis/DataFlow/DenseAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "Analysis/DataFlow/Framework.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"

namespace infra {

class RegionBranchOpInterface;

namespace dataflow {

// A dense lattice is attached to operations to represent the program
// state after execution, or to blocks to represent the program state
// at the beginning of the block. It is propagated through the analysis.
class AbstractDenseLattice : public AnalysisState {
public:
using AnalysisState::AnalysisState;

virtual ChangeStatus Join(const AbstractDenseLattice& rhs) = 0;
};

// Implements a transfer function from the lattice between operations.
class AbstractDenseAnalysis : public DataFlowAnalysis {
public:
using DataFlowAnalysis::DataFlowAnalysis;
using Operation = ::mlir::Operation;
using Block = ::mlir::Block;
using RegionBranchOpInterface = ::mlir::RegionBranchOpInterface;
using CallOpInterface = ::mlir::CallOpInterface;
using CallableOpInterface = ::mlir::CallableOpInterface;

// Traversals every operation and block and initialize them.
bool Initialize(Operation* top) override;

// Visit a program point and modifiy the state of the program.
bool Visit(ProgramPoint point) override;

protected:
virtual void VisitOperationImpl(Operation* op,
const AbstractDenseLattice& before,
AbstractDenseLattice* after) = 0;

virtual AbstractDenseLattice* GetLattice(ProgramPoint point) = 0;

virtual void SetToEntryState(AbstractDenseLattice* lattice) = 0;

const AbstractDenseLattice* GetLatticeFor(ProgramPoint dependent,
ProgramPoint point);

void Join(AbstractDenseLattice* lhs, const AbstractDenseLattice& rhs) {
PropagateIfChanged(lhs, lhs->Join(rhs));
}

protected:
// If the operation is a call or region, the state is set by control-flow.
// Otherwise it calls the transfer function.
virtual void VisitOperation(Operation* op);

void VisitRegionBranchOperation(ProgramPoint point,
RegionBranchOpInterface branch);

void VisitCallOperation(ProgramPoint point, CallOpInterface call);

void VisitCallableOperation(ProgramPoint point, CallableOpInterface callable);

void VisitBlock(Block* block);
};

template <typename LatticeT>
class DenseAnalysis : public AbstractDenseAnalysis {
static_assert(
std::is_base_of<AbstractDenseLattice, LatticeT>::value,
"The class `LatticeT` must derive from `AbstractDenseLattice`.");

public:
using AbstractDenseAnalysis::AbstractDenseAnalysis;

virtual void VisitOperation(Operation* op,
const LatticeT& before,
LatticeT* after) = 0;
};

} // namespace dataflow
} // namespace infra
7 changes: 3 additions & 4 deletions paddle/infra/Analysis/DataFlow/Framework.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ void DataFlowSolver::InitializeAndRun(Operation* top) {

DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver& solver) : solver_{solver} {}

void DataFlowAnalysis::AddDependency(AnalysisState* state,
DataFlowAnalysis* analysis,
ProgramPoint point) {
void DataFlowAnalysis::AddDependency(AnalysisState* state, ProgramPoint point) {
solver_.AddDependency(state, this, point);
}

void DataFlowAnalysis::PropagateIfChanged(AnalysisState* state, bool changed) {
void DataFlowAnalysis::PropagateIfChanged(AnalysisState* state,
ChangeStatus changed) {
solver_.PropagateIfChanged(state, changed);
}

Expand Down

0 comments on commit 6e2139a

Please sign in to comment.