Skip to content

Commit

Permalink
feat: support for list in inputs
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitri Khokhlov <dkhokhlov@gmail.com>
  • Loading branch information
dkhokhlov committed Mar 6, 2024
1 parent 46d0eb3 commit 641da4c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 16 deletions.
68 changes: 53 additions & 15 deletions pkg/predict/input.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package predict

import (
"encoding/json"
"fmt"
"github.com/vincent-petithory/dataurl"
"mime"
"os"
"path/filepath"
"strings"

"github.com/mitchellh/go-homedir"
"github.com/vincent-petithory/dataurl"

"github.com/replicate/cog/pkg/util/console"
"github.com/replicate/cog/pkg/util/mime"
)

type Input struct {
String *string
File *string
Array *[]interface{}
}

type Inputs map[string]Input
Expand All @@ -25,15 +27,16 @@ func NewInputs(keyVals map[string]string) Inputs {
val := val
if strings.HasPrefix(val, "@") {
val = val[1:]
expandedVal, err := homedir.Expand(val)
if err != nil {
// FIXME: handle this better?
console.Warnf("Error expanding homedir: %s", err)
input[key] = Input{File: &val}
} else if strings.HasPrefix(val, "[") && strings.HasSuffix(val, "]") {
// Handle array of strings
var arr []interface{}
if err := json.Unmarshal([]byte(val), &arr); err != nil {
// Handle JSON unmarshalling error
console.Warnf("Error parsing array for key '%s': %s", key, err)
} else {
val = expandedVal
input[key] = Input{Array: &arr}
}

input[key] = Input{File: &val}
} else {
input[key] = Input{String: &val}
}
Expand All @@ -55,19 +58,54 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs {
return input
}

func (inputs *Inputs) toMap() (map[string]string, error) {
keyVals := map[string]string{}
func (inputs *Inputs) toMap() (map[string]interface{}, error) {
keyVals := map[string]interface{}{}
for key, input := range *inputs {
if input.String != nil {
// Directly assign the string value
keyVals[key] = *input.String
} else if input.File != nil {
content, err := os.ReadFile(*input.File)
// Single file handling: read content and convert to a data URL
dataURL, err := fileToDataURL(*input.File)
if err != nil {
return keyVals, err
}
mimeType := mime.TypeByExtension(filepath.Ext(*input.File))
keyVals[key] = dataurl.New(content, mimeType).String()
keyVals[key] = dataURL
} else if input.Array != nil {
// Handle array, potentially containing file paths
var dataURLs []string
for _, elem := range *input.Array {
if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") {
filePath := str[1:] // Remove '@' prefix
dataURL, err := fileToDataURL(filePath)
if err != nil {
return keyVals, err
}
dataURLs = append(dataURLs, dataURL)
} else if ok {
// Directly use the string if it's not a file path
dataURLs = append(dataURLs, str)
}
}
keyVals[key] = dataURLs
}
}
return keyVals, nil
}

// Helper function to read file content and convert to a data URL
func fileToDataURL(filePath string) (string, error) {
// Expand home directory if necessary
expandedVal, err := homedir.Expand(filePath)
if err != nil {
return "", fmt.Errorf("error expanding homedir for '%s': %w", filePath, err)
}

content, err := os.ReadFile(expandedVal)
if err != nil {
return "", err
}
mimeType := mime.TypeByExtension(filepath.Ext(expandedVal))
dataURL := dataurl.New(content, mimeType).String()
return dataURL, nil
}
2 changes: 1 addition & 1 deletion pkg/predict/predictor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type HealthcheckResponse struct {

type Request struct {
// TODO: could this be Inputs?
Input map[string]string `json:"input"`
Input map[string]interface{} `json:"input"`
}

type Response struct {
Expand Down

0 comments on commit 641da4c

Please sign in to comment.