VisionReasoner
====

 **VisionReasoner: Unified Visual Perception and Reasoning via Reinforcement Learning**

* Paper: https://arxiv.org/abs/2505.12081

![VisionReasoner Overview](../assets/visionreasoner_overview.jpg)

* Installation

```bash
git clone https://github.com/dvlab-research/VisionReasoner.git VisionReasoner_repo
cd VisionReasoner_repo
conda create -n visionreasoner_test python=3.12
conda activate visionreasoner_test
pip3 install torch torchvision
pip install -r requirements.txt
```

* FlashAttention
If you dont have flash-attention installed, and don't want to install it, change the following lines to `"eager"`

```
# Line 63 - File: VisionReasoner_repo/vision_reasoner/models/vision_reasoner_model.py
attn_implementation="flash_attention_2", # -> "eager"

# Line 11 - File VisionReasoner_repo/vision_reasoner/models/task_router.py
attn_implementation="flash_attention_2", # -> "eager"
```

 * Download the model

```bash
mkdir pretrained_models
cd pretrained_models
git lfs install
git clone https://huggingface.co/Ricky06662/VisionReasoner-7B
git clone https://huggingface.co/Ricky06662/TaskRouter-1.5B
```

In [1]:
import os
import sys

from PIL import Image
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

sys.path.append("VisionReasoner_repo/vision_reasoner")
from models.vision_reasoner_model import VisionReasonerModel
from utils import (
    visualize_results_enhanced,
    visualize_pose_estimation_results_enhanced,
    visualize_depth_estimation_results_enhanced,
    visualize_results_video,
    visualize_pose_estimation_results_video,
    visualize_depth_estimation_results_video
)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_path = 'pretrained_models/VisionReasoner-7B'
task_router_model_path = "pretrained_models/TaskRouter-1.5B"
segmentation_model_path = "facebook/sam2-hiera-large"

task_type = "auto"
# Possible choices for task_type:
#  - "auto"
#  - "detection"
#  - "segmentation"
#  - "counting"
#  - "vqa"
#  - "generation"
#  - "depth_estimation"
#  - "pose_estimation"

hybrid_mode = False
yolo_model_path = "yolov8x-worldv2.pt"
generation_mode = False
generation_model_name = "gpt-image-1"
depth_estimation_model_path = None  # facebook/VGGT-1B
pose_estimation_model_path = None  # usyd-community/vitpose-plus-base


In [3]:
model = VisionReasonerModel(
    reasoning_model_path=model_path, 
    task_router_model_path=task_router_model_path, 
    segmentation_model_path=segmentation_model_path,
    depth_estimation_model_path=depth_estimation_model_path,
    pose_estimation_model_path=pose_estimation_model_path
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.11it/s]
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


In [11]:

def inference(image, task_type, query, refer_image_path="", image_prompt=""):
    if task_type == "auto":
        result, task_type = model.process_single_image(image, query, return_task_type=True)
    elif task_type == "detection":
        result = model.detect_objects(image, query)
    elif task_type == "segmentation":
        result = model.segment_objects(image, query)
    elif task_type == "counting":
        result = model.count_objects(image, query)
    elif task_type == "generation":
        result = model.generate_image(refer_image_path, image_prompt)
    elif task_type == "depth_estimation":
        result = model.depth_estimation(image, query)
    elif task_type == "pose_estimation":
        result = model.pose_estimation(image, query)
    else:    # VQA
        result = model.answer_question(image, query)

    return result


In [None]:
image_path = "../samples/plants.jpg"
query = "How many plants are there in this image?"
refer_image_path = ""
image_prompt = ""

image = Image.open(image_path).convert("RGB")

result = inference(
    image=image,
    task_type=task_type,
    query=query,
    refer_image_path=refer_image_path,
    image_prompt=image_prompt
)

result.keys()

dict_keys(['count', 'bboxes', 'points', 'thinking', 'full_response', 'pred_answer'])

In [5]:
import textwrap

print(textwrap.fill(result["thinking"], width=80))

 The task involves identifying the number of plants in the image and comparing
the objects to find the most closely matched ones. I'll start through the image
step by step:  1. First, I'll identify the plants in the image. There are two
distinct plants visible. 2. Next to each plant, there is a gold-colored metal
stand that appears to be a stand holder or stand stand. 3. The plants are are in
black pots, which are are are not the main focus of the image but but are are
are part of the plants. 4. The metal holders are are the most closely matched
objects to the question, as they are are the main focus of the image and the
question asks about the number of plants.


In [6]:
print(textwrap.fill(result["full_response"], width=80))

<think> The task involves identifying the number of plants in the image and
comparing the objects to find the most closely matched ones. I'll start through
the image step by step:  1. First, I'll identify the plants in the image. There
are two distinct plants visible. 2. Next to each plant, there is a gold-colored
metal stand that appears to be a stand holder or stand stand. 3. The plants are
are in black pots, which are are are not the main focus of the image but but are
are are part of the plants. 4. The metal holders are are the most closely
matched objects to the question, as they are are the main focus of the image and
the question asks about the number of plants.  </think> <answer>[{"bbox_2d":
[108,354,364,518], "point_2d": [218,468]}, {"bbox_2d": [337,56,620,412],
"point_2d": [515,280]}]</answer>


In [9]:
result["pred_answer"]

[{'bbox_2d': [108, 354, 364, 518], 'point_2d': [218, 468]},
 {'bbox_2d': [337, 56, 620, 412], 'point_2d': [515, 280]}]

### task_type="segmentation"

In [13]:
result = inference(
    image=image,
    task_type=task_type,
    query=query,
    refer_image_path=refer_image_path,
    image_prompt=image_prompt
)
result.keys()

dict_keys(['count', 'bboxes', 'points', 'thinking', 'full_response', 'pred_answer'])

In [14]:
result

{'count': 2,
 'bboxes': [[137, 479, 463, 701], [428, 76, 788, 558]],
 'points': [[277, 633], [655, 379]],
 'thinking': " The task involves identifying the number of plants in the image and comparing the objects to find the most closely matched ones. I'll start through the image step by step:\n\n1. First, I'll identify the plants in the image. There are two distinct plants visible.\n2. Next to each plant, there is a gold-colored metal stand that appears to be a stand holder or stand stand.\n3. The plants are are in black pots, which are are are not the main focus of the image but but are are are part of the plants.\n4. The metal holders are are the most closely matched objects to the question, as they are are the main focus of the image and the question asks about the number of plants.\n\n",
 'full_response': '<think> The task involves identifying the number of plants in the image and comparing the objects to find the most closely matched ones. I\'ll start through the image step by step