Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mispellings and add minor corrections #3258

Merged
merged 3 commits into from Aug 8, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions tensorflow/contrib/go/README.md
Expand Up @@ -15,7 +15,7 @@ also to build, load and run Graphs.
## Installation

This package depends on the TensorFlow shared libraries, in order to compile
this libraries follow the [Installing fromsources](https://www.tensorflow.org/versions/r0.8/get_started/os_setup.html#installing-from-sources)
this libraries follow the [Installing fromsources](https://www.tensorflow.org/get_started/os_setup.html#installing-from-sources)
guide to clone and configure the repository.

After you have cloned the repository, run the next commands at the root of the
Expand Down Expand Up @@ -98,11 +98,12 @@ func main() {

// Load the graph from the file that we had generated from Python on
// the previous step
reader, err := os.Open("/tmp/graph/test_graph.pb")
b, err := ioutil.ReadFile("/tmp/graph/test_graph.pb")
if err != nil {
t.Fatal(err)
}
graph, err := tensorflow.NewGraphFromReader(reader, true)

graph, err := tensorflow.NewGraphFromBuffer(b)
if err != nil {
log.Fatal("Error reading Graph from the text file:", err)
}
Expand Down Expand Up @@ -170,7 +171,7 @@ execute them:

### ld: library not found for -ltensorflow

This package expects the linker to find the 'libtensorflow' shared library.
This package expects the linker to find the 'libtensorflow' shared library.

To generate this file run:

Expand Down
7 changes: 4 additions & 3 deletions tensorflow/contrib/go/doc.go
@@ -1,6 +1,7 @@
/*
This package provides a high-level Go API for TensorFlow, providing the
necessary tools to create and manipulate Tensors, Variables, Constants and
also to build, load and run Graphs.
This package provides a high-level Go API for TensorFlow.
It provised methods to create and manipulate Tensors, Variables, Constants and
to build, load and run Graphs.
*/

package tensorflow
26 changes: 14 additions & 12 deletions tensorflow/contrib/go/example_test.go
Expand Up @@ -2,8 +2,7 @@ package tensorflow_test

import (
"fmt"
"os"
"strings"
"io/ioutil"

"github.com/tensorflow/tensorflow/tensorflow/contrib/go"
)
Expand Down Expand Up @@ -31,7 +30,7 @@ func ExampleGraph_Op() {

for i := 0; i < len(inputSlice1); i++ {
val, _ := out[0].GetVal(int64(i))
fmt.Println("The result of: %d + (%d*%d) is: %d", inputSlice1[i], inputSlice2[i], additions, val)
fmt.Printf("The result of: %d + (%d*%d) is: %d\n", inputSlice1[i], inputSlice2[i], additions, val)
}
}

Expand All @@ -46,8 +45,8 @@ func ExampleNewTensor_scalar() {
tensorflow.NewTensor("Hello TensorFlow")
}

func ExampleNewGraphFromText() {
graph, err := tensorflow.NewGraphFromReader(strings.NewReader(`
func ExampleNewGraphFromString() {
graph, err := tensorflow.NewGraphFromString(`
node {
name: "output"
op: "Const"
Expand All @@ -69,9 +68,12 @@ func ExampleNewGraphFromText() {
}
}
}
version: 5`), true)
version: 5`)
if err != nil {
return
}

fmt.Println(graph, err)
fmt.Print(graph)
}

func ExampleGraph_Constant() {
Expand Down Expand Up @@ -145,11 +147,11 @@ func ExampleSession_ExtendGraph() {
s.ExtendGraph(graph)
}

func ExampleNewGraphFromReader() {
// Load the Graph from from a file who contains a previously generated
// Graph as text.
reader, _ := os.Open("/tmp/graph/test_graph.pb")
graph, _ := tensorflow.NewGraphFromReader(reader, true)
func ExampleNewGraphFromBuffer() {
// Load the Graph from from a file containing a serialized
// Graph.
b, _ := ioutil.ReadFile("/tmp/graph/test_graph.pb")
graph, _ := tensorflow.NewGraphFromBuffer(b)

// Create the Session and extend the Graph on it.
s, _ := tensorflow.NewSession()
Expand Down
43 changes: 20 additions & 23 deletions tensorflow/contrib/go/graph.go
Expand Up @@ -2,8 +2,6 @@ package tensorflow

import (
"fmt"
"io"
"io/ioutil"
"strings"

"github.com/golang/protobuf/proto"
Expand Down Expand Up @@ -33,14 +31,14 @@ type GraphNode struct {
outDataTypes map[string]DataType
}

// ErrExpectedVarAsinput is returned when the input value on an operation is
// ErrExpectedVarAsInput is returned when the input value on an operation is
// not a Variable and it must be a Variable.
type ErrExpectedVarAsinput struct {
type ErrExpectedVarAsInput struct {
Op string
InputPos int
}

func (e *ErrExpectedVarAsinput) Error() string {
func (e *ErrExpectedVarAsInput) Error() string {
return fmt.Sprintf(
"The input value at pos %d for the operation '%s' must be of type Variable",
e.InputPos, e.Op)
Expand Down Expand Up @@ -114,23 +112,22 @@ func NewGraph() *Graph {
}
}

// NewGraphFromReader reads from reader until an error or EOF and loads the
// NewGraphFromBuffer reads from reader until an error or EOF and loads the
// content into a new graph. Use the asText parameter to specify if the graph
// from the reader is provided in Text format.
func NewGraphFromReader(reader io.Reader, asText bool) (*Graph, error) {
graphStr, err := ioutil.ReadAll(reader)
if err != nil {
return nil, err
}

gr := NewGraph()
if asText {
err = proto.UnmarshalText(string(graphStr), gr.def)
} else {
err = proto.Unmarshal(graphStr, gr.def)
}
func NewGraphFromBuffer(b []byte) (*Graph, error) {
graph := NewGraph()
err := proto.Unmarshal(b, graph.def)
return graph, err
}

return gr, err
// NewGraphFromString reads from reader until an error or EOF and loads the
// content into a new graph. Use the asText parameter to specify if the graph
// from the reader is provided in Text format.
func NewGraphFromString(s string) (*Graph, error) {
graph := NewGraph()
err := proto.UnmarshalText(s, graph.def)
return graph, err
}

// Op adds a new Node to the Graph with the specified operation. This function
Expand Down Expand Up @@ -159,7 +156,7 @@ func (gr *Graph) Op(opName string, name string, input []*GraphNode, device strin
for i, inNode := range input {
if op.InputArg[i].IsRef {
if inNode.ref == nil {
return nil, &ErrExpectedVarAsinput{
return nil, &ErrExpectedVarAsInput{
Op: opName,
InputPos: i,
}
Expand Down Expand Up @@ -469,9 +466,9 @@ func (gr *Graph) Constant(name string, data interface{}) (*GraphNode, error) {
}

// matchTypes matches all the input/output parameters with their corresponding
// data types specified on the attribues or deducing the data type from other
// data types specified on the attributes or deducing the data type from other
// parameters. This method can return an error if the matching is not possible,
// for instance if two input paramters must have the same data type but one is
// for instance if two input parameters must have the same data type but one is
// int and the other float.
func (gr *Graph) matchTypes(input []*GraphNode, outNode *GraphNode, attrs map[string]interface{}, op *pb.OpDef) error {
// On this part the data type tags are associated with the data type
Expand All @@ -495,7 +492,7 @@ func (gr *Graph) matchTypes(input []*GraphNode, outNode *GraphNode, attrs map[st
}
}

// Now assign all the types we got from the inputs/ouputs to their
// Now assign all the types we got from the inputs/outputs to their
// bound attributes
for _, attr := range op.Attr {
if attr.Type == "type" {
Expand Down
34 changes: 18 additions & 16 deletions tensorflow/contrib/go/session.go
Expand Up @@ -15,13 +15,14 @@ type Session struct {
graph *Graph
}

// ErrStatusTf is an error message comming out from the TensorFlow C++ libraries.
type ErrStatusTf struct {
// ErrStatusCore represents an error message occurred in the TensorFlow C++ libraries.
type ErrStatusCore struct {
Code TF_Code
Message string
}

func (e *ErrStatusTf) Error() string {
// Error implements the error interface.
func (e *ErrStatusCore) Error() string {
return fmt.Sprintf("tensorflow: %d: %v", e.Code, e.Message)
}

Expand All @@ -38,7 +39,7 @@ func NewSession() (*Session, error) {
),
}

if err := s.statusToError(status); err != nil {
if err := statusToError(status); err != nil {
return nil, err
}

Expand All @@ -49,10 +50,10 @@ func NewSession() (*Session, error) {
}

// Run runs the operations on the target nodes, or all the operations if not
// targets are specified. the Parameter Input is a dictionary where the key is
// the Tensor name on the Graph, and the value, the Tensor. The parameter
// targets are specified. The first parameter is a dictionary where the key is
// the name of a Tensor in the Graph and the value is a Tensor. The parameter
// outputs is used to specify the tensors from the graph to be returned in the
// same order as they occur on the slice.
// same order as they occur on the slice. Targets .. ?
func (s *Session) Run(inputs map[string]*Tensor, outputs []string, targets []string) ([]*Tensor, error) {
inputNames := NewStringVector()
inputValues := NewTensorVector()
Expand Down Expand Up @@ -91,7 +92,7 @@ func (s *Session) Run(inputs map[string]*Tensor, outputs []string, targets []str
})
}

return result, s.statusToError(status)
return result, statusToError(status)
}

// ExtendGraph loads the Graph definition into the Session.
Expand All @@ -104,9 +105,13 @@ func (s *Session) ExtendGraph(graph *Graph) error {
}

TF_ExtendGraph(s.session, buf, status)
if statusToError(status); err != nil {
return err
}

s.graph = graph

return s.statusToError(status)
return nil
}

// ExtendAndInitializeAllVariables adds the "init" op to the Graph in order to
Expand All @@ -133,14 +138,11 @@ func (s *Session) FreeAllocMem() {
}

// statusToError converts a TF_Status returned by a C execution into a Go Error.
func (s *Session) statusToError(status TF_Status) error {
code := TF_GetCode(status)
message := TF_Message(status)

if code != 0 {
return &ErrStatusTf{
func statusToError(status TF_Status) error {
if code := TF_GetCode(status); code != 0 {
return &ErrStatusCore{
Code: code,
Message: message,
Message: TF_Message(status),
}
}

Expand Down
6 changes: 3 additions & 3 deletions tensorflow/contrib/go/session_test.go
@@ -1,7 +1,7 @@
package tensorflow_test

import (
"os"
"io/ioutil"
"testing"

tf "github.com/tensorflow/tensorflow/tensorflow/contrib/go"
Expand Down Expand Up @@ -133,11 +133,11 @@ func loadAndExtendGraphFromFile(t *testing.T, filePath string) (s *tf.Session) {
t.Fatal("Error creating Session:", err)
}

reader, err := os.Open(filePath)
b, err := ioutil.ReadFile(filePath)
if err != nil {
t.Fatal("Error reading Graph definition file:", err)
}
graph, err := tf.NewGraphFromReader(reader, true)
graph, err := tf.NewGraphFromBuffer(b)
if err != nil {
t.Fatal(err)
}
Expand Down
14 changes: 8 additions & 6 deletions tensorflow/contrib/go/tensor.go
Expand Up @@ -55,14 +55,14 @@ func (e *ErrInvalidTensorType) Error() string {
return fmt.Sprintf("Invalid tensor data type, tensor data type: '%s', required data type: '%s'", e.TensorType, e.ExpectedType)
}

// ErrTensorTypeNotSupported is returned when the tensor type is still not
// ErrTensorTypeNotSupported is returned when the tensor type is not
// supported.
type ErrTensorTypeNotSupported struct {
TensotType DataType
}

func (e *ErrTensorTypeNotSupported) Error() string {
return fmt.Sprintf("The tensor data type '%s' is still not supported", e.TensotType)
return fmt.Sprintf("The tensor data type '%s' is not supported", e.TensotType)
}

// ErrDimsOutOfTensorRange is returned when the specified number of dimensions
Expand Down Expand Up @@ -97,13 +97,13 @@ func (e *ErrSliceExpected) Error() string {
return fmt.Sprintf("The argument must be a Slice, but the data type is: '%s'", e.DataType)
}

// ErrDataTypeNotSupported is returned when the data type is not suported.
// ErrDataTypeNotSupported is returned when the data type is not supported.
type ErrDataTypeNotSupported struct {
DataType string
}

func (e *ErrDataTypeNotSupported) Error() string {
return fmt.Sprintf("The type of the provided data is not suported: '%s'", e.DataType)
return fmt.Sprintf("The type of the provided data is not supported: '%s'", e.DataType)
}

var (
Expand All @@ -130,7 +130,8 @@ var (
// DTUint16 corresponds to TF_UINT16.
DTUint16 = DataType(TF_UINT16)

// The next data types are still not supported
//NOTE: The next data types are still not supported

// DTBfloat corresponds to TF_BFLOAT16.
DTBfloat = DataType(TF_BFLOAT16)
// DTComplex corresponds to TF_COMPLEX.
Expand Down Expand Up @@ -165,7 +166,8 @@ func NewTensorWithShape(shape TensorShape, data interface{}) (*Tensor, error) {
}
}

dataType, err := getDataTypeFromReflect(v.Type().Elem().Kind(), int64(v.Type().Elem().Size()))
elem := v.Type().Elem()
dataType, err := getDataTypeFromReflect(elem.Kind(), int64(elem.Size()))
if err != nil {
return nil, err
}
Expand Down
8 changes: 3 additions & 5 deletions tensorflow/contrib/go/tensor_test.go
Expand Up @@ -3,14 +3,13 @@ package tensorflow_test
import (
"fmt"
"reflect"
"strings"
"testing"

tf "github.com/tensorflow/tensorflow/tensorflow/contrib/go"
)

func getTensorFromGraph(t *testing.T, dType, shapeVal string) *tf.Tensor {
graph, err := tf.NewGraphFromReader(strings.NewReader(fmt.Sprintf(`
graph, err := tf.NewGraphFromString(`
node {
name: "output"
op: "Const"
Expand All @@ -30,8 +29,7 @@ func getTensorFromGraph(t *testing.T, dType, shapeVal string) *tf.Tensor {
}
}
}
version: 5`,
dType, dType, shapeVal)), true)
version: 5`)
if err != nil {
t.Fatal(err)
}
Expand All @@ -46,7 +44,7 @@ func getTensorFromGraph(t *testing.T, dType, shapeVal string) *tf.Tensor {
}

if len(output) != 1 {
t.Fatalf("Expexted 1 tensor, got: %d tensors", len(output))
t.Fatalf("Expected 1 tensor, got: %d tensors", len(output))
}

return output[0]
Expand Down