From 641da4c8698638e2070ce25092e985fa276f0e3b Mon Sep 17 00:00:00 2001 From: Dmitri Khokhlov Date: Tue, 5 Mar 2024 20:20:53 -0800 Subject: [PATCH] feat: support for list in inputs Signed-off-by: Dmitri Khokhlov --- pkg/predict/input.go | 68 +++++++++++++++++++++++++++++++--------- pkg/predict/predictor.go | 2 +- 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/pkg/predict/input.go b/pkg/predict/input.go index 81f88bd536..aeadd6bbe5 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -1,20 +1,22 @@ package predict import ( + "encoding/json" + "fmt" + "github.com/vincent-petithory/dataurl" + "mime" "os" "path/filepath" "strings" "github.com/mitchellh/go-homedir" - "github.com/vincent-petithory/dataurl" - "github.com/replicate/cog/pkg/util/console" - "github.com/replicate/cog/pkg/util/mime" ) type Input struct { String *string File *string + Array *[]interface{} } type Inputs map[string]Input @@ -25,15 +27,16 @@ func NewInputs(keyVals map[string]string) Inputs { val := val if strings.HasPrefix(val, "@") { val = val[1:] - expandedVal, err := homedir.Expand(val) - if err != nil { - // FIXME: handle this better? - console.Warnf("Error expanding homedir: %s", err) + input[key] = Input{File: &val} + } else if strings.HasPrefix(val, "[") && strings.HasSuffix(val, "]") { + // Handle array of strings + var arr []interface{} + if err := json.Unmarshal([]byte(val), &arr); err != nil { + // Handle JSON unmarshalling error + console.Warnf("Error parsing array for key '%s': %s", key, err) } else { - val = expandedVal + input[key] = Input{Array: &arr} } - - input[key] = Input{File: &val} } else { input[key] = Input{String: &val} } @@ -55,19 +58,54 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { return input } -func (inputs *Inputs) toMap() (map[string]string, error) { - keyVals := map[string]string{} +func (inputs *Inputs) toMap() (map[string]interface{}, error) { + keyVals := map[string]interface{}{} for key, input := range *inputs { if input.String != nil { + // Directly assign the string value keyVals[key] = *input.String } else if input.File != nil { - content, err := os.ReadFile(*input.File) + // Single file handling: read content and convert to a data URL + dataURL, err := fileToDataURL(*input.File) if err != nil { return keyVals, err } - mimeType := mime.TypeByExtension(filepath.Ext(*input.File)) - keyVals[key] = dataurl.New(content, mimeType).String() + keyVals[key] = dataURL + } else if input.Array != nil { + // Handle array, potentially containing file paths + var dataURLs []string + for _, elem := range *input.Array { + if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") { + filePath := str[1:] // Remove '@' prefix + dataURL, err := fileToDataURL(filePath) + if err != nil { + return keyVals, err + } + dataURLs = append(dataURLs, dataURL) + } else if ok { + // Directly use the string if it's not a file path + dataURLs = append(dataURLs, str) + } + } + keyVals[key] = dataURLs } } return keyVals, nil } + +// Helper function to read file content and convert to a data URL +func fileToDataURL(filePath string) (string, error) { + // Expand home directory if necessary + expandedVal, err := homedir.Expand(filePath) + if err != nil { + return "", fmt.Errorf("error expanding homedir for '%s': %w", filePath, err) + } + + content, err := os.ReadFile(expandedVal) + if err != nil { + return "", err + } + mimeType := mime.TypeByExtension(filepath.Ext(expandedVal)) + dataURL := dataurl.New(content, mimeType).String() + return dataURL, nil +} diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index 9f175087c5..bf422a4e7d 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -24,7 +24,7 @@ type HealthcheckResponse struct { type Request struct { // TODO: could this be Inputs? - Input map[string]string `json:"input"` + Input map[string]interface{} `json:"input"` } type Response struct {