Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion docs/python.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ This _required_ method is where you call the model that was loaded during `setup

The `predict()` method takes an arbitrary list of named arguments, where each argument name must correspond to an [`Input()`](#inputkwargs) annotation.

`predict()` can return strings, numbers, [`cog.Path`](#path) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`Output()`](#outputbasemodel) for more complex return types.
`predict()` can return strings, numbers, [`cog.Path`](#path) or [`cog.Image`](#image) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`Output()`](#outputbasemodel) for more complex return types.

## `async` predictors and concurrency

Expand Down Expand Up @@ -304,6 +304,7 @@ Each parameter of the `predict()` method must be annotated with a type. The meth
- `bool`: a boolean
- [`cog.File`](#file): a file-like object representing a file
- [`cog.Path`](#path): a path to a file on disk
- [`cog.Image`](#image): a path to an image file on disk
- [`cog.Secret`](#secret): a string containing sensitive information

## `File()`
Expand Down Expand Up @@ -351,6 +352,31 @@ class Predictor(BasePredictor):
return Path(output_path)
```

## `Image()`

The `cog.Image` object is used to get image files in and out of models. It represents a _path to an image file on disk_.

`cog.Image` is a subclass of `cog.Path` and works identically to it, but produces OpenAPI schemas with `format: "image"` instead of `format: "uri"`. This hints to UIs that the input or output is specifically an image, allowing them to provide image-specific widgets like image pickers or preview functionality.

For models that return a `cog.Image` object, the prediction output returned by Cog's built-in HTTP server will be a URL.

This example takes an input image, processes it, and returns a processed image:

```python
import tempfile
from cog import BasePredictor, Input, Image

class Predictor(BasePredictor):
def predict(self, image: Image = Input(description="Input image")) -> Image:
processed_image = do_some_processing(image)

# To output `cog.Image` objects the file needs to exist, so create a temporary file first.
# This file will automatically be deleted by Cog after it has been returned.
output_path = Image(tempfile.mkdtemp()) / "processed.png"
processed_image.save(output_path)
return Image(output_path)
```

## `Secret`

The `cog.Secret` type is used to signify that an input holds sensitive information,
Expand Down
2 changes: 2 additions & 0 deletions python/cog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ConcatenateIterator,
ExperimentalFeatureWarning,
File,
Image,
Input,
Path,
Secret,
Expand All @@ -32,6 +33,7 @@
"ConcatenateIterator",
"ExperimentalFeatureWarning",
"File",
"Image",
"Input",
"Path",
"Secret",
Expand Down
23 changes: 19 additions & 4 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,10 @@
"list": "array",
"cog.Path": "string",
"cog.File": "string",
"cog.Image": "string",
"Path": "string",
"File": "string",
"Image": "string",
}


Expand Down Expand Up @@ -446,7 +448,7 @@ def parse_class(classdef: ast.AST) -> "JSONDict":
# cog.File: a file-like object representing a file
# cog.Path: a path to a file on disk

BASE_TYPES = ["str", "int", "float", "bool", "File", "Path"]
BASE_TYPES = ["str", "int", "float", "bool", "File", "Path", "Image"]


