Skip to content

Commit

Permalink
Create machine logs to support AI retraining (#585)
Browse files Browse the repository at this point in the history
To support collecting implicit feedback and retraining an AI we need to
create machine readable logs that contain the actual command executed.

The design is described in #574.

With this change and ai logs enabled.
  • Loading branch information
jlewi committed May 24, 2024
1 parent 61fc878 commit 7d1c469
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 33 deletions.
14 changes: 14 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ make wasm

This builds the wasm file to `examples/web/runme.wasm`.

## Install Dev Tools

To install tools like `gofumpt` and `revive` which are used for development (e.g. linting) run

```sh
make install/dev
```

You will need the [pre-commit](https://pre-commit.com/) tool to run the pre-commit hooks. You can install it with:

```sh
python3 -m pip install pre-commit
```

## Linting

Like many complex go projects, this project uses a variety of linting tools to ensure code quality and prevent regressions! The main linter (revive) can be run with:
Expand Down
2 changes: 1 addition & 1 deletion cmd/gqltool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

var (
apiURL = flag.String("api-url", "https://api.stateful.com", "The API base address")
tokenDir = flag.String("token-dir", cmd.GetDefaultConfigHome(), "The directory with tokens")
tokenDir = flag.String("token-dir", cmd.GetUserConfigHome(), "The directory with tokens")
)

func init() {
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/api_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func setAPIFlags(flagSet *pflag.FlagSet) {
flagSet.StringVar(&authBaseURL, authURLF, defaultAuthURL, "Backend URL to authorize you")
flagSet.StringVar(&apiBaseURL, apiURLF, "https://api.stateful.com", "Backend URL with API")
flagSet.StringVar(&apiToken, apiTokenF, "", "API token")
flagSet.StringVar(&configDir, configDirF, GetDefaultConfigHome(), "Location where token will be saved")
flagSet.StringVar(&configDir, configDirF, GetUserConfigHome(), "Location where token will be saved")
flagSet.BoolVar(&trace, traceF, false, "Trace HTTP calls")
flagSet.BoolVar(&traceAll, traceAllF, false, "Trace all HTTP calls including authentication (it might leak sensitive data to output)")
flagSet.BoolVar(&enableChaos, enableChaosF, false, "Enable Chaos Monkey mode for GraphQL requests")
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/code_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func codeServerCmd() *cobra.Command {
return errors.New("currently, we only support coder's code server; please uninstall any other code-server installations to use this feature")
}

configDir := filepath.Join(GetDefaultConfigHome(), "code-server")
configDir := filepath.Join(GetUserConfigHome(), "code-server")

if codeServerConfigFile == "" {
codeServerConfigFile = filepath.Join(configDir, "config.yaml")
Expand Down
146 changes: 122 additions & 24 deletions internal/cmd/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"path/filepath"
"strconv"
"strings"
"time"

"go.uber.org/zap/zapcore"

"github.com/pkg/errors"
"github.com/spf13/cobra"
Expand All @@ -29,7 +32,7 @@ func getIdentityResolver() *identity.IdentityResolver {
}

func getProject() (*project.Project, error) {
logger, err := getLogger(false)
logger, err := getLogger(false, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -148,38 +151,129 @@ func getCodeBlocks() (document.CodeBlocks, error) {
return document.CollectCodeBlocks(node), nil
}

func getLogger(devMode bool) (*zap.Logger, error) {
if !fLogEnabled {
func getLogger(devMode bool, aiLogs bool) (*zap.Logger, error) {
if !fLogEnabled && !aiLogs {
return zap.NewNop(), nil
}

config := zap.Config{
Level: zap.NewAtomicLevelAt(zap.InfoLevel),
Development: false,
Sampling: &zap.SamplingConfig{
Initial: 100,
Thereafter: 100,
},
Encoding: "json",
EncoderConfig: zap.NewProductionEncoderConfig(),
OutputPaths: []string{"stderr"},
ErrorOutputPaths: []string{"stderr"},
cores := make([]zapcore.Core, 0, 2)
if fLogEnabled {
consoleCore, err := createCoreForConsole(devMode)
if err != nil {
return nil, errors.WithStack(err)
}
cores = append(cores, consoleCore)
}

if aiLogs {
aiCore, err := createAICoreLogger()
if err != nil {
return nil, errors.WithStack(err)
}
cores = append(cores, aiCore)
}

if len(cores) == 0 {
return zap.NewNop(), nil
}

// Create a multi-core logger with different encodings
core := zapcore.NewTee(cores...)

// Create the logger
newLogger := zap.New(core)
// Record the caller of the log message
newLogger = newLogger.WithOptions(zap.AddCaller())
return newLogger, nil
}

// createCorForConsole creates a zapcore.Core for console output.
func createCoreForConsole(devMode bool) (zapcore.Core, error) {
if !fLogEnabled {
return zapcore.NewNopCore(), nil
}

encoderConfig := zap.NewProductionEncoderConfig()
lvl := zap.NewAtomicLevelAt(zap.InfoLevel)

if devMode {
config.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
config.Development = true
config.Encoding = "console"
config.EncoderConfig = zap.NewDevelopmentEncoderConfig()
lvl = zap.NewAtomicLevelAt(zap.DebugLevel)
encoderConfig = zap.NewDevelopmentEncoderConfig()
}

path := "stderr"
if fLogFilePath != "" {
config.OutputPaths = []string{fLogFilePath}
config.ErrorOutputPaths = []string{fLogFilePath}
path = fLogFilePath
}

oFile, _, err := zap.Open(path)
if err != nil {
return nil, errors.Wrapf(err, "could not create writer for console logger")
}

l, err := config.Build()
return l, errors.WithStack(err)
var encoder zapcore.Encoder
if devMode {
encoder = zapcore.NewConsoleEncoder(encoderConfig)
} else {
encoder = zapcore.NewJSONEncoder(encoderConfig)
}

core := zapcore.NewCore(encoder, zapcore.AddSync(oFile), lvl)

if !devMode {
// For non-dev mode, add sampling.
core = zapcore.NewSamplerWithOptions(
core,
time.Second,
100,
100,
)
}
return core, nil
}

// createAICoreLogger creates a core logger that writes logs to files. These logs are always written in JSON
// format. Their purpose is to capture AI traces that we use for retraining. Since these are supposed to be machine
// readable they are always written in JSON format.
func createAICoreLogger() (zapcore.Core, error) {
// Configure encoder for JSON format
c := zap.NewProductionEncoderConfig()
// We attach the function key to the logs because that is useful for identifying the function that generated the log.
c.FunctionKey = "function"

jsonEncoder := zapcore.NewJSONEncoder(c)

configDir := getConfigDir()
if configDir == "" {
return nil, errors.New("could not determine config directory")
}
logDir := filepath.Join(configDir, "logs")
if _, err := os.Stat(logDir); os.IsNotExist(err) {
// Logger won't be setup yet so we can't use it.
if _, err := fmt.Fprintf(os.Stdout, "Creating log directory %s\n", logDir); err != nil {
return nil, errors.Wrapf(err, "could not write to stdout")
}
err := os.MkdirAll(logDir, 0o750)
if err != nil {
return nil, errors.Wrapf(err, "could not create log directory %s", logDir)
}
}

// We need to set a unique file name for the logs as a way of dealing with log rotation.
name := fmt.Sprintf("logs.%s.json", time.Now().Format("2006-01-02T15:04:05"))
logFile := filepath.Join(logDir, name)

// TODO(jeremy): How could we handle invoking the log closer if there is one.
oFile, _, err := zap.Open(logFile)
if err != nil {
return nil, errors.Wrapf(err, "could not open log file %s", logFile)
}

// Force log level to be at least info. Because info is the level at which we capture the logs we need for
// tracing.
core := zapcore.NewCore(jsonEncoder, zapcore.AddSync(oFile), zapcore.InfoLevel)

return core, nil
}

func validCmdNames(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
Expand Down Expand Up @@ -230,7 +324,11 @@ func printfInfo(msg string, args ...any) {
_, _ = os.Stderr.Write(buf.Bytes())
}

func GetDefaultConfigHome() string {
// GetUserConfigHome returns the user's configuration directory.
// The user configuration directory should be used for configuration that is specific to the user and thus
// shouldn't be included in project/repository configuration. An example of user location is where server logs
// should be stored.
func GetUserConfigHome() string {
dir, err := os.UserConfigDir()
if err != nil {
dir = os.TempDir()
Expand Down Expand Up @@ -332,7 +430,7 @@ func setRunnerFlags(cmd *cobra.Command, serverAddr *string) func() ([]client.Run

type runFunc func(context.Context) error

var defaultTLSDir = filepath.Join(GetDefaultConfigHome(), "tls")
var defaultTLSDir = filepath.Join(GetUserConfigHome(), "tls")

func promptEnvVars(cmd *cobra.Command, runner client.Runner, tasks ...project.Task) error {
for _, task := range tasks {
Expand Down
6 changes: 4 additions & 2 deletions internal/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func serverCmd() *cobra.Command {
devMode bool
enableRunner bool
tlsDir string
enableAILogs bool
)

cmd := cobra.Command{
Expand All @@ -57,7 +58,7 @@ The kernel is used to run long running processes like shells and interacting wit
}
},
RunE: func(cmd *cobra.Command, args []string) error {
logger, err := getLogger(devMode)
logger, err := getLogger(devMode, enableAILogs)
if err != nil {
return err
}
Expand Down Expand Up @@ -137,8 +138,9 @@ The kernel is used to run long running processes like shells and interacting wit
cmd.Flags().StringVarP(&addr, "address", "a", defaultAddr, "Address to create unix (unix:///path/to/socket) or IP socket (localhost:7890)")
cmd.Flags().BoolVar(&devMode, "dev", false, "Enable development mode")
cmd.Flags().BoolVar(&enableRunner, "runner", true, "Enable runner service (legacy, defaults to true)")
cmd.Flags().BoolVar(&enableAILogs, "ai-logs", false, "Enable logs to support training an AI")
cmd.Flags().StringVar(&tlsDir, "tls", defaultTLSDir, "Directory in which to generate TLS certificates & use for all incoming and outgoing messages")

cmd.Flags().StringVar(&configDir, configDirF, GetUserConfigHome(), "If ai logs is enabled logs will be written to ${config-dir}/logs")
_ = cmd.Flags().MarkHidden("runner")

return &cmd
Expand Down
38 changes: 37 additions & 1 deletion internal/runner/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package runner
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
Expand All @@ -11,6 +12,10 @@ import (
"sync"
"time"

"google.golang.org/protobuf/proto"

"google.golang.org/protobuf/encoding/protojson"

"github.com/creack/pty"
"github.com/gabriel-vasile/mimetype"
"github.com/pkg/errors"
Expand Down Expand Up @@ -216,7 +221,15 @@ func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error
return errors.WithStack(err)
}

logger.Debug("received initial request", zap.Any("req", req))
// We want to always log the request because it is used for AI training.
// see: https://github.com/stateful/runme/issues/574
if req.KnownId != "" {
logger = logger.With(zap.String("knownID", req.KnownId))
}
if req.KnownName != "" {
logger = logger.With(zap.String("knownName", req.KnownName))
}
logger.Info("received initial request", zap.Any("req", zapProto(req, logger)))

createSession := func(envs []string) (*Session, error) {
// todo(sebastian): owl store?
Expand Down Expand Up @@ -539,6 +552,29 @@ func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error
return werr
}

// zapProto is a helper function to be able to log protos as JSON objects.
// We want protos to be logged using the proto json format so we can deserialize them from the logs.
// If you just log a proto with zap it will use the json serialization of the GoLang struct which will not match
// the proto json format. So we serialize the request to JSON and then deserialize it to a map so we can log it as a
// JSON object. A more efficient solution would be to use https://github.com/kazegusuri/go-proto-zap-marshaler
// to generate a custom zapcore.ObjectMarshaler implementation for each proto message.
func zapProto(pb proto.Message, logger *zap.Logger) map[string]interface{} {
reqObj := map[string]interface{}{}
reqJSON, err := protojson.Marshal(pb)
if err != nil {
logger.Error("failed to marshal request", zap.Error(err))
reqObj["error"] = err.Error()
return reqObj
}

if err := json.Unmarshal(reqJSON, &reqObj); err != nil {
logger.Error("failed to unmarshal request", zap.Error(err))
reqObj["error"] = err.Error()
}

return reqObj
}

type output struct {
Stdout []byte
Stderr []byte
Expand Down
Loading

0 comments on commit 7d1c469

Please sign in to comment.