Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
230 lines (217 sloc) 7.38 KB
#include <torch/csrc/jit/argument_spec.h>
namespace torch {
namespace jit {
void ArgumentSpecCreator::scan(
const TypePtr& typ,
size_t depth,
const WrittenSlots& written_slots) {
auto finishAggregate = [&](size_t pos) {
// it is possible after all the work we did to scan this aggregate,
// we found no tensors to specialize. In this case, just generate
// a skip for the whole aggregate.
bool any_spec = std::any_of(
instructions_.begin() + pos, instructions_.end(), [](Inst i) {
return i == SPECIALIZE_TENSOR;
});
if (!any_spec) {
instructions_[pos] = SKIP;
instructions_.resize(pos + 1);
} else {
instructions_.emplace_back(LEAVE);
}
};
// the simple vm that scans instructions_ has a limited stack depth,
// this prevents going deeper than that.
if (depth >= DEPTH_LIMIT) {
instructions_.emplace_back(SKIP);
}
if (typ->isSubtypeOf(TensorType::get())) {
num_tensors_++;
instructions_.emplace_back(SPECIALIZE_TENSOR);
} else if (auto tup = typ->cast<TupleType>()) {
size_t pos = instructions_.size();
instructions_.emplace_back(ENTER_TUPLE);
for (const auto& elem : tup->containedTypes()) {
scan(elem, depth + 1, written_slots);
}
finishAggregate(pos);
} else if (auto cls = typ->cast<ClassType>()) {
size_t pos = instructions_.size();
instructions_.emplace_back(ENTER_OBJECT);
for (size_t i = 0; i < cls->numAttributes(); ++i) {
auto key = cls->name() + cls->attributeNames().at(i);
// it is only safe to specialize because someone might have written to it
if (!written_slots.count(key)) {
scan(cls->containedTypes().at(i), depth + 1, written_slots);
} else {
instructions_.emplace_back(SKIP);
}
}
finishAggregate(pos);
} else {
instructions_.emplace_back(SKIP);
}
};
// this is a coarse-grained guarentee that the slots of a class will not be
// modified by the function. It works fine for things that used be read-only
// modules, but will be overly conservative when some classes are written to.
// Doing alias analysis and looking for writes to the class would be more
// accurate.
static void scanWrittenSlots(
Block* block,
ArgumentSpecCreator::WrittenSlots& written_slots) {
for (Node* n : block->nodes()) {
if (n->kind() == prim::SetAttr) {
if (auto cls = n->inputs().at(0)->type()->cast<ClassType>()) {
written_slots.insert(cls->name() + n->s(attr::name));
}
}
for (Block* subblock : n->blocks()) {
scanWrittenSlots(subblock, written_slots);
}
if (n->hasAttribute(attr::Subgraph)) {
scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots);
}
}
}
ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph)
: num_inputs_(graph.inputs().size()) {
WrittenSlots written_slots;
scanWrittenSlots(graph.block(), written_slots);
for (Value* input : graph.inputs()) {
scan(input->type(), 0, written_slots);
}
}
void ArgumentSpecCreator::dump() const {
for (Inst inst : instructions_) {
switch (inst) {
case LEAVE:
std::cout << "] ";
break;
case ENTER_TUPLE:
std::cout << "Tuple[";
break;
case ENTER_OBJECT:
std::cout << "Object[";
break;
case SKIP:
std::cout << "Skip ";
break;
case SPECIALIZE_TENSOR:
std::cout << "SpecializeTensor ";
break;
}
}
std::cout << "\n";
}
ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
const {
ArgumentSpec spec(num_tensors_);
const IValue* stack[DEPTH_LIMIT]; // The stack of IValue lists
// The stack gets initialized with the input list
stack[0] = last(input, num_inputs_).begin();
size_t stack_top = 0; // offset to the top of the stack
for (Inst inst : instructions_) {
switch (inst) {
case SPECIALIZE_TENSOR:
// consume a tensor and add to the argspec
spec.addTensor(*stack[stack_top]++, with_grad);
break;
case ENTER_TUPLE: {
// consume tuple
const IValue* iv = stack[stack_top]++;
AT_ASSERT(iv->isTuple());
// see [argspec refcounting]
auto p = *reinterpret_cast<const at::ivalue::Tuple* const*>(iv);
auto tup_ptr = &p->elements()[0];
// push list of tuple elements to the stack
stack[++stack_top] = tup_ptr;
} break;
case ENTER_OBJECT: {
// consume object
const IValue* iv = stack[stack_top]++;
AT_ASSERT(iv->isObject());
iv->toObject();
// see [argspec refcounting]
auto p = *reinterpret_cast<const at::ivalue::Object* const*>(iv);
auto obj_ptr = &p->slots()[0];
// push list of object elements to the stack
stack[++stack_top] = obj_ptr;
} break;
case SKIP:
// consume and skip an element
stack[stack_top]++;
break;
case LEAVE:
--stack_top;
break;
}
}
return spec;
}
// For every input of a given graph, returns a most detailed type that can be
// inferred for it based on this ArgumentSpec.
std::vector<TypePtr> ArgumentSpecCreator::getSpecializedTypes(
Graph& graph,
const ArgumentSpec& spec) const {
auto input_types =
fmap(graph.inputs(), [](Value* input) { return input->type(); });
std::vector<std::vector<TypePtr>> result_stack;
result_stack.emplace_back();
std::vector<const TypePtr*> input_stack = {input_types.data()};
std::vector<std::function<TypePtr()>> aggregate_creators;
size_t arg_spec_offset = 0; // number of specialized tensors seen so far
for (Inst inst : instructions_) {
switch (inst) {
case SPECIALIZE_TENSOR: {
input_stack.back()++;
auto& arg = spec.at(arg_spec_offset++);
if (!arg.defined()) {
result_stack.back().emplace_back(AutogradZeroTensorType::get());
} else {
result_stack.back().emplace_back(DimensionedTensorType::create(
arg.type(),
ConvertIntToCPUOrCUDA(arg.device()),
arg.dim(),
arg.requires_grad()));
}
} break;
case ENTER_TUPLE: {
auto tup = (*input_stack.back()++)->expect<TupleType>();
input_stack.emplace_back(tup->elements().data());
result_stack.emplace_back();
aggregate_creators.emplace_back(
[&] { return TupleType::create(result_stack.back()); });
} break;
case ENTER_OBJECT: {
auto cls = (*input_stack.back()++)->expect<ClassType>();
input_stack.emplace_back(cls->containedTypes().data());
result_stack.emplace_back();
aggregate_creators.emplace_back(
[&result_stack, cls] { return cls->refine(result_stack.back()); });
} break;
case SKIP:
result_stack.back().emplace_back(*input_stack.back()++);
break;
case LEAVE:
TypePtr result = aggregate_creators.back()();
result_stack.pop_back();
aggregate_creators.pop_back();
input_stack.pop_back();
result_stack.back().emplace_back(std::move(result));
break;
}
}
AT_ASSERT(result_stack.size() == 1);
return result_stack.back();
}
void ArgumentSpecCreator::setInputTypes(Graph& g, const ArgumentSpec& spec)
const {
auto input_types = getSpecializedTypes(g, spec);
auto inputs = g.inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
inputs[i]->setType(input_types[i]);
}
}
} // namespace jit
} // namespace torch
You can’t perform that action at this time.