Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
0268e21
feat: reclaim 'cog run' for predictions, deprecate 'cog predict'
markphelps Apr 9, 2026
7e33957
feat: support 'run:' key in cog.yaml alongside 'predict:'
markphelps Apr 9, 2026
f855a39
feat: add BaseRunner with run() method to Python SDK
markphelps Apr 9, 2026
11c7098
feat: detect run() vs predict() in Rust coglet
markphelps Apr 9, 2026
e217084
feat: schema parser tries run() before predict()
markphelps Apr 9, 2026
de3141c
feat: cog init generates run.py with BaseRunner
markphelps Apr 9, 2026
3d5c6e0
test: rename predict->run in integration tests
markphelps Apr 9, 2026
ce50efd
test: add backwards compatibility integration tests
markphelps Apr 9, 2026
311a07d
docs: update all docs to use run/Runner/BaseRunner as primary
markphelps Apr 9, 2026
5d86228
chore: regenerate CLI docs, llms.txt, and apply rust fmt
markphelps Apr 9, 2026
20b28b6
fix: walk MRO for method detection, fix return type, warn on ambiguou…
markphelps Apr 9, 2026
e118403
fix: revert legacy test, use &'static str for method name, fix wsl2 docs
markphelps Apr 9, 2026
1af85f0
fix: review cleanup - validation, schema desc, template comment, MRO …
markphelps Apr 9, 2026
222a226
fix: remove needless borrows after &'static str change
markphelps Apr 9, 2026
b4b5794
merge: resolve conflict with main (setup_is_async + predict_method_name)
markphelps Apr 10, 2026
7baa3e5
refactor: collapse BaseRunner/BasePredictor into single class with alias
markphelps Apr 10, 2026
9826069
fix: CI failures - update test assertion, fix no_predictor match, rus…
markphelps Apr 10, 2026
6728a5a
fix: address review findings (2 medium, 5 low)
markphelps Apr 10, 2026
2fc8930
fix: tighten no_predictor test assertion, mark predict schema deprecated
michaeldwan Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Cog is a tool that packages machine learning models in production-ready containe

It consists of:
- **Cog CLI** (`cmd/cog/`) - Command-line interface for building, running, and deploying models, written in Go
- **Python SDK** (`python/cog/`) - Python library for defining model predictors and training in Python
- **Python SDK** (`python/cog/`) - Python library for defining model runners and training in Python
- **Coglet** (`crates/`) - Rust-based prediction server that runs inside containers, with Python bindings via PyO3

Documentation for the CLI and SDK is available by reading ./docs/llms.txt.
Expand Down Expand Up @@ -159,7 +159,7 @@ For detailed architecture documentation, see `crates/README.md` and `crates/cogl

The main commands for working on Coglet are:
- `mise run build:coglet` - Build and install coglet wheel for development (macOS, for local Rust/Python tests)
- `mise run build:coglet:wheel:linux-x64` - Build Linux x86_64 wheel (required to test Rust changes in Docker containers via `cog predict`/`cog train`)
- `mise run build:coglet:wheel:linux-x64` - Build Linux x86_64 wheel (required to test Rust changes in Docker containers via `cog run`/`cog train`)
- `mise run test:rust` - Run Rust unit tests
- `mise run lint:rust` - Run clippy linter
- `mise run fmt:rust:fix` - Format code
Expand Down Expand Up @@ -207,7 +207,7 @@ The CLI follows a command pattern with subcommands. The main components are:

### Python SDK Architecture
- `python/cog/` - Core SDK
- `base_predictor.py` - Base class for model predictors
- `predictor.py` - Base classes for model runners (`BaseRunner`) and predictors (`BasePredictor`)
- `types.py` - Input/output type definitions
- `server/` - HTTP/queue server implementation
- `command/` - Runner implementations for predict/train
Expand Down Expand Up @@ -281,7 +281,7 @@ Tools disabled in CI are listed in `MISE_DISABLE_TOOLS` in `ci.yaml`.
- `cog.yaml` - User-facing model configuration
- `pkg/config/config.go` - Go code for parsing and validating `cog.yaml`
- `pkg/config/data/config_schema_v1.0.json` - JSON schema for `cog.yaml`
- `python/cog/base_predictor.py` - Predictor interface
- `python/cog/predictor.py` - Runner/Predictor interface (BaseRunner, BasePredictor)
- `crates/Cargo.toml` - Rust workspace configuration (version must match VERSION.txt)
- `crates/README.md` - Coglet architecture overview
- `mise.toml` - Task definitions for development workflow
Expand Down
24 changes: 12 additions & 12 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ There are a few concepts used throughout Cog that might be helpful to understand
- **Model**: A user's machine learning model, consisting of code and weights.
- **Output**: Output from a **prediction**, as arbitrarily complex JSON object.
- **Prediction**: A single run of the model, that takes **input** and produces **output**.
- **Predictor**: Defines how Cog runs **predictions** on a **model**.
- **Runner**: Defines how Cog runs **predictions** on a **model**.

## Running tests

Expand Down Expand Up @@ -176,21 +176,21 @@ When adding new functionality, add integration tests in `integration-tests/tests
Example test structure:

