Skip to content

Commit

Permalink
Variable Tensor API for TF Lite.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 200457602
  • Loading branch information
miaout17 authored and tensorflower-gardener committed Jun 13, 2018
1 parent 2f7f04a commit a3273e0
Show file tree
Hide file tree
Showing 16 changed files with 299 additions and 53 deletions.
58 changes: 53 additions & 5 deletions tensorflow/contrib/lite/arena_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct AllocationInfo {
// The tensor index to be allocated or deallocated.
int tensor;
// Whether to allocate or deallocate
enum { ALLOC, DEALLOC } type;
enum Type { ALLOC, DEALLOC } type;
};

ArenaPlanner::ArenaPlanner(TfLiteContext* context,
Expand Down Expand Up @@ -67,6 +67,33 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {

// Keeps track of references to each tensor.
std::vector<int> refcounts(graph_info_->num_tensors(), 0);
// `allocated` and `deallocated` are technically list of boolean values.
// We're saving the compiled binary size by using `vector<int>`.
std::vector<int> allocated(graph_info_->num_tensors(), false);
std::vector<int> deallocated(graph_info_->num_tensors(), false);

auto allocate = [this, &allocated, &deallocated](int node,
int tensor) -> TfLiteStatus {
if (allocated[tensor]) {
return kTfLiteOk;
}
TF_LITE_ENSURE(context_, !deallocated[tensor]);
alloc_queue_.push_back({node, tensor, AllocationInfo::ALLOC});
allocated[tensor] = true;
return kTfLiteOk;
};

auto deallocate = [this, &allocated, &deallocated](
int node, int tensor) -> TfLiteStatus {
if (!allocated[tensor]) {
// Do not enqueue a DEALLOC if the tensor is never allocated.
// This happened with the constant tensors.
return kTfLiteOk;
}
TF_LITE_ENSURE(context_, !deallocated[tensor]);
alloc_queue_.push_back({node, tensor, AllocationInfo::DEALLOC});
return kTfLiteOk;
};

// There will be an entry in alloc_queue_ for the allocation of each tensor
// and another for their deallocation.
Expand All @@ -79,6 +106,28 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
refcounts[tensor_index]++;
}

// Variable tensors should are also never overwritten and need to be alive all
// the time.
for (int tensor_index : graph_info_->variables()) {
refcounts[tensor_index]++;
}

// Queue all graph inputs for allocation.
for (int tensor_index : graph_info_->inputs()) {
if (tensor_index != kOptionalTensor) {
TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
}
}

// Queue all graph variable tensors for allocation.
for (int tensor_index : graph_info_->variables()) {
if (tensor_index != kOptionalTensor) {
// Increase the reference count for input tensors by one, so it will
// never be deallocated.
TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
}
}

// Count references to node input tensors.
for (int i = 0; i < graph_info_->num_nodes(); ++i) {
const TfLiteNode& node = graph_info_->node(i);
Expand All @@ -94,10 +143,9 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
// Queue all graph inputs for allocation.
for (int tensor_index : graph_info_->inputs()) {
if (tensor_index != kOptionalTensor) {
alloc_queue_.push_back({0, tensor_index, AllocationInfo::ALLOC});
TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
}
}

// Go through the graph in execution order.
for (int i = 0; i < graph_info_->num_nodes(); ++i) {
const TfLiteNode& node = graph_info_->node(i);
Expand All @@ -106,7 +154,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
TfLiteIntArray* node_outputs = node.outputs;
for (int j = 0; j < node_outputs->size; ++j) {
int tensor_index = node_outputs->data[j];
alloc_queue_.push_back({i, tensor_index, AllocationInfo::ALLOC});
TF_LITE_ENSURE_STATUS(allocate(i, tensor_index));
}

// Then update the ref-counts of the node's inputs, and if necessary queue
Expand All @@ -117,7 +165,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
if (tensor_index != kOptionalTensor) {
refcounts[tensor_index]--;
if (refcounts[tensor_index] == 0) {
alloc_queue_.push_back({i, tensor_index, AllocationInfo::DEALLOC});
TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index));
}
}
}
Expand Down
13 changes: 12 additions & 1 deletion tensorflow/contrib/lite/arena_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,18 @@ class TestGraph {
std::vector<TfLiteTensor>* tensors() { return &tensors_; }
const std::vector<int>& inputs() { return inputs_; }
const std::vector<int>& outputs() { return outputs_; }
const std::vector<int>& variables() { return variables_; }

void SetVariables(const std::vector<int>& variables) {
variables_ = variables;
}

private:
std::vector<TfLiteNode> nodes_;
std::vector<TfLiteTensor> tensors_;
std::vector<int> inputs_;
std::vector<int> outputs_;
std::vector<int> variables_;
};

