Skip to content

Commit

Permalink
feat: support for list in inputs
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitri Khokhlov <dkhokhlov@gmail.com>
  • Loading branch information
dkhokhlov committed Mar 7, 2024
1 parent 46d0eb3 commit 4a16aa1
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 18 deletions.
67 changes: 53 additions & 14 deletions pkg/predict/input.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
package predict

import (
"encoding/json"
"fmt"
"mime"
"os"
"path/filepath"
"strings"

"github.com/mitchellh/go-homedir"
"github.com/vincent-petithory/dataurl"

"github.com/mitchellh/go-homedir"
"github.com/replicate/cog/pkg/util/console"
"github.com/replicate/cog/pkg/util/mime"
)

type Input struct {
String *string
File *string
Array *[]interface{}
}

type Inputs map[string]Input
Expand All @@ -25,15 +28,16 @@ func NewInputs(keyVals map[string]string) Inputs {
val := val
if strings.HasPrefix(val, "@") {
val = val[1:]
expandedVal, err := homedir.Expand(val)
if err != nil {
// FIXME: handle this better?
console.Warnf("Error expanding homedir: %s", err)
input[key] = Input{File: &val}
} else if strings.HasPrefix(val, "[") && strings.HasSuffix(val, "]") {
// Handle array of strings
var arr []interface{}
if err := json.Unmarshal([]byte(val), &arr); err != nil {
// Handle JSON unmarshalling error
console.Warnf("Error parsing array for key '%s': %s", key, err)
} else {
val = expandedVal
input[key] = Input{Array: &arr}
}

input[key] = Input{File: &val}
} else {
input[key] = Input{String: &val}
}
Expand All @@ -55,19 +59,54 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs {
return input
}

func (inputs *Inputs) toMap() (map[string]string, error) {
keyVals := map[string]string{}
func (inputs *Inputs) toMap() (map[string]interface{}, error) {
keyVals := map[string]interface{}{}
for key, input := range *inputs {
if input.String != nil {
// Directly assign the string value
keyVals[key] = *input.String
} else if input.File != nil {
content, err := os.ReadFile(*input.File)
// Single file handling: read content and convert to a data URL
dataURL, err := fileToDataURL(*input.File)
if err != nil {
return keyVals, err
}
mimeType := mime.TypeByExtension(filepath.Ext(*input.File))
keyVals[key] = dataurl.New(content, mimeType).String()
keyVals[key] = dataURL
} else if input.Array != nil {
// Handle array, potentially containing file paths
var dataURLs []string
for _, elem := range *input.Array {
if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") {
filePath := str[1:] // Remove '@' prefix
dataURL, err := fileToDataURL(filePath)
if err != nil {
return keyVals, err
}
dataURLs = append(dataURLs, dataURL)
} else if ok {
// Directly use the string if it's not a file path
dataURLs = append(dataURLs, str)
}
}
keyVals[key] = dataURLs
}
}
return keyVals, nil
}

// Helper function to read file content and convert to a data URL
func fileToDataURL(filePath string) (string, error) {
// Expand home directory if necessary
expandedVal, err := homedir.Expand(filePath)
if err != nil {
return "", fmt.Errorf("error expanding homedir for '%s': %w", filePath, err)
}

content, err := os.ReadFile(expandedVal)
if err != nil {
return "", err
}
mimeType := mime.TypeByExtension(filepath.Ext(expandedVal))
dataURL := dataurl.New(content, mimeType).String()
return dataURL, nil
}
2 changes: 1 addition & 1 deletion pkg/predict/predictor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type HealthcheckResponse struct {

type Request struct {
// TODO: could this be Inputs?
Input map[string]string `json:"input"`
Input map[string]interface{} `json:"input"`
}

type Response struct {
Expand Down
1 change: 1 addition & 0 deletions python/cog/command/openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This prints a JSON object describing the inputs of the model.
"""

import json

from ..errors import CogError, ConfigDoesNotExist, PredictorNotSet
Expand Down
6 changes: 3 additions & 3 deletions python/cog/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ class PredictionRequest(PredictionBaseModel):
output_file_prefix: t.Optional[str]

webhook: t.Optional[pydantic.AnyHttpUrl]
webhook_events_filter: t.Optional[
t.List[WebhookEvent]
] = WebhookEvent.default_events()
webhook_events_filter: t.Optional[t.List[WebhookEvent]] = (
WebhookEvent.default_events()
)

@classmethod
def with_types(cls, input_type: t.Type[t.Any]) -> t.Any:
Expand Down

0 comments on commit 4a16aa1

Please sign in to comment.