Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions docs/foundation/depth_estimation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
<a href="https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf" target="_blank">Depth-Anything-V2-Small</a> is a depth estimation model developed by Hugging Face.

You can use Depth-Anything-V2-Small to estimate the depth of objects in images, creating a depth map where:
- Each pixel's value represents its relative distance from the camera
- Lower values (darker colors) indicate closer objects
- Higher values (lighter colors) indicate further objects

You can deploy Depth-Anything-V2-Small with Inference.

### Installation

To install inference with the extra dependencies necessary to run Depth-Anything-V2-Small, run

```pip install inference[transformers]```

or

```pip install inference-gpu[transformers]```

### How to Use Depth-Anything-V2-Small

Create a new Python file called `app.py` and add the following code:

```python
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from inference.models.depth_estimation.depthestimation import DepthEstimator

# Initialize the model
model = DepthEstimator()

# Load an image
image = Image.open("your_image.jpg")

# Run inference
results = model.predict(image)

# Get the depth map and visualization
depth_map = results[0]['normalized_depth']
visualization = results[0]['image']

# Convert visualization to numpy array for display
visualization_array = visualization.numpy()

# Display the results
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(visualization_array)
plt.title('Depth Map')
plt.axis('off')

plt.show()
```

In this code, we:
1. Load the Depth-Anything-V2-Small model
2. Load an image for depth estimation
3. Run inference to get the depth map
4. Display both the original image and the depth map visualization

The depth map visualization uses a viridis colormap where:
- Darker colors (purple/blue) represent objects closer to the camera
- Lighter colors (yellow/green) represent objects further from the camera

To use Depth-Anything-V2-Small with Inference, you will need a Hugging Face token. If you don't already have a Hugging Face account, <a href="https://huggingface.co/join" target="_blank">sign up for a free Hugging Face account</a>.

Then, set your Hugging Face token as an environment variable:

```bash
export HUGGING_FACE_HUB_TOKEN=your_token_here
```

Or you can log in using the Hugging Face CLI:

```bash
huggingface-cli login
```

Then, run the Python script you have created:

```bash
python app.py
```

The script will display both the original image and the depth map visualization.
16 changes: 16 additions & 0 deletions inference/core/entities/requests/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,22 @@ class CVInferenceRequest(InferenceRequest):
)


class DepthEstimationRequest(BaseRequest):
"""Request for depth estimation.

Attributes:
image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) to be estimated.
"""

image: Union[List[InferenceRequestImage], InferenceRequestImage]

visualize_predictions: Optional[bool] = Field(
default=False,
examples=[False],
description="If true, the predictions will be drawn on the original image and returned as a base64 string",
)


class ObjectDetectionInferenceRequest(CVInferenceRequest):
"""Object Detection inference request.

Expand Down
2 changes: 2 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@

QWEN_2_5_ENABLED = str2bool(os.getenv("QWEN_2_5_ENABLED", True))

DEPTH_ESTIMATION_ENABLED = str2bool(os.getenv("DEPTH_ESTIMATION_ENABLED", True))

SMOLVLM2_ENABLED = str2bool(os.getenv("SMOLVLM2_ENABLED", True))

MOONDREAM2_ENABLED = str2bool(os.getenv("MOONDREAM2_ENABLED", True))
Expand Down
16 changes: 10 additions & 6 deletions inference/core/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ def add_model(
f"ModelManager - model with model_id={resolved_identifier} is already loaded."
)
return

logger.debug("ModelManager - model initialisation...")

model = self.model_registry.get_model(resolved_identifier, api_key)(
model_id=model_id,
api_key=api_key,
)
logger.debug("ModelManager - model successfully loaded.")
self._models[resolved_identifier] = model
try:
model = self.model_registry.get_model(resolved_identifier, api_key)(
model_id=model_id,
api_key=api_key,
)
logger.debug("ModelManager - model successfully loaded.")
self._models[resolved_identifier] = model
except Exception as e:
raise

def check_for_model(self, model_id: str) -> None:
"""Checks whether the model with the given ID is in the manager.
Expand Down
1 change: 1 addition & 0 deletions inference/core/registries/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"yolo_world": ("object-detection", "yolo-world"),
"owlv2": ("object-detection", "owlv2"),
"smolvlm2": ("lmm", "smolvlm-2.2b-instruct"),
"depth-anything-v2": ("depth-estimation", "small"),
"moondream2": ("lmm", "moondream2"),
}

