-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support for list in inputs #1561
Changes from all commits
4a16aa1
a679639
13e39ca
c8b9353
71d0d7c
ecdf5b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||
|
||||
This prints a JSON object describing the inputs of the model. | ||||
""" | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this change is required by one of new linters in |
||||
import json | ||||
|
||||
from ..errors import CogError, ConfigDoesNotExist, PredictorNotSet | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,9 +50,9 @@ class PredictionRequest(PredictionBaseModel): | |
output_file_prefix: t.Optional[str] | ||
|
||
webhook: t.Optional[pydantic.AnyHttpUrl] | ||
webhook_events_filter: t.Optional[ | ||
t.List[WebhookEvent] | ||
] = WebhookEvent.default_events() | ||
webhook_events_filter: t.Optional[t.List[WebhookEvent]] = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changes like this suggest that you don't have your editor set up to do autoformatting the same as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this change is required by one of new linters in our |
||
WebhookEvent.default_events() | ||
) | ||
|
||
@classmethod | ||
def with_types(cls, input_type: t.Type[t.Any]) -> t.Any: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
from cog import BasePredictor, Path | ||
from cog import BasePredictor, File | ||
|
||
|
||
class Predictor(BasePredictor): | ||
def predict(self, path: Path) -> str: | ||
with open(path) as f: | ||
return f.read() | ||
def predict(self, file: File) -> str: | ||
content = file.read() | ||
if isinstance(content, bytes): | ||
# Decode bytes to str assuming UTF-8 encoding; adjust if needed | ||
content = content.decode('utf-8') | ||
return content |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
text |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
build: | ||
python_version: "3.11" | ||
predict: "predict.py:Predictor" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from cog import BasePredictor, File | ||
|
||
|
||
class Predictor(BasePredictor): | ||
def predict(self, files: list[File]) -> str: | ||
output_parts = [] # Use a list to collect file contents | ||
for f in files: | ||
# Assuming file content is in bytes, decode to str before appending | ||
content = f.read() | ||
if isinstance(content, bytes): | ||
# Decode bytes to str assuming UTF-8 encoding; adjust if needed | ||
content = content.decode('utf-8') | ||
output_parts.append(content) | ||
return "\n\n".join(output_parts) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from cog import BasePredictor, Path | ||
|
||
|
||
class Predictor(BasePredictor): | ||
def predict(self, path: Path) -> str: | ||
with open(path) as f: | ||
return f.read() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
build: | ||
python_version: "3.11" | ||
predict: "predict.py:Predictor" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from cog import BasePredictor, Path | ||
|
||
class Predictor(BasePredictor): | ||
def predict(self, paths: list[Path]) -> str: | ||
output_parts = [] # Use a list to collect file contents | ||
for path in paths: | ||
with open(path) as f: | ||
output_parts.append(f.read()) | ||
return "".join(output_parts) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
build: | ||
python_version: "3.8" | ||
predict: "predict.py:Predictor" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import moved, and it should move back :)