// The GraphInfo for a TestGraph.
Expand All @@ -123,6 +129,9 @@ class TestGraphInfo : public GraphInfo {
}
const std::vector<int>& inputs() const override { return graph_->inputs(); }
const std::vector<int>& outputs() const override { return graph_->outputs(); }
const std::vector<int>& variables() const override {
return graph_->variables();
}

private:
TestGraph* graph_;
Expand Down Expand Up @@ -306,13 +315,15 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentTensor) {
{
/* in, out, tmp */
{{0, 1}, {2}, {}}, // First op
{{2, 0}, {4}, {5}}, // Second op, with temporary
{{2, 0}, {4}, {5}}, // Second op, with persistent
{{4, -1}, {3}, {}} // Third op, with optional
},
{3});

// Make #1 persistent so it goes into its own arena.
(*graph.tensors())[1].allocation_type = kTfLiteArenaRwPersistent;
// The only use case for kTfLiteArenaRwPersistent is variable tensor now.
graph.SetVariables({1});

SetGraph(&graph);
Execute(0, 10);
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/contrib/lite/context.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void TfLiteTensorFree(TfLiteTensor* t) {
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, TfLiteTensor* tensor) {
const void* allocation, bool is_variable, TfLiteTensor* tensor) {
TfLiteTensorFree(tensor);
tensor->type = type;
tensor->name = name;
Expand All @@ -86,6 +86,7 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
tensor->bytes = size;
tensor->allocation_type = allocation_type;
tensor->allocation = allocation;
tensor->is_variable = is_variable;
}

void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/contrib/lite/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ typedef struct {
// delegate buffer.
// WARNING: This is an // experimental interface that is subject to change.
bool data_is_stale;

// True if the tensor is a variable.
bool is_variable;
} TfLiteTensor;

// Free data memory of tensor `t`;
Expand All @@ -237,7 +240,8 @@ void TfLiteTensorFree(TfLiteTensor* t);
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, TfLiteTensor* tensor);
const void* allocation, bool is_variable,
TfLiteTensor* tensor);

// Resize the allocated data of a (dynamic) tensor.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/contrib/lite/graph_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class GraphInfo {

// Returns the indices of the output tensors.
virtual const std::vector<int>& outputs() const = 0;

// Returns the indices of the variable tensors.
virtual const std::vector<int>& variables() const = 0;
};

// Represents a subgraph of a TensorFlow Lite graph.
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/contrib/lite/graph_info_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class SimpleTestGraph : public GraphInfo {
TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; }
const std::vector<int>& inputs() const override { return inputs_; }
const std::vector<int>& outputs() const override { return outputs_; }
const std::vector<int>& variables() const override { return variables_; }

void AddNode(const std::vector<int>& inputs,
const std::vector<int>& outputs) {
Expand All @@ -67,6 +68,7 @@ class SimpleTestGraph : public GraphInfo {
std::vector<TfLiteTensor> tensors_;
std::vector<int> inputs_;
std::vector<int> outputs_;
std::vector<int> variables_;
};

// Partition a graph to generate a list of subgraphs. This wraps the API call
Expand Down
55 changes: 49 additions & 6 deletions tensorflow/contrib/lite/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class InterpreterInfo : public GraphInfo {
const std::vector<int>& outputs() const override {
return interpreter_->outputs();
}
const std::vector<int>& variables() const override {
return interpreter_->variables();
}

public:
Interpreter* interpreter_;
Expand Down Expand Up @@ -302,6 +305,13 @@ TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
return kTfLiteOk;
}

TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(),
variables.size()));
variables_ = std::move(variables);
return kTfLiteOk;
}

TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
const int* indices, int length) {
// Making sure kOptionalTensor is not re-defined to something other than -1.
Expand Down Expand Up @@ -370,6 +380,7 @@ TfLiteStatus Interpreter::AllocateTensors() {
}

TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());

if (state_ == kStateUninvokable) {
state_ = kStateInvokable;
}
Expand All @@ -378,6 +389,25 @@ TfLiteStatus Interpreter::AllocateTensors() {
return kTfLiteOk;
}

