diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 63afa35349..80a88edea1 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -288,7 +288,7 @@ func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openap } func parseInputFlags(inputs []string) (predict.Inputs, error) { - keyVals := map[string]string{} + keyVals := map[string][]string{} for _, input := range inputs { var name, value string @@ -305,7 +305,8 @@ func parseInputFlags(inputs []string) (predict.Inputs, error) { value = value[1 : len(value)-1] } - keyVals[name] = value + // Append new values to the slice associated with the key + keyVals[name] = append(keyVals[name], value) } return predict.NewInputs(keyVals), nil diff --git a/pkg/predict/input.go b/pkg/predict/input.go index bc826d63b7..0f42e3dc66 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -1,45 +1,41 @@ package predict import ( - "encoding/json" "fmt" "mime" "os" "path/filepath" "strings" - "github.com/vincent-petithory/dataurl" - "github.com/mitchellh/go-homedir" - "github.com/replicate/cog/pkg/util/console" + "github.com/vincent-petithory/dataurl" ) type Input struct { String *string File *string - Array *[]interface{} + Array *[]any } type Inputs map[string]Input -func NewInputs(keyVals map[string]string) Inputs { +func NewInputs(keyVals map[string][]string) Inputs { input := Inputs{} - for key, val := range keyVals { - val := val - if strings.HasPrefix(val, "@") { - val = val[1:] - input[key] = Input{File: &val} - } else if strings.HasPrefix(val, "[") && strings.HasSuffix(val, "]") { - // Handle array of strings - var arr []interface{} - if err := json.Unmarshal([]byte(val), &arr); err != nil { - // Handle JSON unmarshalling error - console.Warnf("Error parsing array for key '%s': %s", key, err) + for key, vals := range keyVals { + if len(vals) == 1 { + val := vals[0] + if strings.HasPrefix(val, "@") { + val = val[1:] + input[key] = Input{File: &val} } else { - input[key] = Input{Array: &arr} + input[key] = Input{String: &val} } - } else { - input[key] = Input{String: &val} + } else if len(vals) > 1 { + var anyVals = make([]any, len(vals)) + for i, v := range vals { + anyVals[i] = v + } + input[key] = Input{Array: &anyVals} } } return input @@ -59,8 +55,8 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { return input } -func (inputs *Inputs) toMap() (map[string]interface{}, error) { - keyVals := map[string]interface{}{} +func (inputs *Inputs) toMap() (map[string]any, error) { + keyVals := map[string]any{} for key, input := range *inputs { if input.String != nil { // Directly assign the string value @@ -74,18 +70,18 @@ func (inputs *Inputs) toMap() (map[string]interface{}, error) { keyVals[key] = dataURL } else if input.Array != nil { // Handle array, potentially containing file paths - var dataURLs []string - for _, elem := range *input.Array { + dataURLs := make([]string, len(*input.Array)) + for i, elem := range *input.Array { if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") { filePath := str[1:] // Remove '@' prefix dataURL, err := fileToDataURL(filePath) if err != nil { return keyVals, err } - dataURLs = append(dataURLs, dataURL) + dataURLs[i] = dataURL } else if ok { // Directly use the string if it's not a file path - dataURLs = append(dataURLs, str) + dataURLs[i] = str } } keyVals[key] = dataURLs diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index eeebeaa926..91d3f4d3ca 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -239,7 +239,7 @@ def test_predict_path_list_input(tmpdir_factory): fh.write("test1") with open(out_dir / "2.txt", "w") as fh: fh.write("test2") - cmd = ["cog", "predict", "-i", "paths=[\"@1.txt\",\"@2.txt\"]"] + cmd = ["cog", "predict", "-i", "paths=@1.txt", "-i", "paths=@2.txt"] result = subprocess.run( cmd,