Skip to content

Commit

Permalink
Fix LOGFATAL for logical_buffer access with invalid id
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631562933
  • Loading branch information
zzzaries authored and tensorflower-gardener committed May 7, 2024
1 parent 62c4300 commit b7dfb6b
Showing 1 changed file with 46 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,6 @@ class HloProtoBufferWrapper {
// Get the raw HLO proto.
const ::xla::HloProto& GetHloProto() const { return hlo_proto_; }

const BufferAllocationStruct& GetBufferAllocation(
int64_t buffer_allocation_id) const {
if (!id_to_buffer_allocation_.contains(buffer_allocation_id)) {
LOG(DFATAL) << "buffer_allocation_id " << buffer_allocation_id
<< " not found.";
}
return *id_to_buffer_allocation_.at(buffer_allocation_id);
}

std::vector<const BufferAllocationStruct*> GetBufferAllocations(
int64_t memory_color) const {
std::vector<const BufferAllocationStruct*> buffer_allocations;
Expand All @@ -239,11 +230,12 @@ class HloProtoBufferWrapper {
return buffer_allocations;
}

LogicalBufferStruct& GetLogicalBuffer(int64_t logical_buffer_id) const {
LogicalBufferStruct* GetLogicalBuffer(int64_t logical_buffer_id) const {
if (!id_to_logical_buffer_.contains(logical_buffer_id)) {
LOG(DFATAL) << "logical_buffer_id " << logical_buffer_id << "not found.";
return nullptr;
}
return *id_to_logical_buffer_.at(logical_buffer_id);
return id_to_logical_buffer_.at(logical_buffer_id).get();
}

// Get the logical buffers with indefinite lifetime (excluding thread_local).
Expand All @@ -262,11 +254,12 @@ class HloProtoBufferWrapper {
const LogicalBufferStruct* best_logical_buffer = nullptr;
size_t best_size = 0;
for (const auto& assigned : buffer_assignment->proto().assigned()) {
const auto& logical_buffer_struct =
const LogicalBufferStruct* logical_buffer_struct =
GetLogicalBuffer(assigned.logical_buffer_id());
if (logical_buffer_struct.size() > best_size) {
best_size = logical_buffer_struct.size();
best_logical_buffer = &logical_buffer_struct;
if (logical_buffer_struct == nullptr) continue;
if (logical_buffer_struct->size() > best_size) {
best_size = logical_buffer_struct->size();
best_logical_buffer = logical_buffer_struct;
}
}
if (best_logical_buffer) {
Expand Down Expand Up @@ -442,12 +435,13 @@ void Convert(const xla::BufferAllocationProto_Assigned& assigned,
const HloProtoBufferWrapper& wrapper, LogicalBuffer* result) {
result->set_id(assigned.logical_buffer_id()),
result->set_size_mib(BytesToMiB(assigned.size()));
const auto& logical_buffer =
const LogicalBufferStruct* logical_buffer =
wrapper.GetLogicalBuffer(assigned.logical_buffer_id());
result->set_hlo_name(std::string(logical_buffer.instruction_name()));
if (logical_buffer == nullptr) return;
result->set_hlo_name(std::string(logical_buffer->instruction_name()));
result->mutable_shape_index()->CopyFrom(
logical_buffer.proto.defined_at().shape_index());
result->set_shape(ShapeDescription(logical_buffer.shape));
logical_buffer->proto.defined_at().shape_index());
result->set_shape(ShapeDescription(logical_buffer->shape));
}

bool IsReusable(const BufferAllocationProto& buffer_allocation) {
Expand Down Expand Up @@ -542,9 +536,11 @@ struct HeapSimulatorStats {
// Update memory timelines and seen buffers.
heap_size_bytes_timeline.push_back(heap_size_bytes);
unpadded_heap_size_bytes_timeline.push_back(unpadded_heap_size_bytes);
const auto& logical_buffer = wrapper.GetLogicalBuffer(event.buffer_id());
seen_logical_buffers.insert(&logical_buffer);
seen_buffer_allocations.insert(&logical_buffer.buffer_allocation.proto());
const LogicalBufferStruct* logical_buffer =
wrapper.GetLogicalBuffer(event.buffer_id());
if (logical_buffer == nullptr) return;
seen_logical_buffers.insert(logical_buffer);
seen_buffer_allocations.insert(&logical_buffer->buffer_allocation.proto());
}

// Update stats when memory usage increase.
Expand Down Expand Up @@ -670,36 +666,44 @@ Status ProcessHeapSimulatorTrace(const HloProtoBufferWrapper& wrapper,
stats->SetSimulatorTraceEventSize(trace.events_size());
for (const auto& event : trace.events()) {
stats->UpdateOnSimulatorEvent(event);
auto& logical_buffer = wrapper.GetLogicalBuffer(event.buffer_id());
LogicalBufferStruct* logical_buffer =
wrapper.GetLogicalBuffer(event.buffer_id());
if (logical_buffer == nullptr) {
continue;
}
if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
// ALLOC event increases memory usage and initializes the buffer lifetime
// span.
logical_buffer.inc();
stats->IncreaseMemoryUsage(&logical_buffer,
logical_buffer->inc();
stats->IncreaseMemoryUsage(logical_buffer,
/*init_buffer_span=*/true);
} else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
auto ref_count = logical_buffer.dec();
auto ref_count = logical_buffer->dec();
if (ref_count < 0) {
return errors::InvalidArgument(absl::StrCat(
"Buffer ", logical_buffer.proto.id(), "is freed multiple times."));
"Buffer ", logical_buffer->proto.id(), "is freed multiple times."));
}
if (ref_count == 0) {
// There is no more reference to the canonical buffer, the canonical
// buffer is finally freed. Update memory usage and memory timespan
// using the metadata of canonical buffer.
auto& canonical_buffer = *logical_buffer.get_canonical_buffer();
auto& canonical_buffer = *logical_buffer->get_canonical_buffer();
TF_RETURN_IF_ERROR(stats->DecreaseMemoryUsage(&canonical_buffer));
}
} else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
int64_t canonical_buffer_id = event.share_with_canonical_id();
auto& canonical_buffer = wrapper.GetLogicalBuffer(canonical_buffer_id);
auto ref_count = logical_buffer.share_with(&canonical_buffer);
LogicalBufferStruct* canonical_buffer =
wrapper.GetLogicalBuffer(canonical_buffer_id);
if (canonical_buffer == nullptr) {
continue;
}
auto ref_count = logical_buffer->share_with(canonical_buffer);

if (ref_count == 1) {
// SHARE_WITH happens after the FREE of a canonical buffer.
// SHARE_WITH event does not initialize buffer lifetime span, it was
// initialized by ALLOC event using the canonical logical buffer.
stats->IncreaseMemoryUsage(&canonical_buffer,
stats->IncreaseMemoryUsage(canonical_buffer,
/*init_buffer_span=*/false);
}
} else {
Expand Down Expand Up @@ -735,8 +739,10 @@ struct PeakUsageSnapshot {
// Buffers from HeapSimulatorTrace.
for (const int64_t logical_buffer_id :
simulator_stats.peak_logical_buffers) {
const auto& logical_buffer = wrapper.GetLogicalBuffer(logical_buffer_id);
AddHeapObject(logical_buffer);
const LogicalBufferStruct* logical_buffer =
wrapper.GetLogicalBuffer(logical_buffer_id);
if (logical_buffer == nullptr) return;
AddHeapObject(*logical_buffer);
}

// Make a single HeapObject out of all the small buffers.
Expand Down Expand Up @@ -963,14 +969,15 @@ void ConvertAllocationTimeline(const HloProtoBufferWrapper& wrapper,
ba_colors[buffer_id % num_ba_colors]);

for (const auto& assigned : buffer_allocation->proto().assigned()) {
const LogicalBufferStruct& logical_buffer =
const LogicalBufferStruct* logical_buffer =
wrapper.GetLogicalBuffer(assigned.logical_buffer_id());
if (logical_buffer == nullptr) continue;
// Exclude non-canonical logical buffers.
if (!logical_buffer.span || logical_buffer.canonical_buffer) continue;
size_t width = logical_buffer.span->second - logical_buffer.span->first;
size_t height = buffer_allocation_offset + logical_buffer.size();
add_rect(logical_buffer.span->first, logical_buffer.offset, width, height,
logical_buffer.description(),
if (!logical_buffer->span || logical_buffer->canonical_buffer) continue;
size_t width = logical_buffer->span->second - logical_buffer->span->first;
size_t height = buffer_allocation_offset + logical_buffer->size();
add_rect(logical_buffer->span->first, logical_buffer->offset, width,
height, logical_buffer->description(),
lb_colors[node_id % num_lb_colors]);
}
}
Expand Down

0 comments on commit b7dfb6b

Please sign in to comment.