Embedding-based response drift detection for LangChain agents.
Detects when an LLM starts answering outside its intended domain (a legal assistant drifting into cooking advice, a medical chatbot wandering into finance) without ground-truth labels or a separate classifier.
- How it works
- Installation
- Quick start
- Integration patterns
- LangGraph guardrail
- Async support
- Alert sinks
- Multi-topic corpora
- Domain auditing
- Building a corpus with FPS
- Distribution-level detection
- Visualisation
- Persisting a corpus
- DriftResult reference
- Development
- Build a reference corpus from representative on-topic texts.
- Embed each LLM response with the same model.
- Compare using two complementary signals:
- Centroid distance: how close is the response to the centre of the corpus (or its nearest cluster)?
- Nearest-neighbour distance: is the response close to at least one reference text?
- Flag drift when both signals agree the response is far from the reference domain.
Using both signals reduces false positives: a paraphrase that sits slightly off the centroid is rescued when it's still close to a known reference text.
The threshold for each signal is adaptive: the 5th percentile of within-corpus similarity scores, so ~95% of reference texts clear it with no manual tuning.
git clone https://github.com/vinerya/driftguard.git
cd driftguard
pip install -r requirements.txt
pip install -e .Requires Python ≥ 3.9. The only runtime dependencies are langchain-core and numpy.
Optional extras:
pip install -e ".[viz]" # matplotlib + scikit-learn for corpus.plot()
pip install langgraph # LangGraph guardrail nodesfrom driftguard import ReferenceCorpus, DriftDetector
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
# 1. Build the reference corpus from representative on-topic texts
corpus = ReferenceCorpus(embeddings_model=embeddings)
corpus.add_texts([
"tort law", "contract formation", "negligence standard",
"criminal intent", "due process rights",
])
# 2. Create the detector
detector = DriftDetector(corpus=corpus)
# 3. Check a response
result = detector.check("habeas corpus")
print(result.is_drift) # False (on-topic)
print(result.centroid_similarity) # e.g. 0.91
print(result.max_reference_similarity)# e.g. 0.95
print(result.threshold) # e.g. 0.87
result = detector.check("best pasta recipe")
print(result.is_drift) # True (off-topic)Attach to any LangChain LLM or chat model. Runs on every response without interrupting the pipeline; use for monitoring, logging, or metrics.
from driftguard import DriftCallbackHandler, AlertManager
alerts = AlertManager(sinks=["log"])
handler = DriftCallbackHandler(detector=detector, alerts=alerts)
llm = ChatOpenAI(callbacks=[handler])
response = llm.invoke("What is the recipe for tiramisu?")
# Drift is logged as a WARNING; the response still returns normally.
print(handler.history[-1].is_drift) # TrueInsert as a step in a LangChain chain. Raises DriftError on drift; passes the text through unchanged otherwise.
from driftguard import DriftRunnable, DriftError
from langchain_core.output_parsers import StrOutputParser
drift = DriftRunnable(detector=detector)
chain = llm | StrOutputParser() | drift.as_guard()
try:
result = chain.invoke("What is the recipe for tiramisu?")
except DriftError as e:
print(f"Blocked: centroid_sim={e.result.centroid_similarity:.3f} "
f"< threshold={e.result.threshold:.3f}")Annotates the chain output with drift metadata without halting. Useful when you want to observe drift but let the response through for the user to see.
chain = llm | StrOutputParser() | drift.as_passthrough()
output = chain.invoke("habeas corpus")
# {"output": "Habeas corpus is a legal right...", "drift": DriftResult(...)}
print(output["drift"].is_drift) # Falsedriftguard ships a first-class LangGraph integration. The node and routing helpers are plain callables that match LangGraph's expected signatures, no LangGraph import inside the library itself, so the module loads fine even if LangGraph isn't installed.
from langgraph.graph import StateGraph
from typing import Any
from typing_extensions import TypedDict
from driftguard.langgraph import drift_node, route_on_drift
class AgentState(TypedDict):
query: str
response: str
drift: Any # holds the DriftResult written by the drift node
graph = StateGraph(AgentState)
graph.add_node("llm", call_llm) # writes state["response"]
graph.add_node("drift_check", drift_node(detector)) # reads "response", writes "drift"
graph.add_node("fallback", handle_fallback)
graph.add_node("respond", finalize)
graph.set_entry_point("llm")
graph.add_edge("llm", "drift_check")
graph.add_conditional_edges(
"drift_check",
route_on_drift, # returns "drift" or "ok"
{"drift": "fallback", "ok": "respond"},
)
app = graph.compile()Custom state key: if your LLM node writes to a key other than "response":
graph.add_node("drift_check", drift_node(detector, text_key="output"))Async graphs: swap drift_node for adrift_node:
from driftguard.langgraph import adrift_node
graph.add_node("drift_check", adrift_node(detector))Custom route labels: use make_route_on_drift when your edge map uses different names:
from driftguard.langgraph import make_route_on_drift
router = make_route_on_drift(on_drift="blocked", on_ok="continue")
graph.add_conditional_edges(
"drift_check", router, {"blocked": "fallback", "continue": "respond"}
)Every public method has an async counterpart:
await corpus.aadd_texts(["tort law", "negligence"])
result = await detector.acheck("contract formation")AsyncDriftCallbackHandler mirrors DriftCallbackHandler for async LangChain pipelines.
AlertManager dispatches drift alerts to one or more sinks simultaneously:
from driftguard import AlertManager
alerts = AlertManager(sinks=[
"log", # WARNING via Python logging
"https://your-service.example/webhook", # POST JSON payload
lambda result: my_queue.put(result), # arbitrary sync or async callable
])Pass an AlertManager instance to DriftCallbackHandler, DriftRunnable, or the LangGraph nodes; all accept one via the alerts argument.
When your reference corpus spans several distinct topics, a single global centroid produces false positives for texts that are on-topic but far from the average. Set n_clusters to partition the corpus into groups; each query is then compared to its nearest cluster rather than the global centre.
corpus = ReferenceCorpus(embeddings_model=embeddings, n_clusters=2)
corpus.add_texts([
# Legal cluster
"tort law", "contract formation", "negligence",
# Medical cluster
"malpractice", "diagnosis", "clinical trial",
])
detector = DriftDetector(corpus=corpus)
detector.check("habeas corpus").is_drift # False (routes to legal cluster)
detector.check("prognosis").is_drift # False (routes to medical cluster)
detector.check("pasta recipe").is_drift # True (far from both clusters)Clustering uses numpy k-means internally with no extra dependencies.
The Auditor class runs drift detection over a batch of historical responses and returns a structured report: pass rate, score distribution, flagged outliers. Use it before deployment to validate your corpus, after incidents to understand what went wrong, or in CI to catch domain regressions between prompt versions.
from driftguard import Auditor
auditor = Auditor(detector)
report = auditor.run(production_responses)
print(f"Pass rate: {report.pass_rate:.1%}")
print(f"Drift rate: {report.drift_rate:.1%}")
print(f"Flagged: {report.flagged} / {report.total}")Export the report for a compliance doc or CI artifact:
report.to_json() # structured JSON string
open("report.html", "w").write(report.to_html()) # self-contained HTML reportThe HTML report includes a summary dashboard, centroid similarity distribution (p5 → p95), and a table of all flagged responses with their scores.
Async: all responses are checked concurrently:
report = await auditor.arun(production_responses)Detect domain shift between prompt versions, model upgrades, or dataset changes:
comparison = corpus_v1.compare(corpus_v2)
print(f"Centroid shift: {comparison.centroid_shift:.4f}") # cosine distance
print(f"Threshold delta: {comparison.threshold_delta:+.4f}")
print(f"Significant: {comparison.is_significant}") # shift > 0.05A centroid_shift above 0.05 (configurable via significant_shift_threshold) means the two corpora represent meaningfully different domains, worth investigating before swapping one for the other.
Hand-picking reference texts is tedious and easy to get wrong. ReferenceCorpus.from_texts() accepts a large pool of candidates and uses Farthest Point Sampling to automatically select the n most coverage-maximising texts; each new selection is the one farthest (in cosine distance) from all already-chosen texts.
# 500 example legal responses; pick the 30 most diverse ones.
corpus = ReferenceCorpus.from_texts(
candidates=my_500_legal_responses,
embeddings_model=embeddings,
n=30,
)The result is a fully initialised ReferenceCorpus ready for use with DriftDetector. An async variant is also available:
corpus = await ReferenceCorpus.afrom_texts(candidates, embeddings_model=embeddings, n=30)Per-response checks are sensitive to one-off anomalies. WindowedDriftDetector accumulates a sliding window of responses and checks whether the window's embedding distribution has shifted from the reference. Two signals can trigger drift:
- Centroid shift: the window's mean embedding has moved away from the reference.
- Drift fraction: more than
drift_fraction_threshold(default 30%) of recent responses are individually off-topic.
from driftguard import WindowedDriftDetector
wd = WindowedDriftDetector(corpus=corpus, window_size=20, drift_fraction_threshold=0.3)
for response in llm_responses:
result = wd.update(response)
if result is None:
continue # window still filling
if result.is_drift:
print(f"Window drift detected: "
f"centroid_sim={result.window_centroid_similarity:.3f}, "
f"drift_fraction={result.drift_fraction:.0%}")result is a WindowDriftResult returned on every call once the window is full. Use on_drift for async-friendly callbacks:
wd = WindowedDriftDetector(corpus=corpus, on_drift=lambda r: alert_queue.put(r))Async usage mirrors the sync API:
result = await wd.aupdate(response)corpus.plot() projects the reference corpus into 2D via t-SNE and optionally overlays texts colour-coded by drift status, useful for debugging false positives and tuning threshold_percentile.
pip install driftguard[viz] # adds matplotlib + scikit-learncorpus.plot(check_texts=["habeas corpus", "pasta recipe", "clinical trial"])Blue circles are reference texts; green triangles are on-topic detections; red X markers are flagged as drift.
For more control, call plot_corpus directly:
from driftguard.viz import plot_corpus
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 7))
plot_corpus(corpus, check_texts=probe_texts, ax=ax)
plt.show()Save a trained corpus to disk and reload it on the next run, no need to re-embed reference texts every time.
corpus.save("legal_corpus")
# writes legal_corpus.npz (embeddings, centroid, thresholds, cluster data)
# legal_corpus.texts.json (original texts)
loaded = ReferenceCorpus(embeddings_model=embeddings)
loaded.load("legal_corpus")Cluster data (centroids, per-cluster thresholds) is persisted alongside the embeddings.
Every call to detector.check() or detector.acheck() returns a frozen DriftResult:
| Field | Type | Description |
|---|---|---|
is_drift |
bool |
True when both centroid and NN signals indicate drift |
centroid_similarity |
float |
Cosine similarity to the nearest cluster (or global) centroid |
max_reference_similarity |
float |
Cosine similarity to the closest individual reference text |
threshold |
float |
Adaptive centroid threshold for this check |
nn_threshold |
float |
Adaptive nearest-neighbour threshold |
text |
str |
The checked text |
timestamp |
float |
Unix timestamp |
metadata |
dict |
Any kwargs passed to check(), e.g. run_id |
DriftError (raised by as_guard()) exposes the full DriftResult on its .result attribute.
WindowedDriftDetector.update() returns a WindowDriftResult once the window is full:
| Field | Type | Description |
|---|---|---|
is_drift |
bool |
True when centroid or fraction signal fires |
window_centroid_similarity |
float |
Cosine similarity of window centroid to reference |
drift_fraction |
float |
Fraction of window responses individually flagged |
window_size |
int |
Number of responses in the window |
threshold |
float |
Reference threshold used for centroid check |
drift_fraction_threshold |
float |
Configured fraction threshold |
timestamp |
float |
Unix timestamp |
pip install -e ".[dev]"
pytestAll tests use deterministic FakeEmbeddings, no API key or network access required.
MIT