-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support for list in inputs #1561
Changes from 4 commits
4a16aa1
a679639
13e39ca
c8b9353
71d0d7c
ecdf5b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,20 +1,23 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
package predict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"encoding/json" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"fmt" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"mime" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"os" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"path/filepath" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"strings" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"github.com/mitchellh/go-homedir" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"github.com/vincent-petithory/dataurl" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"github.com/mitchellh/go-homedir" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"github.com/replicate/cog/pkg/util/console" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"github.com/replicate/cog/pkg/util/mime" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
type Input struct { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
String *string | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
File *string | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Array *[]interface{} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two small points:
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
type Inputs map[string]Input | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -25,15 +28,16 @@ func NewInputs(keyVals map[string]string) Inputs { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
val := val | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if strings.HasPrefix(val, "@") { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
val = val[1:] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
expandedVal, err := homedir.Expand(val) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// FIXME: handle this better? | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
console.Warnf("Error expanding homedir: %s", err) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
input[key] = Input{File: &val} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if strings.HasPrefix(val, "[") && strings.HasSuffix(val, "]") { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Handle array of strings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
var arr []interface{} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if err := json.Unmarshal([]byte(val), &arr); err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is JSON unmarshalling right here? That means the strings have to be quoted, which is inconsistent with the other types. For example, I can do
or
but with this implementation, the following will throw an error
I think there are two possibilities here. Either:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I chose repeated inputs. Good idea. Thanks. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Handle JSON unmarshalling error | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
console.Warnf("Error parsing array for key '%s': %s", key, err) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
val = expandedVal | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
input[key] = Input{Array: &arr} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
input[key] = Input{File: &val} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
input[key] = Input{String: &val} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -55,19 +59,54 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return input | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
func (inputs *Inputs) toMap() (map[string]string, error) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
keyVals := map[string]string{} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
func (inputs *Inputs) toMap() (map[string]interface{}, error) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
keyVals := map[string]interface{}{} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for key, input := range *inputs { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if input.String != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Directly assign the string value | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
keyVals[key] = *input.String | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if input.File != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
content, err := os.ReadFile(*input.File) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Single file handling: read content and convert to a data URL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataURL, err := fileToDataURL(*input.File) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return keyVals, err | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mimeType := mime.TypeByExtension(filepath.Ext(*input.File)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
keyVals[key] = dataurl.New(content, mimeType).String() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
keyVals[key] = dataURL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if input.Array != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Handle array, potentially containing file paths | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
var dataURLs []string | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for _, elem := range *input.Array { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
filePath := str[1:] // Remove '@' prefix | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataURL, err := fileToDataURL(filePath) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return keyVals, err | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataURLs = append(dataURLs, dataURL) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else if ok { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Directly use the string if it's not a file path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataURLs = append(dataURLs, str) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
keyVals[key] = dataURLs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not a big deal, but it would be more conventional when you know the length of the slice beforehand to do something like
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return keyVals, nil | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Helper function to read file content and convert to a data URL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
func fileToDataURL(filePath string) (string, error) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Expand home directory if necessary | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
expandedVal, err := homedir.Expand(filePath) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return "", fmt.Errorf("error expanding homedir for '%s': %w", filePath, err) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
content, err := os.ReadFile(expandedVal) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return "", err | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mimeType := mime.TypeByExtension(filepath.Ext(expandedVal)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataURL := dataurl.New(content, mimeType).String() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return dataURL, nil | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||
|
||||
This prints a JSON object describing the inputs of the model. | ||||
""" | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this change is required by one of new linters in |
||||
import json | ||||
|
||||
from ..errors import CogError, ConfigDoesNotExist, PredictorNotSet | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,9 +50,9 @@ class PredictionRequest(PredictionBaseModel): | |
output_file_prefix: t.Optional[str] | ||
|
||
webhook: t.Optional[pydantic.AnyHttpUrl] | ||
webhook_events_filter: t.Optional[ | ||
t.List[WebhookEvent] | ||
] = WebhookEvent.default_events() | ||
webhook_events_filter: t.Optional[t.List[WebhookEvent]] = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changes like this suggest that you don't have your editor set up to do autoformatting the same as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this change is required by one of new linters in our |
||
WebhookEvent.default_events() | ||
) | ||
|
||
@classmethod | ||
def with_types(cls, input_type: t.Type[t.Any]) -> t.Any: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
from cog import BasePredictor, Path | ||
from cog import BasePredictor, File | ||
|
||
|
||
class Predictor(BasePredictor): | ||
def predict(self, path: Path) -> str: | ||
with open(path) as f: | ||
return f.read() | ||
def predict(self, file: File) -> str: | ||
content = file.read() | ||
if isinstance(content, bytes): | ||
# Decode bytes to str assuming UTF-8 encoding; adjust if needed | ||
content = content.decode('utf-8') | ||
return content |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
text |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
build: | ||
python_version: "3.11" | ||
predict: "predict.py:Predictor" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from cog import BasePredictor, File | ||
|
||
|
||
class Predictor(BasePredictor): | ||
def predict(self, files: list[File]) -> str: | ||
output_parts = [] # Use a list to collect file contents | ||
for f in files: | ||
# Assuming file content is in bytes, decode to str before appending | ||
content = f.read() | ||
if isinstance(content, bytes): | ||
# Decode bytes to str assuming UTF-8 encoding; adjust if needed | ||
content = content.decode('utf-8') | ||
output_parts.append(content) | ||
return "\n\n".join(output_parts) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from cog import BasePredictor, Path | ||
|
||
|
||
class Predictor(BasePredictor): | ||
def predict(self, path: Path) -> str: | ||
with open(path) as f: | ||
return f.read() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
build: | ||
python_version: "3.11" | ||
predict: "predict.py:Predictor" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from cog import BasePredictor, Path | ||
|
||
class Predictor(BasePredictor): | ||
def predict(self, paths: list[Path]) -> str: | ||
output_parts = [] # Use a list to collect file contents | ||
for path in paths: | ||
with open(path) as f: | ||
output_parts.append(f.read()) | ||
return "".join(output_parts) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
build: | ||
python_version: "3.8" | ||
predict: "predict.py:Predictor" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import moved, and it should move back :)