Expand Down
6 changes: 5 additions & 1 deletion inference/core/utils/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def get_model_id_chunks(
model_id_chunks = model_id.split("/")
if len(model_id_chunks) != 2:
raise InvalidModelIDError(f"Model ID: `{model_id}` is invalid.")

dataset_id, version_id = model_id_chunks[0], model_id_chunks[1]

if dataset_id.lower() in {
"clip",
"doctr",
Expand All @@ -25,9 +27,11 @@ def get_model_id_chunks(
"yolo_world",
"smolvlm2",
"moondream2",
"depth-anything-v2",
}:
return dataset_id, version_id

try:
return dataset_id, str(int(version_id))
except Exception:
except Exception as e:
return model_id, None
4 changes: 4 additions & 0 deletions inference/core/workflows/core_steps/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@
from inference.core.workflows.core_steps.models.foundation.cog_vlm.v1 import (
CogVLMBlockV1,
)
from inference.core.workflows.core_steps.models.foundation.depth_estimation.v1 import (
DepthEstimationBlockV1,
)
from inference.core.workflows.core_steps.models.foundation.florence2.v1 import (
Florence2BlockV1,
)
Expand Down Expand Up @@ -515,6 +518,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
DynamicCropBlockV1,
DetectionsFilterBlockV1,
DetectionOffsetBlockV1,
DepthEstimationBlockV1,
ByteTrackerBlockV1,
RelativeStaticCropBlockV1,
DetectionsTransformationBlockV1,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import List, Literal, Optional, Type, Union

from pydantic import ConfigDict, Field

from inference.core.entities.requests.inference import DepthEstimationRequest
from inference.core.managers.base import ModelManager
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.execution_engine.entities.base import (
Batch,
OutputDefinition,
WorkflowImageData,
)
from inference.core.workflows.execution_engine.entities.types import (
IMAGE_KIND,
NUMPY_ARRAY_KIND,
ROBOFLOW_MODEL_ID_KIND,
ImageInputField,
RoboflowModelField,
Selector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
WorkflowBlock,
WorkflowBlockManifest,
)


class BlockManifest(WorkflowBlockManifest):
# Standard model configuration for UI, schema, etc.
model_config = ConfigDict(
json_schema_extra={
"name": "Depth Estimation",
"version": "v1",
"short_description": "Run Depth Estimation on an image.",
"long_description": (
"""
🎯 This workflow block performs depth estimation on images using Apple's DepthPro model. It analyzes the spatial relationships
and depth information in images to create a depth map where:

📊 Each pixel's value represents its relative distance from the camera
🔍 Lower values (darker colors) indicate closer objects
🔭 Higher values (lighter colors) indicate further objects

The model outputs:
1. 🗺️ A depth map showing the relative distances of objects in the scene
2. 📐 The camera's field of view (in degrees)
3. 🔬 The camera's focal length

This is particularly useful for:
- 🏗️ Understanding 3D structure from 2D images
- 🎨 Creating depth-aware visualizations
- 📏 Analyzing spatial relationships in scenes
- 🕶️ Applications in augmented reality and 3D reconstruction

⚡ The model runs efficiently on Apple Silicon (M1-M4) using Metal Performance Shaders (MPS) for accelerated inference.
"""
),
"license": "Apache-2.0",
"block_type": "model",
"search_keywords": [
"Depth Estimation",
"Depth Anything",
"Depth Anything V2",
"Hugging Face",
"HuggingFace",
],
"is_vlm_block": True,
"ui_manifest": {
"section": "model",
"icon": "fal fa-atom",
"blockPriority": 5.5,
},
},
protected_namespaces=(),
)
type: Literal["roboflow_core/depth_estimation@v1"]
images: Selector(kind=[IMAGE_KIND]) = ImageInputField

model_version: str = Field(
default="depth-anything-v2/small",
description="The Depth Estimation model to be used for inference.",
examples=["depth-anything-v2/small"],
)

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(name="image", kind=[IMAGE_KIND]),
OutputDefinition(name="normalized_depth", kind=[NUMPY_ARRAY_KIND]),
]

@classmethod
def get_parameters_accepting_batches(cls) -> List[str]:
# Only images can be passed in as a list/batch
return ["images"]

@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.3.0,<2.0.0"


class DepthEstimationBlockV1(WorkflowBlock):
def __init__(
self,
model_manager: ModelManager,
api_key: Optional[str],
step_execution_mode: StepExecutionMode,
):
self._model_manager = model_manager
self._api_key = api_key
self._step_execution_mode = step_execution_mode

@classmethod
def get_init_parameters(cls) -> List[str]:
return ["model_manager", "api_key", "step_execution_mode"]

@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
return BlockManifest

def run(
self,
images: Batch[WorkflowImageData],
model_version: str = "depth-anything-v2/small",
) -> BlockResult:
if self._step_execution_mode == StepExecutionMode.LOCAL:
return self.run_locally(
images=images,
model_version=model_version,
)
elif self._step_execution_mode == StepExecutionMode.REMOTE:
raise NotImplementedError(
"Remote execution is not supported for Depth Estimation. Please use a local or dedicated inference server."
)
else:
raise ValueError(
f"Unknown step execution mode: {self._step_execution_mode}"
)

def run_locally(
self,
images: Batch[WorkflowImageData],
model_version: str = "depth-anything-v2/small",
) -> BlockResult:
# Convert each image to the format required by the model.
inference_images = [
i.to_inference_format(numpy_preferred=False) for i in images
]

# Register Depth Estimation with the model manager.
try:
self._model_manager.add_model(model_id=model_version, api_key=self._api_key)
except Exception as e:
raise

predictions = []
for idx, image in enumerate(inference_images):
# Run inference.
request = DepthEstimationRequest(
image=image,
)

try:
prediction = self._model_manager.infer_from_request_sync(
model_id=model_version, request=request
)
response_text = prediction.response
predictions.append(response_text)
except Exception as e:
raise

return predictions
1 change: 1 addition & 0 deletions inference/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The models supported by Roboflow Inference have their own licenses. View the lic
| `inference/models/yolov11` | [AGPL-3.0](https://github.com/ultralytics/ultralytics/blob/master/LICENSE) | ✅ |
| `inference/models/yolov12` | [AGPL-3.0](https://github.com/sunsmarterjie/yolov12?tab=AGPL-3.0-1-ov-file) | ✅ |
| `inference/models/smolvlm2` | [Apache 2.0](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) | 👍 |
| `inference/models/depth_estimation` | [Apache 2.0](https://huggingface.co/depth-anything/Depth-Anything-V2-Small) | 👍 |
| `inference/models/rfdetr` | [Apache 2.0](https://github.com/roboflow/rf-detr/blob/main/LICENSE) | 👍 |
| `inference/models/moondream2` | [Apache 2.0](https://github.com/vikhyat/moondream/blob/main/LICENSE) | 👍 |

Expand Down
Loading