Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Go function to create serialized ConfigOptions protos #26682

Merged
merged 3 commits into from Mar 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
81 changes: 81 additions & 0 deletions tensorflow/go/session.go
Expand Up @@ -18,6 +18,7 @@ package tensorflow

// #include <stdlib.h>
// #include "tensorflow/c/c_api.h"
// #include "tensorflow/c/c_api_experimental.h"
import "C"

import (
Expand Down Expand Up @@ -315,6 +316,11 @@ type SessionOptions struct {
// Config is a binary-serialized representation of the
// tensorflow.ConfigProto protocol message
// (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
// You can populate this field in three ways. You can use the zero value
// of this field, which configures the session with a default set of
// options. Or you create a `Config` struct with your options and call that
// struct's `Bytes()` method. Or you can generate a byte string outside of
// Go and paste that string into your Go program as a string literal.
Config []byte
}

Expand Down Expand Up @@ -349,6 +355,81 @@ func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error)
}, nil
}

// JitLevel represents the level of optimization that the XLA compiler
// performs during just-in-time compilation.
type JitLevel int

const (
// JitDefault is the default setting for this version of TensorFlow
// and corresponds to the "DEFAULT" level in the Python API.
// Currently the default is `JitOff`, but it will change to `JitOn` in a
// future version.
JitDefault JitLevel = 0
// JitOff disables just-in-time compilation and will continue to do so
// even after JIT compilation is enabled by default.
JitOff JitLevel = -1
// JitOn enables just-in-time compilation. It is a synonym for the
// "ON_1" optimization level in the Python API.
JitOn JitLevel = 1
)

// Config represents session parameters as encoded in the tensorflow.ConfigProto
// protocol buffer message.
type Config struct {
// GlobalJitLevel controls the degree of optimization that the XLA just-in-time
// compiler will perform. The default is currently "off", but it is expected
// to change to "on" in a future version of TensorFlow.
GlobalJitLevel JitLevel

// AllowGPUMemoryGrowth controls whether the TensorFlow memory allocator
// pre-allocates the entire specified GPU memory region or instead starts
// with a small block of GPU memory and grows its memory usage as needed.
AllowGPUMemoryGrowth bool

// NumCPUs is the maximum number of CPU devices that the session will use.
// A value of 0 means "let the system pick an appropriate number"
NumCPUs int

// This struct only exposes the session options available via TensorFlow's
// experimental C API function `TF_CreateConfig()`.
// TODO(frreiss): Add additional options here as more session options are
// exposed via the C API.
}

// Bytes generates a serialized ConfigOptions protobuf for use in the `Config`
// field of a `SessionOptions` struct.
func (c *Config) Bytes() []byte {
// The C API expects an unsigned char that is 0 if XLA compilation is off and
// nonzero otherwise.
// There is currently no way in the C API to specify "use TensorFlow's default
// JIT level". The translation logic here ensures that the zero value of
// c.GlobalJitLevel means the same as the default value of
// OptimizerOptions.global_jit_level in the Python API.
enableXLACompilationAsChar := C.uchar(0)
switch c.GlobalJitLevel {
case JitDefault:
// TODO(frreiss): When the semantics of GlobalJitLevel.DEFAULT change to
// "on", uncomment the following line.
// enableXLACompilationAsChar = C.uchar(1)
case JitOn:
enableXLACompilationAsChar = C.uchar(1)
}
gpuMemoryAllowGrowthAsChar := C.uchar(0)
if c.AllowGPUMemoryGrowth {
gpuMemoryAllowGrowthAsChar = 1
}
// The C API doesn't currently have a way to say "let the system pick how many
// CPUs to use," so detect the number of CPUs here.
numCPUDevicesAsUint := C.uint(runtime.NumCPU())
if c.NumCPUs > 0 {
numCPUDevicesAsUint = C.uint(c.NumCPUs)
}
buf := C.TF_CreateConfig(enableXLACompilationAsChar, gpuMemoryAllowGrowthAsChar, numCPUDevicesAsUint)
defer C.TF_DeleteBuffer(buf)
// Copy out of C memory.
return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length))
}

// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
// values suitable for C library calls.
type cRunArgs struct {
Expand Down
20 changes: 4 additions & 16 deletions tensorflow/go/session_test.go
Expand Up @@ -250,27 +250,15 @@ func ExamplePartialRun() {
}

func TestSessionConfig(t *testing.T) {
// Exercise SessionOptions.
// Arguably, a better API would be for SessionOptions.Config to be the
// type generated by the protocol buffer compiler. But for now, the
// tensorflow package continues to be independent of protocol buffers
// and this test exercises the option since the implementation has a
// nuanced conversion to C types.
//
// Till then, the []byte form of Config here was generated using a toy
// tensorflow Python program:
/*
import tensorflow
c = tensorflow.ConfigProto()
c.intra_op_parallelism_threads = 1
print c.SerializeToString()
*/
// Exercise SessionOptions and Config structs
graph := NewGraph()
c, err := Const(graph, "Const", int32(14))
if err != nil {
t.Fatal(err)
}
opts := SessionOptions{Config: []byte("(\x01")}
// Use the zero values for Config.GlobalJitLevel and NumCPUs
config := Config{AllowGPUMemoryGrowth: true}
opts := SessionOptions{Config: config.Bytes()}
s, err := NewSession(graph, &opts)
if err != nil {
t.Fatal(err)
Expand Down