Skip to content

Commit

Permalink
Write model to temporary directory (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
robherley committed Apr 25, 2023
1 parent 1d1d79e commit 79be010
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 24 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ go get github.com/robherley/guesslang-go

See example usage in [`examples/main.go`](/example/main.go)

## Caveats

To work around some of the limitations of the Go TensorFlow bindings (and the wrapper library)[^1], the [SavedModel](https://www.tensorflow.org/guide/saved_model) is embeded in the binary and
when a [`Guesser`](https://pkg.go.dev/github.com/robherley/guesslang-go/pkg/guesser#Guesser) is initialized, it temporarily writes the model to a directory (and removes it after).

So, in order to use this package, you must at least have a writeable temporary directory that aligns with Go's [`os.TempDir()`](https://pkg.go.dev/os@go1.20.3#TempDir).

[^1]: https://github.com/galeone/tfgo/issues/44#issuecomment-841806254

## Acknowledgements

Powered by:
Expand Down
13 changes: 9 additions & 4 deletions pkg/guesser/guesser.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Guesser struct {
model *tg.Model
}

// New initializes a guesslang model
// New initializes a guesslang model. It will write the TensorFlow SavedModel temporarily to disk, then load it.
func New() (g *Guesser, err error) {
defer func() {
// unfortunately, the tfgo library panics instead of returning errors
Expand All @@ -28,9 +28,14 @@ func New() (g *Guesser, err error) {
}
}()

return &Guesser{
model: tg.LoadModel(hackyPathToModel(), []string{"serve"}, nil),
}, nil
modelPath, err := writeModelToTempDir()
if err != nil {
return nil, err
}
defer os.RemoveAll(modelPath)

model := tg.LoadModel(modelPath, []string{"serve"}, nil)
return &Guesser{model}, nil
}

// Guess executes the guesslang model on a code snippet, providing sorted confidences for all of the supported languages
Expand Down
File renamed without changes.
18 changes: 0 additions & 18 deletions pkg/guesser/hack.go

This file was deleted.

68 changes: 68 additions & 0 deletions pkg/guesser/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package guesser

import (
"embed"
"io"
"os"
"path/filepath"
)

var (
//go:embed model/*
savedModel embed.FS
)

// as far as I know, there is no way to load a SavedModel/graph (with variables) from memory with the go tf library
// https://github.com/galeone/tfgo/issues/44#issuecomment-841806254
func writeModelToTempDir() (string, error) {
modelPath, err := os.MkdirTemp("", "guesslang-go")
if err != nil {
return "", err
}

savedModelFile, err := savedModel.Open("model/saved_model.pb")
if err != nil {
return "", err
}
defer savedModelFile.Close()

savedModelFileDisk, err := os.Create(modelPath + "/saved_model.pb")
if err != nil {
return "", err
}
defer savedModelFileDisk.Close()

if _, err := io.Copy(savedModelFileDisk, savedModelFile); err != nil {
return "", err
}

err = os.MkdirAll(filepath.Join(modelPath, "variables"), 0755)
if err != nil {
return "", err
}

variablesFiles, err := savedModel.ReadDir("model/variables")
if err != nil {
return "", err
}

for _, file := range variablesFiles {
variablesFile, err := savedModel.Open("model/variables/" + file.Name())
if err != nil {
return "", err
}
defer variablesFile.Close()

variablesFileDisk, err := os.Create(filepath.Join(modelPath, "variables", file.Name()))
if err != nil {
return "", err
}
defer variablesFileDisk.Close()

if _, err := io.Copy(variablesFileDisk, variablesFile); err != nil {
return "", err
}
}

return modelPath, nil
}
12 changes: 10 additions & 2 deletions script/install-libtensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ if [ "$(go env GOARCH)" != "amd64" ]; then
exit 1
fi

function maybe_sudo() {
if [ "$(id -u)" -ne 0 ]; then
sudo "$@"
else
"$@"
fi
}

OS=$(go env GOOS)

case "$OS" in
Expand All @@ -21,8 +29,8 @@ case "$OS" in
TENSORFLOW_VERSION=2.11.0
ARCHIVE_PATH=$(mktemp -d)/tensorflow.tar.gz
curl -o "$ARCHIVE_PATH" "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-${OS}-x86_64-${TENSORFLOW_VERSION}.tar.gz"
sudo tar -xzf "$ARCHIVE_PATH" -C /usr/local
sudo ldconfig /usr/local/lib
maybe_sudo tar -xzf "$ARCHIVE_PATH" -C /usr/local
maybe_sudo ldconfig /usr/local/lib
rm -rf "$ARCHIVE_PATH"
;;
*)
Expand Down

0 comments on commit 79be010

Please sign in to comment.