def resolve_name(node: ast.expr) -> str:
Expand Down Expand Up @@ -490,7 +492,12 @@ def predict(
if isinstance(annotation, ast.Subscript):
# forget about other subscripts like Optional, and assume otherlib.File will still be an uri
slice = resolve_name(annotation.slice) # pylint: disable=redefined-builtin
format = {"format": "uri"} if slice in ("Path", "File") else {} # pylint: disable=redefined-builtin
if slice == "Image":
format = {"format": "uri", "x-cog-type": "image"} # pylint: disable=redefined-builtin
elif slice in ("Path", "File"):
format = {"format": "uri"} # pylint: disable=redefined-builtin
else:
format = {} # pylint: disable=redefined-builtin
array_type = {"x-cog-array-type": "iterator"} if "Iterator" in name else {}
display_type = (
{"x-cog-array-display": "concatenate"} if "Concatenate" in name else {}
Expand All @@ -507,7 +514,12 @@ def predict(
}
if name in BASE_TYPES:
# otherwise figure this out...
format = {"format": "uri"} if name in ("Path", "File") else {}
if name == "Image":
format = {"format": "uri", "x-cog-type": "image"}
elif name in ("Path", "File"):
format = {"format": "uri"}
else:
format = {}
return {}, {"title": "Output", "type": OPENAPI_TYPES.get(name, name), **format}
# it must be a custom object
schema: JSONDict = {name: parse_class(find(tree, name))}
Expand Down Expand Up @@ -548,7 +560,10 @@ def extract_info(code: str) -> "JSONDict": # pylint: disable=too-many-branches,

annotation = get_annotation(arg.annotation)
arg_type = OPENAPI_TYPES.get(annotation, "string")
if annotation in ("Path", "File"):
if annotation == "Image":
input["format"] = "uri"
input["x-cog-type"] = "image"
elif annotation in ("Path", "File"):
input["format"] = "uri"
elif annotation.startswith("Literal["):
input["enum"] = list(
Expand Down
2 changes: 2 additions & 0 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .types import (
File as CogFile,
)
from .types import Image as CogImage
from .types import (
Path as CogPath,
)
Expand All @@ -64,6 +65,7 @@
CogFile,
CogPath,
CogSecret,
CogImage,
]


Expand Down
66 changes: 66 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,72 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update(type="string", format="uri")


class Image(pathlib.PosixPath): # pylint: disable=abstract-method
"""
Image represents an image file. It's a subclass of pathlib.Path that
accepts URLs (http://, https://, data:) or local file paths.

The OpenAPI schema will have format="uri" with x-cog-type="image" which UIs
can use to provide image-specific widgets like image pickers or preview functionality.
"""

validate_always = True

@classmethod
def validate(cls, value: Any) -> pathlib.Path:
if isinstance(value, pathlib.Path):
return value

return URLPath(
source=value,
filename=get_filename(value),
fileobj=File.validate(value),
)

if PYDANTIC_V2:
from pydantic import GetCoreSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema

@classmethod
def __get_pydantic_core_schema__(
cls,
source: Type[Any], # pylint: disable=unused-argument
handler: "pydantic.GetCoreSchemaHandler", # pylint: disable=unused-argument
) -> "CoreSchema":
from pydantic_core import ( # pylint: disable=import-outside-toplevel
core_schema,
)

return core_schema.union_schema(
[
core_schema.is_instance_schema(pathlib.Path),
core_schema.no_info_plain_validator_function(cls.validate),
]
)

@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: "CoreSchema", handler: "pydantic.GetJsonSchemaHandler"
) -> "JsonSchemaValue": # type: ignore # noqa: F821
json_schema = handler(core_schema)
json_schema.update(type="string", format="uri")
json_schema["x-cog-type"] = "image"
return json_schema

else:

@classmethod
def __get_validators__(cls) -> Iterator[Any]:
yield cls.validate

@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
"""Defines what this type should be in openapi.json"""
field_schema.update(type="string", format="uri")
field_schema["x-cog-type"] = "image"


class URLPath(pathlib.PosixPath): # pylint: disable=abstract-method
"""
URLPath is a nasty hack to ensure that we can defer the downloading of a
Expand Down
8 changes: 8 additions & 0 deletions python/tests/server/fixtures/input_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from cog import BasePredictor, Image


class Predictor(BasePredictor):
def predict(self, image: Image) -> str:
with open(image) as fh:
extension = fh.name.split(".")[-1]
return f"{extension} {fh.read()}"
9 changes: 9 additions & 0 deletions python/tests/server/fixtures/openapi_image_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from cog import BasePredictor, Image, Input


class Predictor(BasePredictor):
def predict(
self,
image: Image = Input(description="An input image"),
) -> str:
return "success"
12 changes: 12 additions & 0 deletions python/tests/server/fixtures/output_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import tempfile
from pathlib import Path

from cog import BasePredictor, Image


class Predictor(BasePredictor):
def predict(self) -> Image:
output_path = Path(tempfile.mkdtemp()) / "output.jpg"
with open(output_path, "w") as fh:
fh.write("test image data")
return Image(output_path)
23 changes: 23 additions & 0 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,29 @@ def test_openapi_specification_with_custom_user_defined_output_type(
}


@uses_predictor("openapi_image_input")
def test_openapi_specification_with_image_input(client):
resp = client.get("/openapi.json")
assert resp.status_code == 200

schema = resp.json()
assert schema["components"]["schemas"]["Input"] == {
"title": "Input",
"required": ["image"],
"type": "object",
"properties": {
"image": {
"title": "Image",
"description": "An input image",
"type": "string",
"format": "uri",
"x-cog-type": "image",
"x-order": 0,
},
},
}


@uses_predictor("openapi_optional_output_type")
def test_openapi_specification_with_optional_output_type(
client,
Expand Down
42 changes: 41 additions & 1 deletion python/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import responses

from cog.types import Secret, URLFile, URLPath, get_filename
from cog.types import Image, Secret, URLFile, URLPath, get_filename


def test_urlfile_protocol_validation():
Expand Down Expand Up @@ -167,3 +167,43 @@ def test_truncate_filename_if_long():
)
assert url_path.filename == "waffwyyg~~.zip"
_ = url_path.convert()


def test_image_type_validates_urls():
"""Test that Image type can validate HTTP/HTTPS URLs"""
from pathlib import Path

# Should accept HTTP URLs
result = Image.validate("https://example.com/image.jpg")
assert isinstance(result, URLPath)
assert str(result) == "https://example.com/image.jpg"

# Should accept local paths
result = Image.validate(Path("/tmp/image.png"))
assert isinstance(result, Path)


def test_image_type_pydantic_schema():
"""Test that Image type produces correct OpenAPI schema format"""
from cog.types import PYDANTIC_V2

if PYDANTIC_V2:
from pydantic import BaseModel

class TestModel(BaseModel):
image: Image

schema = TestModel.model_json_schema()
assert schema["properties"]["image"]["type"] == "string"
assert schema["properties"]["image"]["format"] == "uri"
assert schema["properties"]["image"]["x-cog-type"] == "image"
else:
from pydantic import BaseModel

class TestModel(BaseModel):
image: Image

schema = TestModel.schema()
assert schema["properties"]["image"]["type"] == "string"
assert schema["properties"]["image"]["format"] == "uri"
assert schema["properties"]["image"]["x-cog-type"] == "image"
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.12"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from cog import BasePredictor, Image


class Predictor(BasePredictor):
def predict(self, image: Image) -> str:
with open(image) as f:
content = f.read()
return f"image: {content}"
18 changes: 18 additions & 0 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,24 @@ def test_predict_many_inputs_with_existing_image(
assert "falling back to slow loader" not in str(result.stderr)


def test_predict_image_input(tmpdir_factory, cog_binary):
project_dir = Path(__file__).parent / "fixtures/image-input-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
with open(out_dir / "test.jpg", "w", encoding="utf-8") as fh:
fh.write("fake image data")
cmd = [cog_binary, "predict", "-i", "image=@test.jpg"]

result = subprocess.run(
cmd,
cwd=out_dir,
check=True,
capture_output=True,
text=True,
)
assert result.stdout == "image: fake image data\n"


def test_predict_path_list_input(tmpdir_factory, cog_binary):
project_dir = Path(__file__).parent / "fixtures/path-list-input-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
Expand Down
Loading