// TODO(ycling): Consider to provide other functions to initialize variable
// tensors to non-zero values.
TfLiteStatus Interpreter::ResetVariableTensorsToZero() {
for (auto& tensor : tensors_) {
if (!tensor.is_variable) {
continue;
}

// Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be
// allocated after the initial `PrepareOpsAndTensors()` is called.
TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type,
kTfLiteArenaRwPersistent);
TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr);

memset(tensor.data.raw, 0, tensor.bytes);
}
return kTfLiteOk;
}

TfLiteStatus Interpreter::AddNodeWithParameters(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
Expand Down Expand Up @@ -690,7 +720,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
state_ = kStateUninvokable;
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
kTfLiteMmapRo, allocation, &tensor);
kTfLiteMmapRo, allocation, false, &tensor);
}
return kTfLiteOk;
}
Expand All @@ -701,7 +731,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// to Interpreter.
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization) {
const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
if (state_ == kStateInvokableAndImmutable) {
ReportError(
&context_,
Expand All @@ -719,11 +749,23 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
TF_LITE_ENSURE_OK(&context_,
BytesRequired(type, dims, rank, &required_bytes));
}

TfLiteAllocationType allocation_type = kTfLiteArenaRw;
if (type == kTfLiteString) {
if (is_variable) {
// We don't have a real use case for string variable tensor.
ReportError(&context_, "String variable tensor isn't supported.");
return kTfLiteError;
}
allocation_type = kTfLiteDynamic;
} else if (is_variable) {
allocation_type = kTfLiteArenaRwPersistent;
}

TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization,
/*buffer=*/nullptr, required_bytes,
type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
nullptr, &context_.tensors[tensor_index]);
/*buffer=*/nullptr, required_bytes, allocation_type,
nullptr, is_variable, &context_.tensors[tensor_index]);
return kTfLiteOk;
}

Expand All @@ -739,7 +781,8 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
TfLiteIntArray* new_size) {
// Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
if (tensor->allocation_type == kTfLiteArenaRw ||
tensor->allocation_type == kTfLiteDynamic) {
tensor->allocation_type == kTfLiteDynamic ||
tensor->allocation_type == kTfLiteArenaRwPersistent) {
if (tensor->type != kTfLiteString) {
size_t bytesRequired;
TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
Expand Down
23 changes: 20 additions & 3 deletions tensorflow/contrib/lite/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ class Interpreter {
// interpreter.
TfLiteStatus SetOutputs(std::vector<int> outputs);

// Provide a list of tensor indexes that are variable tensors.
// Each index is bound check and this modifies the consistent_ flag of the
// interpreter.
TfLiteStatus SetVariables(std::vector<int> variables);

// Adds a node with the given parameters and returns the index of the new
// node in `node_index` (optionally). Interpreter will take ownership of
// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
Expand Down Expand Up @@ -160,13 +165,15 @@ class Interpreter {
// to Interpreter.
inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
const std::vector<int>& dims, TfLiteQuantizationParams quantization,
bool is_variable = false) {
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
dims.data(), quantization);
dims.data(), quantization, is_variable);
}
TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization);
const int* dims, TfLiteQuantizationParams quantization,
bool is_variable = false);

// Functions to access tensor data

Expand All @@ -182,6 +189,9 @@ class Interpreter {
// Read only access to list of outputs.
const std::vector<int>& outputs() const { return outputs_; }

// Read only access to list of variable tensors.
const std::vector<int>& variables() const { return variables_; }

// Return the name of a given output. The given index must be between 0 and
// outputs().size().
const char* GetOutputName(int index) const {
Expand Down Expand Up @@ -379,6 +389,10 @@ class Interpreter {
allow_buffer_handle_output_ = allow_buffer_handle_output;
}

// Reset all variable tensors to zero.
// WARNING: This is an experimental API and subject to change.
TfLiteStatus ResetVariableTensorsToZero();

private:
// Give 'op_reg' a chance to initialize itself using the contents of
// 'buffer'.
Expand Down Expand Up @@ -541,6 +555,9 @@ class Interpreter {
// interpreter.
std::vector<int> outputs_;

// Array of indices representing the tensors that are variable tensors.
std::vector<int> variables_;

// The error reporter delegate that tflite will forward queries errors to.
ErrorReporter* error_reporter_;

Expand Down
Loading

0 comments on commit a3273e0

Please sign in to comment.