Skip to content

Commit

Permalink
Adjust API to match current public repo
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidNorman authored and martinwicke committed Jun 14, 2017
1 parent 43d8475 commit d59b64f
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 10 deletions.
10 changes: 6 additions & 4 deletions tensorflow/compiler/plugin/executor/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,16 @@ ExecutorCompiler::CompileAheadOfTime(
"AOT compilation not supported on Executor");
}

int64 ExecutorCompiler::ShapeSizeBytes(const Shape& shape) const {
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}

se::Platform::Id ExecutorCompiler::PlatformId() const {
return sep::kExecutorPlatformId;
}

HloCostAnalysis::ShapeSizeFunction
ExecutorCompiler::ShapeSizeBytesFunction() const {
return ExecutorExecutable::ShapeSizeBytes;
}


} // namespace executorplugin
} // namespace xla

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/plugin/executor/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ExecutorCompiler : public Compiler {
std::vector<std::unique_ptr<HloModule>> module,
HloDumper dump_hlo, const AotCompilationOptions& options) override;

int64 ShapeSizeBytes(const Shape& shape) const override;
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;

perftools::gputools::Platform::Id PlatformId() const override;

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/plugin/executor/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const char* const DEVICE_XLA_EXEC = "XLA_EXEC";
const char* const DEVICE_EXEC_XLA_JIT = "XLA_EXEC_JIT";

constexpr std::array<DataType, 5> kExecAllTypes = {
{DT_INT32, DT_FLOAT, DT_BOOL}};
{DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}};

class XlaExaDeviceFactory : public DeviceFactory {
public:
Expand All @@ -50,7 +50,7 @@ Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options,
return Status::OK();
}

REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 210);
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 110);

// Kernel registrations

Expand Down
10 changes: 9 additions & 1 deletion tensorflow/compiler/plugin/executor/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace xla {
namespace executorplugin {

ExecutorExecutable::ExecutorExecutable(std::unique_ptr<HloModule> hlo_module)
: Executable(std::move(hlo_module)) {}
: Executable(std::move(hlo_module), ShapeSizeBytes) {}

ExecutorExecutable::~ExecutorExecutable() {}

Expand Down Expand Up @@ -135,5 +135,13 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteAsyncOnStream(
"ExecuteAsyncOnStream is not yet supported on Executor.");
}

/*static*/ int64 ExecutorExecutable::ShapeSizeBytes(const Shape& shape) {
if (ShapeUtil::IsOpaque(shape)) {
return sizeof(void*);
}
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}


} // namespace executorplugin
} // namespace xla
2 changes: 2 additions & 0 deletions tensorflow/compiler/plugin/executor/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class ExecutorExecutable : public Executable {
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments) override;

static int64 ShapeSizeBytes(const Shape& shape);

private:
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorExecutable);
};
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/plugin/executor/transfer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ int64 ExecutorTransferManager::GetByteSizeRequirement(const Shape& shape) {
} // namespace executorplugin
} // namespace xla

static xla::TransferManager* CreateExecutorTransferManager() {
return new xla::executorplugin::ExecutorTransferManager();
static std::unique_ptr<xla::TransferManager> CreateExecutorTransferManager() {
return xla::MakeUnique<xla::executorplugin::ExecutorTransferManager>();
}

static bool InitModule() {
Expand Down

0 comments on commit d59b64f

Please sign in to comment.