```txtar
# Test string predictor
# Test string runner
cog build -t $TEST_IMAGE
cog predict $TEST_IMAGE -i s=world
cog run $TEST_IMAGE -i s=world
stdout 'hello world'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"
run: "run.py:Runner"

-- predict.py --
from cog import BasePredictor
-- run.py --
from cog import BaseRunner

class Predictor(BasePredictor):
def predict(self, s: str) -> str:
class Runner(BaseRunner):
def run(self, s: str) -> str:
return "hello " + s
```

Expand All @@ -216,7 +216,7 @@ retry-curl POST /predictions '{"input":{"s":"test"}}' 30 1s
stdout '"output":"hello test"'
```

**Example: Testing predictor with subprocess in setup**
**Example: Testing runner with subprocess in setup**

```txtar
cog build -t $TEST_IMAGE
Expand All @@ -226,12 +226,12 @@ cog serve
retry-curl POST /predictions '{"input":{"s":"test"}}' 30 1s
stdout '"output":"hello test"'

-- predict.py --
class Predictor(BasePredictor):
-- run.py --
class Runner(BaseRunner):
def setup(self):
self.process = subprocess.Popen(["./background.sh"])

def predict(self, s: str) -> str:
def run(self, s: str) -> str:
return "hello " + s
```

Expand Down
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ build:
- "libglib2.0-0"
python_version: "3.13"
python_requirements: requirements.txt
predict: "predict.py:Predictor"
run: "run.py:Runner"
```

Define how predictions are run on your model with `predict.py`:
Define how predictions are run on your model with `run.py`:

```python
from cog import BasePredictor, Input, Path
from cog import BaseRunner, Input, Path
import torch

class Predictor(BasePredictor):
class Runner(BaseRunner):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.model = torch.load("./weights.pth")

# The arguments and types the model takes as input
def predict(self,
def run(self,
image: Path = Input(description="Grayscale input image")
) -> Path:
"""Run a single prediction on the model"""
Expand All @@ -57,7 +57,7 @@ In the above we accept a path to the image as an input, and return a path to our
Now, you can run predictions on this model:

```console
$ cog predict -i image=@input.jpg
$ cog run -i image=@input.jpg
--> Building Docker image...
--> Running Prediction...
--> Output written to output.jpg
Expand Down Expand Up @@ -180,7 +180,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for how to set up a development environme
- [Take a look at some examples of using Cog](https://github.com/replicate/cog-examples)
- [Deploy models with Cog](docs/deploy.md)
- [`cog.yaml` reference](docs/yaml.md) to learn how to define your model's environment
- [Prediction interface reference](docs/python.md) to learn how the `Predictor` interface works
- [Prediction interface reference](docs/python.md) to learn how the `Runner` interface works
- [Training interface reference](docs/training.md) to learn how to add a fine-tuning API to your model
- [HTTP API reference](docs/http.md) to learn how to use the HTTP API that models serve

Expand Down
108 changes: 90 additions & 18 deletions crates/coglet-python/src/predictor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ pub struct PythonPredictor {
instance: PyObject,
/// The predictor's kind (class or standalone function) and method execution types
kind: PredictorKind,
/// The name of the predict method ("run" or "predict")
predict_method_name: &'static str,
/// Whether the setup() method is an async def
setup_is_async: bool,
}
Expand All @@ -324,7 +326,7 @@ impl PythonPredictor {
.call_method1("isfunction", (instance.bind(py),))?
.extract()?;

let kind = if is_function {
let (kind, predict_method_name) = if is_function {
// Standalone function - detect its async nature
let (is_async, is_async_gen) = Self::detect_async(py, &instance, "")?;
let predict_kind = if is_async_gen {
Expand All @@ -337,18 +339,69 @@ impl PythonPredictor {
tracing::info!("Detected sync train()");
PredictKind::Sync
};
PredictorKind::StandaloneFunction(predict_kind)
// Standalone functions don't use predict_method_name (dispatched
// directly via PredictorKind::StandaloneFunction match arms).
(PredictorKind::StandaloneFunction(predict_kind), "")
} else {
// Class instance - detect predict() and train() methods
let (is_async, is_async_gen) = Self::detect_async(py, &instance, "predict")?;
// Class instance - detect run() vs predict() method
// Walk MRO to support multi-level inheritance, skipping base stub classes.
// Note: BasePredictor is an alias for BaseRunner in Python, so these
// resolve to the same object. We check both for clarity.
let instance_bound = instance.bind(py);
let cls = instance_bound.getattr("__class__")?;
let mro = cls.getattr("__mro__")?;
let base_predictor = py.import("cog.predictor")?.getattr("BasePredictor")?;
let base_runner = py.import("cog.predictor")?.getattr("BaseRunner")?;
let object_type = py.eval(c"object", None, None)?;

let mut has_run = false;
let mut has_predict = false;

for item in mro.try_iter()? {
let klass: Bound<'_, PyAny> = item?;
// Skip base stubs and object
if klass.eq(&base_predictor)?
|| klass.eq(&base_runner)?
|| klass.eq(&object_type)?
{
continue;
}
let class_dict = klass.getattr("__dict__")?;
if !has_run && class_dict.contains("run")? {
has_run = true;
}
if !has_predict && class_dict.contains("predict")? {
has_predict = true;
}
if has_run && has_predict {
break;
}
}

let predict_method_name = match (has_run, has_predict) {
(true, true) => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"define either run() or predict(), not both",
));
}
(true, false) => "run",
(false, true) => "predict",
(false, false) => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"run() or predict() method not found",
));
}
};

