Skip to content

Commit

Permalink
TensorFlow: Upstream changes from afternoon.
Browse files Browse the repository at this point in the history
Changes:
- Ptrdiff -> DenseIndex change by @jiayq

- Fix to scoping the logging in logging.py by @dga

- Improvement to Conv2DBackpropFilter on CPU by Andy

- Remove lookup table wrappers for the time being (wasn't in our
  public API yet) by Yukata

- Add a check similar to numpy to make sure the user isn't in the
  tensorflow src directory by @vrv

- More changes for  python 3 compat by @girving

- Make dropout preserve shape info from input (@mrry)

- Significant speed improvements by @zheng-xq to BFC allocator to bring
  on par (CPU overhead-wise) to the region allocator.  Make BFC
  allocator the default now that it's working well for a variety
  of models.

- Fix a bunch of typos reported by users (@vrv)

- Enable concat for bfloat16 on GPU by Ashish.

Base CL: 107733123
  • Loading branch information
Vijay Vasudevan committed Nov 13, 2015
1 parent 72a5a60 commit 6b12d08
Show file tree
Hide file tree
Showing 43 changed files with 397 additions and 580 deletions.
5 changes: 5 additions & 0 deletions tensorflow/core/common_runtime/function.cc
Expand Up @@ -495,6 +495,11 @@ static void SimplifyGraph(Graph* g) {

void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
DumpGraph("Initial", *g);

// Run SimplifyGraph at least once to rewrite away ops such as
// _ListToArray, _ArrayToList, etc.
SimplifyGraph(*g);

const int kNumInlineRounds = 10;
for (int i = 0; i < kNumInlineRounds; ++i) {
if (!ExpandInlineFunctions(lib, *g)) break;
Expand Down
100 changes: 41 additions & 59 deletions tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
Expand Up @@ -65,7 +65,7 @@ GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory)
ptr_to_chunk_map_.insert(std::make_pair(c->ptr, c));

// Insert the chunk into the right bin.
ReassignChunkToBin(c);
InsertFreeChunkIntoBin(c);
}

GPUBFCAllocator::~GPUBFCAllocator() {
Expand All @@ -76,6 +76,7 @@ GPUBFCAllocator::~GPUBFCAllocator() {
}

gtl::STLDeleteValues(&bins_);
gtl::STLDeleteValues(&ptr_to_chunk_map_);
}

void* GPUBFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) {
Expand Down Expand Up @@ -115,10 +116,12 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
// Start searching from the first bin for the smallest chunk that fits
// rounded_bytes.
Bin* b = it->second;
for (GPUBFCAllocator::Chunk* chunk : b->chunks) {
if (!chunk->in_use && chunk->size > rounded_bytes) {
// We found an existing chunk that fits us that wasn't in use.
chunk->in_use = true;
for (GPUBFCAllocator::Chunk* chunk : b->free_chunks) {
DCHECK(!chunk->in_use);
if (chunk->size >= rounded_bytes) {
// We found an existing chunk that fits us that wasn't in use, so remove
// it from the free bin structure prior to using.
RemoveFreeChunkFromBin(chunk);

// If we can break the size of the chunk into two reasonably
// large pieces, do so.
Expand All @@ -132,6 +135,7 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
// The requested size of the returned chunk is what the user
// has allocated.
chunk->requested_size = num_bytes;
chunk->in_use = true;

VLOG(4) << "Returning: " << chunk->ptr;
return chunk->ptr;
Expand All @@ -152,6 +156,8 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
}

void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
CHECK(!c->in_use && !c->bin);

// Create a new chunk starting num_bytes after c
GPUBFCAllocator::Chunk* new_chunk = new GPUBFCAllocator::Chunk();
new_chunk->ptr = static_cast<void*>(static_cast<char*>(c->ptr) + num_bytes);
Expand All @@ -176,9 +182,8 @@ void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
c_neighbor->prev = new_chunk;
}

// Maintain the bins
ReassignChunkToBin(new_chunk);
ReassignChunkToBin(c);
// Add the newly free chunk to the free bin.
InsertFreeChunkIntoBin(new_chunk);
}

void GPUBFCAllocator::DeallocateRaw(void* ptr) {
Expand All @@ -200,19 +205,17 @@ void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {

GPUBFCAllocator::Chunk* c = it->second;
VLOG(6) << "Chunk at " << c->ptr << " no longer in use";
// Mark the chunk as no longer in use
c->in_use = false;

// Consider coalescing it.
MaybeCoalesce(c);
FreeAndMaybeCoalesce(c);
}

// Merges c1 and c2 when c1->next is c2 and c2->prev is c1.
// We merge c2 into c1.
void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
GPUBFCAllocator::Chunk* c2) {
// We can only merge chunks that are not in use.
DCHECK(!c1->in_use && !c2->in_use);
CHECK(!c1->in_use && !c2->in_use);

// c1's prev doesn't change, still points to the same ptr, and is
// still not in use.
Expand All @@ -231,62 +234,42 @@ void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
// Set the new size
c1->size += c2->size;

DeleteChunk(c2);
}

void GPUBFCAllocator::DeleteChunk(Chunk* c) {
// Delete c2 and cleanup all state
RemoveChunkFromBin(c2);
VLOG(4) << "Removing: " << c->ptr;
ptr_to_chunk_map_.erase(c->ptr);
delete c;
}

void GPUBFCAllocator::ReassignChunkToBin(GPUBFCAllocator::Chunk* c) {
void GPUBFCAllocator::InsertFreeChunkIntoBin(GPUBFCAllocator::Chunk* c) {
CHECK(!c->in_use && !c->bin);
auto it = bins_.lower_bound(c->size);
CHECK(it != bins_.end()) << " Tried to reassign to non-existent bin for size "
<< c->size;

Bin* new_bin = it->second;

// If the bin has not changed, do nothing.
Bin* old_bin = c->bin;
if (old_bin != nullptr && new_bin == old_bin) {
return;
}

// The bin has changed. Add the chunk to the new bin and remove
// the chunk from the old bin.
new_bin->chunks.insert(c);
c->bin = new_bin;
new_bin->free_chunks.insert(c);
}

if (old_bin == nullptr) {
return;
}

// Remove chunk from old bin
for (auto it = old_bin->chunks.begin(); it != old_bin->chunks.end(); ++it) {
if (*it == c) {
old_bin->chunks.erase(it);
return;
}
}
CHECK(false) << "Could not find chunk in old bin";
void GPUBFCAllocator::RemoveFreeChunkFromBin(GPUBFCAllocator::Chunk* c) {
CHECK(!c->in_use && c->bin);
int count = c->bin->free_chunks.erase(c);
CHECK(count > 0) << "Could not find chunk in bin";
c->bin = nullptr;
}

void GPUBFCAllocator::RemoveChunkFromBin(GPUBFCAllocator::Chunk* c) {
Bin* b = c->bin;
for (auto it = b->chunks.begin(); it != b->chunks.end(); ++it) {
Chunk* other_c = *it;
if (other_c->ptr == c->ptr) {
b->chunks.erase(it);
VLOG(4) << "Removing: " << c->ptr;
ptr_to_chunk_map_.erase(c->ptr);
delete c;
return;
}
}
void GPUBFCAllocator::FreeAndMaybeCoalesce(GPUBFCAllocator::Chunk* c) {
CHECK(c->in_use && !c->bin);

CHECK(false) << "Could not find chunk in bin";
}
// Mark the chunk as no longer in use
c->in_use = false;

void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
// This chunk is no longer in-use, consider coalescing the chunk
// with adjacent chunks.
Chunk* chunk_to_reassign = nullptr;
Chunk* chunk_to_reassign = c;

// If the next chunk is free, coalesce the two, if the result would
// fit in an existing bin.
Expand All @@ -296,6 +279,7 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
chunk_to_reassign = c;

// Deletes c->next
RemoveFreeChunkFromBin(c->next);
Merge(c, c->next);
}

Expand All @@ -307,13 +291,11 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
chunk_to_reassign = c->prev;

// Deletes c
RemoveFreeChunkFromBin(c->prev);
Merge(c->prev, c);
}

// Reassign the final merged chunk into the right bin.
if (chunk_to_reassign) {
ReassignChunkToBin(chunk_to_reassign);
}
InsertFreeChunkIntoBin(chunk_to_reassign);
}

void GPUBFCAllocator::AddAllocVisitor(Visitor visitor) {
Expand Down Expand Up @@ -354,7 +336,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
size_t total_requested_bytes_in_bin = 0;
size_t total_chunks_in_use = 0;
size_t total_chunks_in_bin = 0;
for (Chunk* c : b->chunks) {
for (Chunk* c : b->free_chunks) {
total_bytes_in_bin += c->size;
total_requested_bytes_in_bin += c->requested_size;
++total_chunks_in_bin;
Expand Down Expand Up @@ -388,7 +370,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
<< " was " << strings::HumanReadableNumBytes(b->bin_size)
<< ", Chunk State: ";

for (Chunk* c : b->chunks) {
for (Chunk* c : b->free_chunks) {
LOG(INFO) << c->DebugString(true);
}
}
Expand Down
27 changes: 16 additions & 11 deletions tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
Expand Up @@ -102,28 +102,33 @@ class GPUBFCAllocator : public VisitableAllocator {
Chunk* AllocateNewChunk(size_t num_bytes);
void SplitChunk(Chunk* c, size_t num_bytes);
void Merge(Chunk* c1, Chunk* c2);
void MaybeCoalesce(Chunk* c);

void ReassignChunkToBin(Chunk* c);
void RemoveChunkFromBin(Chunk* c);
void FreeAndMaybeCoalesce(Chunk* c);
void InsertFreeChunkIntoBin(Chunk* c);
void RemoveFreeChunkFromBin(Chunk* c);
void DeleteChunk(Chunk* c);

void DumpMemoryLog(size_t num_bytes);

// A Bin is a collection of similar-sized Chunks.
// A Bin is a collection of similar-sized free chunks.
struct Bin {
// All chunks in this bin have >= bin_size memory.
size_t bin_size = 0;

struct ChunkComparator {
bool operator()(Chunk* a, Chunk* b) { return a->size < b->size; }
// Sort first by size and then use pointer address as a tie breaker.
bool operator()(const Chunk* a, const Chunk* b) const {
if (a->size != b->size) {
return a->size < b->size;
}
return a->ptr < b->ptr;
}
};

// List of chunks within the bin, sorted by chunk size.
std::multiset<Chunk*, ChunkComparator> chunks;
// List of free chunks within the bin, sorted by chunk size.
// Chunk * not owned.
std::set<Chunk*, ChunkComparator> free_chunks;

explicit Bin(size_t bs) : bin_size(bs) {}

~Bin() { gtl::STLDeleteElements(&chunks); }
};

GPUAllocatorRetry retry_helper_;
Expand All @@ -142,7 +147,7 @@ class GPUBFCAllocator : public VisitableAllocator {

// Structures mutable after construction
mutable mutex lock_;
// Not owned.
// Chunk * owned.
std::unordered_map<void*, Chunk*> ptr_to_chunk_map_;

// Called once on each region, ASAP.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/common_runtime/gpu/process_state.cc
Expand Up @@ -20,7 +20,7 @@ DEFINE_bool(record_mem_types, false,
DEFINE_bool(brain_mem_reg_cuda_dma, true,
"If true, register CPU RAM used to copy to/from GPU RAM "
"with the CUDA driver.");
DEFINE_bool(brain_gpu_use_bfc_allocator, false,
DEFINE_bool(brain_gpu_use_bfc_allocator, true,
"If true, uses the Best-Fit GPU allocator.");
DEFINE_bool(brain_gpu_region_allocator_debug, false,
"If true, checks for memory overwrites by writing "
Expand All @@ -34,7 +34,7 @@ bool FLAGS_record_mem_types = false;
bool FLAGS_brain_mem_reg_cuda_dma = true;
bool FLAGS_brain_gpu_region_allocator_debug = false;
bool FLAGS_brain_gpu_region_allocator_reset_to_nan = false;
bool FLAGS_brain_gpu_use_bfc_allocator = false;
bool FLAGS_brain_gpu_use_bfc_allocator = true;
#endif

namespace gpu = ::perftools::gputools;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/concat_op.cc
Expand Up @@ -135,6 +135,7 @@ REGISTER_CONCAT(bfloat16);
ConcatOp<GPUDevice, type>)

TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU

// A special GPU kernel for int32.
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/concat_op_gpu.cu.cc
Expand Up @@ -6,6 +6,7 @@

#include <memory>

#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"

Expand Down Expand Up @@ -34,6 +35,7 @@ void ConcatGPU(const GPUDevice& d,
typename TTypes<T, 2>::Matrix* output);

TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU

} // end namespace tensorflow
Expand Down
62 changes: 47 additions & 15 deletions tensorflow/core/kernels/conv_grad_ops.cc
Expand Up @@ -541,12 +541,36 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
// The output image size is the spatial size of the output.
const int output_image_size = out_rows * out_cols;

// Shard 'batch' images into 'shard_size' groups of images to be fed
// into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache
// size ('target_working_set_size') by the matmul size of an individual
// image ('work_unit_size').

// TODO(andydavis)
// *) Get L3 cache size from device at runtime (30MB is from ivybridge).
// *) Consider reducing 'target_working_set_size' if L3 is shared by
// other concurrently running tensorflow ops.
const size_t target_working_set_size = (30LL << 20) / sizeof(T);

const size_t size_A = output_image_size * filter_total_size;

const size_t size_B = output_image_size * out_depth;

const size_t size_C = filter_total_size * out_depth;

const size_t work_unit_size = size_A + size_B + size_C;

const size_t shard_size =
(target_working_set_size + work_unit_size - 1) / work_unit_size;

Tensor col_buffer;
OP_REQUIRES_OK(
context,
context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({output_image_size, filter_total_size}), &col_buffer));
OP_REQUIRES_OK(context,
context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({static_cast<int64>(shard_size),
static_cast<int64>(output_image_size),
static_cast<int64>(filter_total_size)}),
&col_buffer));

// The input offset corresponding to a single input image.
const int input_offset = input_rows * input_cols * in_depth;
Expand All @@ -571,21 +595,29 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
contract_dims[0].first = 0;
contract_dims[0].second = 0;

for (int image_id = 0; image_id < batch; ++image_id) {
// When we compute the gradient with respect to the filters, we need to do
// im2col to allow gemm-type computation.
Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
stride, col_buffer_data);
for (int image_id = 0; image_id < batch; image_id += shard_size) {
const int shard_limit = std::min(static_cast<int>(shard_size),
static_cast<int>(batch) - image_id);
for (int shard_id = 0; shard_id < shard_limit; ++shard_id) {
// TODO(andydavis) Parallelize this loop.
// When we compute the gradient with respect to the filters, we need
// to do im2col to allow gemm-type computation.
Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
stride, col_buffer_data + shard_id * size_A);

input_data += input_offset;
}

ConstTensorMap A(col_buffer_data, output_image_size, filter_total_size);
ConstTensorMap B(out_backprop_data + output_offset * image_id,
output_image_size, out_depth);
ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
filter_total_size);
ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
out_depth);

// Gradient with respect to filter.
C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);

input_data += input_offset;
out_backprop_data += output_offset * shard_limit;
}
}

Expand Down

0 comments on commit 6b12d08

Please sign in to comment.