Skip to content

Commit

Permalink
Create and cache a DeviceMemoryAllocator in GemmFusionAutotuner rathe…
Browse files Browse the repository at this point in the history
…r than the one in StreamExecutor as an intermediate step in eliminating a circular dependency between the two classes.

PiperOrigin-RevId: 626445786
  • Loading branch information
klucke authored and tensorflower-gardener committed Apr 19, 2024
1 parent 6932518 commit e021a72
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc
Expand Up @@ -76,6 +76,7 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/gpu/redzone_allocator.h"
#include "xla/stream_executor/stream.h"
#include "xla/tools/hlo_decomposer.h"
Expand Down Expand Up @@ -768,8 +769,11 @@ absl::StatusOr<std::vector<AutotuneResult>> GemmFusionAutotunerImpl::Profile(
return Internal("Failed to synchronize GPU for autotuning.");
}
se::DeviceMemoryAllocator* allocator = config_.GetAllocator();
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator;
if (allocator == nullptr) {
allocator = stream_exec->GetAllocator();
owned_allocator =
std::make_unique<se::StreamExecutorMemoryAllocator>(stream_exec);
allocator = owned_allocator.get();
}
TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream());

Expand Down

0 comments on commit e021a72

Please sign in to comment.