let (is_async, is_async_gen) = Self::detect_async(py, &instance, predict_method_name)?;
let predict_kind = if is_async_gen {
tracing::info!("Detected async generator predict()");
tracing::info!("Detected async generator {}()", predict_method_name);
PredictKind::AsyncGen
} else if is_async {
tracing::info!("Detected async predict()");
tracing::info!("Detected async {}()", predict_method_name);
PredictKind::Async
} else {
tracing::info!("Detected sync predict()");
tracing::info!("Detected sync {}()", predict_method_name);
PredictKind::Sync
};

Expand All @@ -366,10 +419,13 @@ impl PythonPredictor {
TrainKind::None
};

PredictorKind::Class {
predict: predict_kind,
train: train_kind,
}
(
PredictorKind::Class {
predict: predict_kind,
train: train_kind,
},
predict_method_name,
)
};

// Detect if setup() is async
Expand All @@ -386,6 +442,7 @@ impl PythonPredictor {
let predictor = Self {
instance,
kind,
predict_method_name,
setup_is_async,
};

Expand All @@ -396,7 +453,11 @@ impl PythonPredictor {
if is_function {
Self::unwrap_field_info_defaults(py, &predictor.instance, "")?;
} else {
Self::unwrap_field_info_defaults(py, &predictor.instance, "predict")?;
Self::unwrap_field_info_defaults(
py,
&predictor.instance,
predictor.predict_method_name,
)?;
if matches!(predictor.kind, PredictorKind::Class { train, .. } if train != TrainKind::None)
{
Self::unwrap_field_info_defaults(py, &predictor.instance, "train")?;
Expand Down Expand Up @@ -587,7 +648,7 @@ impl PythonPredictor {
pub fn predict_func<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let instance = self.instance.bind(py);
match &self.kind {
PredictorKind::Class { .. } => instance.getattr("predict"),
PredictorKind::Class { .. } => instance.getattr(self.predict_method_name),
PredictorKind::StandaloneFunction(_) => Ok(instance.clone()),
}
}
Expand All @@ -607,7 +668,7 @@ impl PythonPredictor {
pub fn predict_raw(&self, py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult<PyObject> {
let (method_name, is_async) = match &self.kind {
PredictorKind::Class { predict, .. } => (
"predict",
self.predict_method_name,
matches!(predict, PredictKind::Async | PredictKind::AsyncGen),
),
PredictorKind::StandaloneFunction(predict_kind) => (
Expand Down Expand Up @@ -693,7 +754,10 @@ impl PythonPredictor {

// PreparedInput cleans up temp files on drop (RAII)
let func = self.predict_func(py).map_err(|e| {
PredictionError::Failed(format!("Failed to get predict function: {}", e))
PredictionError::Failed(format!(
"Failed to get {} function: {}",
self.predict_method_name, e
))
})?;
let prepared = input::prepare_input(py, raw_input_dict, &func)
.map_err(|e| PredictionError::InvalidInput(format_validation_error(py, &e)))?;
Expand Down Expand Up @@ -969,7 +1033,10 @@ impl PythonPredictor {
})?;

let func = self.predict_func(py).map_err(|e| {
PredictionError::Failed(format!("Failed to get predict function: {}", e))
PredictionError::Failed(format!(
"Failed to get {} function: {}",
self.predict_method_name, e
))
})?;
let prepared = input::prepare_input(py, raw_input_dict, &func)
.map_err(|e| PredictionError::InvalidInput(format_validation_error(py, &e)))?;
Expand All @@ -978,8 +1045,13 @@ impl PythonPredictor {
// Call predict - returns coroutine
let instance = self.instance.bind(py);
let coro = instance
.call_method("predict", (), Some(&input_dict))
.map_err(|e| PredictionError::Failed(format!("Failed to call predict: {}", e)))?;
.call_method(self.predict_method_name, (), Some(&input_dict))
.map_err(|e| {
PredictionError::Failed(format!(
"Failed to call {}: {}",
self.predict_method_name, e
))
})?;

// For async generators, wrap to collect all values
let is_async_gen = matches!(
Expand Down
Loading
Loading