diff --git a/docs/python.md b/docs/python.md index d5f5ddc14b..f1b44c79b7 100644 --- a/docs/python.md +++ b/docs/python.md @@ -89,7 +89,7 @@ This _required_ method is where you call the model that was loaded during `setup The `predict()` method takes an arbitrary list of named arguments, where each argument name must correspond to an [`Input()`](#inputkwargs) annotation. -`predict()` can return strings, numbers, [`cog.Path`](#path) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`Output()`](#outputbasemodel) for more complex return types. +`predict()` can return strings, numbers, [`cog.Path`](#path) or [`cog.Image`](#image) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`Output()`](#outputbasemodel) for more complex return types. ## `async` predictors and concurrency @@ -304,6 +304,7 @@ Each parameter of the `predict()` method must be annotated with a type. The meth - `bool`: a boolean - [`cog.File`](#file): a file-like object representing a file - [`cog.Path`](#path): a path to a file on disk +- [`cog.Image`](#image): a path to an image file on disk - [`cog.Secret`](#secret): a string containing sensitive information ## `File()` @@ -351,6 +352,31 @@ class Predictor(BasePredictor): return Path(output_path) ``` +## `Image()` + +The `cog.Image` object is used to get image files in and out of models. It represents a _path to an image file on disk_. + +`cog.Image` is a subclass of `cog.Path` and works identically to it, but produces OpenAPI schemas with `format: "image"` instead of `format: "uri"`. This hints to UIs that the input or output is specifically an image, allowing them to provide image-specific widgets like image pickers or preview functionality. + +For models that return a `cog.Image` object, the prediction output returned by Cog's built-in HTTP server will be a URL. + +This example takes an input image, processes it, and returns a processed image: + +```python +import tempfile +from cog import BasePredictor, Input, Image + +class Predictor(BasePredictor): + def predict(self, image: Image = Input(description="Input image")) -> Image: + processed_image = do_some_processing(image) + + # To output `cog.Image` objects the file needs to exist, so create a temporary file first. + # This file will automatically be deleted by Cog after it has been returned. + output_path = Image(tempfile.mkdtemp()) / "processed.png" + processed_image.save(output_path) + return Image(output_path) +``` + ## `Secret` The `cog.Secret` type is used to signify that an input holds sensitive information, diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 6824cec5e2..aab86e213d 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -10,6 +10,7 @@ ConcatenateIterator, ExperimentalFeatureWarning, File, + Image, Input, Path, Secret, @@ -32,6 +33,7 @@ "ConcatenateIterator", "ExperimentalFeatureWarning", "File", + "Image", "Input", "Path", "Secret", diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index 9afb03e391..c003a82eae 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -307,8 +307,10 @@ "list": "array", "cog.Path": "string", "cog.File": "string", + "cog.Image": "string", "Path": "string", "File": "string", + "Image": "string", } @@ -446,7 +448,7 @@ def parse_class(classdef: ast.AST) -> "JSONDict": # cog.File: a file-like object representing a file # cog.Path: a path to a file on disk -BASE_TYPES = ["str", "int", "float", "bool", "File", "Path"] +BASE_TYPES = ["str", "int", "float", "bool", "File", "Path", "Image"] def resolve_name(node: ast.expr) -> str: @@ -490,7 +492,12 @@ def predict( if isinstance(annotation, ast.Subscript): # forget about other subscripts like Optional, and assume otherlib.File will still be an uri slice = resolve_name(annotation.slice) # pylint: disable=redefined-builtin - format = {"format": "uri"} if slice in ("Path", "File") else {} # pylint: disable=redefined-builtin + if slice == "Image": + format = {"format": "uri", "x-cog-type": "image"} # pylint: disable=redefined-builtin + elif slice in ("Path", "File"): + format = {"format": "uri"} # pylint: disable=redefined-builtin + else: + format = {} # pylint: disable=redefined-builtin array_type = {"x-cog-array-type": "iterator"} if "Iterator" in name else {} display_type = ( {"x-cog-array-display": "concatenate"} if "Concatenate" in name else {} @@ -507,7 +514,12 @@ def predict( } if name in BASE_TYPES: # otherwise figure this out... - format = {"format": "uri"} if name in ("Path", "File") else {} + if name == "Image": + format = {"format": "uri", "x-cog-type": "image"} + elif name in ("Path", "File"): + format = {"format": "uri"} + else: + format = {} return {}, {"title": "Output", "type": OPENAPI_TYPES.get(name, name), **format} # it must be a custom object schema: JSONDict = {name: parse_class(find(tree, name))} @@ -548,7 +560,10 @@ def extract_info(code: str) -> "JSONDict": # pylint: disable=too-many-branches, annotation = get_annotation(arg.annotation) arg_type = OPENAPI_TYPES.get(annotation, "string") - if annotation in ("Path", "File"): + if annotation == "Image": + input["format"] = "uri" + input["x-cog-type"] = "image" + elif annotation in ("Path", "File"): input["format"] = "uri" elif annotation.startswith("Literal["): input["enum"] = list( diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 27c65d616a..231454b1b1 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -44,6 +44,7 @@ from .types import ( File as CogFile, ) +from .types import Image as CogImage from .types import ( Path as CogPath, ) @@ -64,6 +65,7 @@ CogFile, CogPath, CogSecret, + CogImage, ] diff --git a/python/cog/types.py b/python/cog/types.py index cc240a0647..552f2eb7eb 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -254,6 +254,72 @@ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: field_schema.update(type="string", format="uri") +class Image(pathlib.PosixPath): # pylint: disable=abstract-method + """ + Image represents an image file. It's a subclass of pathlib.Path that + accepts URLs (http://, https://, data:) or local file paths. + + The OpenAPI schema will have format="uri" with x-cog-type="image" which UIs + can use to provide image-specific widgets like image pickers or preview functionality. + """ + + validate_always = True + + @classmethod + def validate(cls, value: Any) -> pathlib.Path: + if isinstance(value, pathlib.Path): + return value + + return URLPath( + source=value, + filename=get_filename(value), + fileobj=File.validate(value), + ) + + if PYDANTIC_V2: + from pydantic import GetCoreSchemaHandler + from pydantic.json_schema import JsonSchemaValue + from pydantic_core import CoreSchema + + @classmethod + def __get_pydantic_core_schema__( + cls, + source: Type[Any], # pylint: disable=unused-argument + handler: "pydantic.GetCoreSchemaHandler", # pylint: disable=unused-argument + ) -> "CoreSchema": + from pydantic_core import ( # pylint: disable=import-outside-toplevel + core_schema, + ) + + return core_schema.union_schema( + [ + core_schema.is_instance_schema(pathlib.Path), + core_schema.no_info_plain_validator_function(cls.validate), + ] + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: "CoreSchema", handler: "pydantic.GetJsonSchemaHandler" + ) -> "JsonSchemaValue": # type: ignore # noqa: F821 + json_schema = handler(core_schema) + json_schema.update(type="string", format="uri") + json_schema["x-cog-type"] = "image" + return json_schema + + else: + + @classmethod + def __get_validators__(cls) -> Iterator[Any]: + yield cls.validate + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + """Defines what this type should be in openapi.json""" + field_schema.update(type="string", format="uri") + field_schema["x-cog-type"] = "image" + + class URLPath(pathlib.PosixPath): # pylint: disable=abstract-method """ URLPath is a nasty hack to ensure that we can defer the downloading of a diff --git a/python/tests/server/fixtures/input_image.py b/python/tests/server/fixtures/input_image.py new file mode 100644 index 0000000000..8f6ae53958 --- /dev/null +++ b/python/tests/server/fixtures/input_image.py @@ -0,0 +1,8 @@ +from cog import BasePredictor, Image + + +class Predictor(BasePredictor): + def predict(self, image: Image) -> str: + with open(image) as fh: + extension = fh.name.split(".")[-1] + return f"{extension} {fh.read()}" diff --git a/python/tests/server/fixtures/openapi_image_input.py b/python/tests/server/fixtures/openapi_image_input.py new file mode 100644 index 0000000000..6a881a6051 --- /dev/null +++ b/python/tests/server/fixtures/openapi_image_input.py @@ -0,0 +1,9 @@ +from cog import BasePredictor, Image, Input + + +class Predictor(BasePredictor): + def predict( + self, + image: Image = Input(description="An input image"), + ) -> str: + return "success" diff --git a/python/tests/server/fixtures/output_image.py b/python/tests/server/fixtures/output_image.py new file mode 100644 index 0000000000..37adbc0b5d --- /dev/null +++ b/python/tests/server/fixtures/output_image.py @@ -0,0 +1,12 @@ +import tempfile +from pathlib import Path + +from cog import BasePredictor, Image + + +class Predictor(BasePredictor): + def predict(self) -> Image: + output_path = Path(tempfile.mkdtemp()) / "output.jpg" + with open(output_path, "w") as fh: + fh.write("test image data") + return Image(output_path) diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index df2dc18c06..6ac43e0f09 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -271,6 +271,29 @@ def test_openapi_specification_with_custom_user_defined_output_type( } +@uses_predictor("openapi_image_input") +def test_openapi_specification_with_image_input(client): + resp = client.get("/openapi.json") + assert resp.status_code == 200 + + schema = resp.json() + assert schema["components"]["schemas"]["Input"] == { + "title": "Input", + "required": ["image"], + "type": "object", + "properties": { + "image": { + "title": "Image", + "description": "An input image", + "type": "string", + "format": "uri", + "x-cog-type": "image", + "x-order": 0, + }, + }, + } + + @uses_predictor("openapi_optional_output_type") def test_openapi_specification_with_optional_output_type( client, diff --git a/python/tests/test_types.py b/python/tests/test_types.py index bef34aae78..6587b75cc3 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -6,7 +6,7 @@ import pytest import responses -from cog.types import Secret, URLFile, URLPath, get_filename +from cog.types import Image, Secret, URLFile, URLPath, get_filename def test_urlfile_protocol_validation(): @@ -167,3 +167,43 @@ def test_truncate_filename_if_long(): ) assert url_path.filename == "waffwyyg~~.zip" _ = url_path.convert() + + +def test_image_type_validates_urls(): + """Test that Image type can validate HTTP/HTTPS URLs""" + from pathlib import Path + + # Should accept HTTP URLs + result = Image.validate("https://example.com/image.jpg") + assert isinstance(result, URLPath) + assert str(result) == "https://example.com/image.jpg" + + # Should accept local paths + result = Image.validate(Path("/tmp/image.png")) + assert isinstance(result, Path) + + +def test_image_type_pydantic_schema(): + """Test that Image type produces correct OpenAPI schema format""" + from cog.types import PYDANTIC_V2 + + if PYDANTIC_V2: + from pydantic import BaseModel + + class TestModel(BaseModel): + image: Image + + schema = TestModel.model_json_schema() + assert schema["properties"]["image"]["type"] == "string" + assert schema["properties"]["image"]["format"] == "uri" + assert schema["properties"]["image"]["x-cog-type"] == "image" + else: + from pydantic import BaseModel + + class TestModel(BaseModel): + image: Image + + schema = TestModel.schema() + assert schema["properties"]["image"]["type"] == "string" + assert schema["properties"]["image"]["format"] == "uri" + assert schema["properties"]["image"]["x-cog-type"] == "image" diff --git a/test-integration/test_integration/fixtures/image-input-project/cog.yaml b/test-integration/test_integration/fixtures/image-input-project/cog.yaml new file mode 100644 index 0000000000..defc03d8cf --- /dev/null +++ b/test-integration/test_integration/fixtures/image-input-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.12" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/image-input-project/predict.py b/test-integration/test_integration/fixtures/image-input-project/predict.py new file mode 100644 index 0000000000..1ba506c4bb --- /dev/null +++ b/test-integration/test_integration/fixtures/image-input-project/predict.py @@ -0,0 +1,8 @@ +from cog import BasePredictor, Image + + +class Predictor(BasePredictor): + def predict(self, image: Image) -> str: + with open(image) as f: + content = f.read() + return f"image: {content}" diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index cfcfe7650f..9c5bc5261f 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -294,6 +294,24 @@ def test_predict_many_inputs_with_existing_image( assert "falling back to slow loader" not in str(result.stderr) +def test_predict_image_input(tmpdir_factory, cog_binary): + project_dir = Path(__file__).parent / "fixtures/image-input-project" + out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) + shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) + with open(out_dir / "test.jpg", "w", encoding="utf-8") as fh: + fh.write("fake image data") + cmd = [cog_binary, "predict", "-i", "image=@test.jpg"] + + result = subprocess.run( + cmd, + cwd=out_dir, + check=True, + capture_output=True, + text=True, + ) + assert result.stdout == "image: fake image data\n" + + def test_predict_path_list_input(tmpdir_factory, cog_binary): project_dir = Path(__file__).parent / "fixtures/path-list-input-project" out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))