Quality-aware circuit breaker for production ML pipelines.
Standard circuit breakers trip on HTTP errors and latency. ML pipelines fail softly — the service returns 200, latency is fine, but model quality has degraded (low-confidence outputs, bad embeddings, hallucinated content). No existing library catches this. ml-breaker lets engineers define quality-based trip conditions using the model's own outputs.
pip install ml-breakerfrom ml_breaker import circuit_breaker, CircuitBreakerOpen
from ml_breaker._conditions import consecutive_failures
@circuit_breaker(
trip_on=consecutive_failures(5),
fallback=lightweight_model,
half_open_after=30,
)
def call_reranker(query, candidates):
return reranker_model(query, candidates)Two ways to emit quality:
# Option A: report_quality() — when validation logic is complex
@circuit_breaker(trip_on=quality_below(threshold=0.7, window=20), ...)
def call_reranker(query, candidates):
result = reranker_model(query, candidates)
report_quality(result.confidence) # emits to context side-channel
return result # return type unchanged
# Option B: score_fn — when quality is directly on the return value
@circuit_breaker(
trip_on=quality_below(threshold=0.7, window=20, score_fn=lambda r: r.confidence),
...
)
def call_reranker(query, candidates):
return reranker_model(query, candidates)report_quality wins if both are present.
Set require_quality_signal=True on the decorator to enforce that every call emits a quality score — useful when strict instrumentation is required and silent omissions should raise immediately.
CLOSED ──(condition trips)──► OPEN ──(half_open_after s)──► HALF_OPEN
▲ │
└──(recovery_threshold probes pass)────────────────────────────┘
│
OPEN ◄──(any probe fails)───────────────┘
| State | Behavior |
|---|---|
CLOSED |
Normal operation. Conditions are evaluated on every call. |
OPEN |
All calls are rejected immediately (or routed to fallback). No traffic reaches the model. |
HALF_OPEN |
A limited number of probe calls are allowed through to test recovery. |
recovery_threshold defaults to 2 — one successful probe is too fragile for ML workloads where a single lucky call is not a reliable signal.
| Condition | Signature | Trips when |
|---|---|---|
quality_below |
(threshold, window=20, score_fn=None) |
rolling mean of quality scores over full window drops below threshold |
latency_above |
(threshold_ms, window=20) |
all latencies in full window exceed threshold |
error_rate_above |
(rate, window=20) |
error rate over full window exceeds rate |
consecutive_failures |
(n) |
n consecutive errors (no window) |
window is a call count, not a time window.
from ml_breaker import any_of, all_of
trip_on=any_of(quality_below(0.7, window=20), latency_above(2000, window=10))
trip_on=all_of(error_rate_above(0.3, window=50), consecutive_failures(3))Each condition in a composition maintains independent state and window — any_of and all_of do not share a merged buffer.
Three forms are supported:
# 1. Callable — called with the same args as the guarded function
@circuit_breaker(trip_on=..., fallback=lightweight_model)
def call_reranker(query, candidates):
return reranker_model(query, candidates)
# 2. Static value — returned as-is when the breaker is open
@circuit_breaker(trip_on=..., fallback=[])
def call_reranker(query, candidates):
return reranker_model(query, candidates)
# 3. No fallback — raises CircuitBreakerOpen
@circuit_breaker(trip_on=...)
def call_reranker(query, candidates):
return reranker_model(query, candidates)
try:
result = call_reranker(query, candidates)
except CircuitBreakerOpen as e:
print(e.name) # breaker name
print(e.state) # State.OPEN
print(e.trip_count) # number of times this breaker has trippedcb = my_fn.breaker # attached to decorated function
cb = CircuitBreaker.get("name") # global registry lookup
cb.state # State.CLOSED | State.OPEN | State.HALF_OPEN
cb.trip_count # int
cb.reset() # force back to CLOSED — useful in tests
CircuitBreaker.all() # dict[str, CircuitBreaker]Use name= to share a single breaker across multiple functions or services:
@circuit_breaker(name="reranker", trip_on=consecutive_failures(5), ...)
def call_reranker_v1(query, candidates): ...
@circuit_breaker(name="reranker", trip_on=consecutive_failures(5), ...)
def call_reranker_v2(query, candidates): ...
# Both functions reference the same CircuitBreaker instance
assert call_reranker_v1.breaker is call_reranker_v2.breakerfrom ml_breaker.metrics import PrometheusMetrics
PrometheusMetrics.register() # call once at startuppip install ml-breaker[metrics]| Metric | Type | Description |
|---|---|---|
ml_breaker_state |
Gauge | Current state per breaker (0=CLOSED, 1=OPEN, 2=HALF_OPEN) |
ml_breaker_trip_total |
Counter | Total number of trips per breaker |
ml_breaker_call_total |
Counter | Total calls, labeled by outcome (success/failure/rejected) |
ml_breaker_quality_score |
Histogram | Quality scores emitted via report_quality or score_fn |
ml_breaker_latency_ms |
Histogram | Call latency in milliseconds per breaker |
For setups without a Prometheus dependency, use the on_state_change callback instead:
@circuit_breaker(
trip_on=...,
on_state_change=lambda breaker, old_state, new_state: logger.warning(
"breaker %s: %s -> %s", breaker.name, old_state, new_state
),
)
def call_reranker(query, candidates):
return reranker_model(query, candidates)- v0.2: async support, Redis-backed distributed state, built-in quality metrics for embeddings and softmax outputs
PRs welcome. Open an issue first for anything non-trivial: github.com/ritabanb/ml-breaker/issues.