Skip to content

Commit

Permalink
review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
dkhokhlov committed Mar 21, 2024
1 parent c8b9353 commit 1b36ac1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 29 deletions.
5 changes: 3 additions & 2 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openap
}

func parseInputFlags(inputs []string) (predict.Inputs, error) {
keyVals := map[string]string{}
keyVals := map[string][]string{}
for _, input := range inputs {
var name, value string

Expand All @@ -305,7 +305,8 @@ func parseInputFlags(inputs []string) (predict.Inputs, error) {
value = value[1 : len(value)-1]
}

keyVals[name] = value
// Append new values to the slice associated with the key
keyVals[name] = append(keyVals[name], value)
}

return predict.NewInputs(keyVals), nil
Expand Down
48 changes: 22 additions & 26 deletions pkg/predict/input.go
Original file line number Diff line number Diff line change
@@ -1,45 +1,41 @@
package predict

import (
"encoding/json"
"fmt"
"mime"
"os"
"path/filepath"
"strings"

"github.com/vincent-petithory/dataurl"

"github.com/mitchellh/go-homedir"
"github.com/replicate/cog/pkg/util/console"
"github.com/vincent-petithory/dataurl"
)

type Input struct {
String *string
File *string
Array *[]interface{}
Array *[]any
}

type Inputs map[string]Input

func NewInputs(keyVals map[string]string) Inputs {
func NewInputs(keyVals map[string][]string) Inputs {
input := Inputs{}
for key, val := range keyVals {
val := val
if strings.HasPrefix(val, "@") {
val = val[1:]
input[key] = Input{File: &val}
} else if strings.HasPrefix(val, "[") && strings.HasSuffix(val, "]") {
// Handle array of strings
var arr []interface{}
if err := json.Unmarshal([]byte(val), &arr); err != nil {
// Handle JSON unmarshalling error
console.Warnf("Error parsing array for key '%s': %s", key, err)
for key, vals := range keyVals {
if len(vals) == 1 {
val := vals[0]
if strings.HasPrefix(val, "@") {
val = val[1:]
input[key] = Input{File: &val}
} else {
input[key] = Input{Array: &arr}
input[key] = Input{String: &val}
}
} else {
input[key] = Input{String: &val}
} else if len(vals) > 1 {
var anyVals = make([]any, len(vals))
for i, v := range vals {
anyVals[i] = v
}
input[key] = Input{Array: &anyVals}
}
}
return input
Expand All @@ -59,8 +55,8 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs {
return input
}

func (inputs *Inputs) toMap() (map[string]interface{}, error) {
keyVals := map[string]interface{}{}
func (inputs *Inputs) toMap() (map[string]any, error) {
keyVals := map[string]any{}
for key, input := range *inputs {
if input.String != nil {
// Directly assign the string value
Expand All @@ -74,18 +70,18 @@ func (inputs *Inputs) toMap() (map[string]interface{}, error) {
keyVals[key] = dataURL
} else if input.Array != nil {
// Handle array, potentially containing file paths
var dataURLs []string
for _, elem := range *input.Array {
dataURLs := make([]string, len(*input.Array))
for i, elem := range *input.Array {
if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") {
filePath := str[1:] // Remove '@' prefix
dataURL, err := fileToDataURL(filePath)
if err != nil {
return keyVals, err
}
dataURLs = append(dataURLs, dataURL)
dataURLs[i] = dataURL
} else if ok {
// Directly use the string if it's not a file path
dataURLs = append(dataURLs, str)
dataURLs[i] = str
}
}
keyVals[key] = dataURLs
Expand Down
2 changes: 1 addition & 1 deletion test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_predict_path_list_input(tmpdir_factory):
fh.write("test1")
with open(out_dir / "2.txt", "w") as fh:
fh.write("test2")
cmd = ["cog", "predict", "-i", "paths=[\"@1.txt\",\"@2.txt\"]"]
cmd = ["cog", "predict", "-i", "paths=@1.txt", "-i", "paths=@2.txt"]

result = subprocess.run(
cmd,
Expand Down

0 comments on commit 1b36ac1

Please sign in to comment.