Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test fixtures for vlmrun tests."""

import hashlib
import pytest
from typer.testing import CliRunner
from pydantic import BaseModel
Expand Down Expand Up @@ -474,26 +475,44 @@ class Agent:
def __init__(self, client):
self._client = client

def get(self, name=None, version=None, id=None):
def get(self, name=None, id=None, prompt=None):
from vlmrun.client.types import AgentInfo
from datetime import datetime

if id and name:
raise ValueError("Only one of `id` or `name` can be provided.")
if not id and not name:
raise ValueError("Either `id` or `name` must be provided.")
if id:
if name or prompt:
raise ValueError(
"Only one of `id` or `name` or `prompt` can be provided."
)
elif name:
if id or prompt:
raise ValueError(
"Only one of `id` or `name` or `prompt` can be provided."
)
elif prompt:
if id or name:
raise ValueError(
"Only one of `id` or `name` or `prompt` can be provided."
)
else:
raise ValueError(
"Either `id` or `name` or `prompt` must be provided."
)

if id:
agent_id = id
agent_name = f"agent-{id}"
else:
agent_id = f"agent-{name}-{version or 'latest'}"
elif name:
agent_id = f"agent-{name}"
agent_name = name
elif prompt:
hash_prompt = hashlib.sha256(prompt.encode()).hexdigest()
agent_id = f"agent-{hash_prompt}"
agent_name = f"agent-{hash_prompt}"

return AgentInfo(
id=agent_id,
name=agent_name,
version=version or "latest",
description="Test agent description",
prompt="Test agent prompt",
json_schema={
Expand All @@ -513,8 +532,7 @@ def list(self):
return [
AgentInfo(
id="agent-1",
name="test-agent-1",
version="1.0.0",
name="test-agent-1:1.0.0",
description="First test agent",
prompt="Test prompt 1",
json_schema={
Expand All @@ -528,8 +546,7 @@ def list(self):
),
AgentInfo(
id="agent-2",
name="test-agent-2",
version="2.0.0",
name="test-agent-2:2.0.0",
description="Second test agent",
prompt="Test prompt 2",
json_schema={
Expand All @@ -554,8 +571,7 @@ def create(self, config, name=None, inputs=None, callback_url=None):

return AgentCreationResponse(
id=f"agent-{name or 'created'}",
name=name or "created-agent",
version="1.0.0",
name=name or "created-agent:1.0.0",
created_at=datetime.fromisoformat("2024-01-01T00:00:00+00:00"),
updated_at=datetime.fromisoformat("2024-01-01T00:00:00+00:00"),
status="pending",
Expand All @@ -564,7 +580,6 @@ def create(self, config, name=None, inputs=None, callback_url=None):
def execute(
self,
name,
version=None,
inputs=None,
batch=True,
config=None,
Expand All @@ -582,7 +597,6 @@ def execute(
return AgentExecutionResponse(
id=f"execution-{name}",
name=name,
version=version or "latest",
status="completed",
created_at=datetime.fromisoformat("2024-01-01T00:00:00+00:00"),
completed_at=datetime.fromisoformat("2024-01-01T00:00:01+00:00"),
Expand Down
30 changes: 12 additions & 18 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def test_agent_creation_response_creation(self):
"""Test creating an AgentCreationResponse instance."""
response_data = {
"id": "test-agent-123",
"name": "Test Agent",
"version": "1.0.0",
"name": "test-agent:1.0.0",
"created_at": datetime.now(),
"updated_at": datetime.now(),
"status": "completed",
Expand All @@ -29,8 +28,7 @@ def test_agent_creation_response_creation(self):
response = AgentCreationResponse(**response_data)

assert response.id == "test-agent-123"
assert response.name == "Test Agent"
assert response.version == "1.0.0"
assert response.name == "test-agent:1.0.0"
assert response.status == "completed"


Expand All @@ -40,11 +38,10 @@ class TestAgentMethods:
def test_agent_get_by_name_and_version(self, mock_client):
"""Test getting an agent by name and version."""
client = mock_client
response = client.agent.get(name="test-agent", version="1.0.0")
response = client.agent.get(name="test-agent:1.0.0")

assert isinstance(response, AgentInfo)
assert response.name == "test-agent"
assert response.version == "1.0.0"
assert response.name == "test-agent:1.0.0"
assert response.description == "Test agent description"
assert response.prompt == "Test agent prompt"
assert response.status == "completed"
Expand All @@ -57,18 +54,19 @@ def test_agent_get_by_id(self, mock_client):
assert isinstance(response, AgentInfo)
assert response.id == "agent-123"
assert response.name == "agent-agent-123"
assert response.version == "latest"

def test_agent_get_validation_error(self, mock_client):
"""Test that get method validates input parameters."""
client = mock_client

with pytest.raises(
ValueError, match="Only one of `id` or `name` can be provided."
ValueError, match="Only one of `id` or `name` or `prompt` can be provided."
):
client.agent.get(id="agent-123", name="test-agent")

with pytest.raises(ValueError, match="Either `id` or `name` must be provided."):
with pytest.raises(
ValueError, match="Either `id` or `name` or `prompt` must be provided."
):
client.agent.get()

def test_agent_list(self, mock_client):
Expand All @@ -79,8 +77,8 @@ def test_agent_list(self, mock_client):
assert isinstance(response, list)
assert len(response) == 2
assert all(isinstance(agent, AgentInfo) for agent in response)
assert response[0].name == "test-agent-1"
assert response[1].name == "test-agent-2"
assert response[0].name.startswith("test-agent-1")
assert response[1].name.startswith("test-agent-2")

def test_agent_create(self, mock_client):
"""Test creating an agent."""
Expand All @@ -99,7 +97,6 @@ def test_agent_create(self, mock_client):

assert isinstance(response, AgentCreationResponse)
assert response.name == "new-agent"
assert response.version == "1.0.0"
assert response.status == "pending"

def test_agent_create_validation_error(self, mock_client):
Expand All @@ -125,15 +122,13 @@ def test_agent_execute(self, mock_client):
)

response = client.agent.execute(
name="test-agent",
version="1.0.0",
name="test-agent:1.0.0",
inputs={"input": "test data"},
config=config,
)

assert isinstance(response, AgentExecutionResponse)
assert response.name == "test-agent"
assert response.version == "1.0.0"
assert response.name == "test-agent:1.0.0"
assert response.status == "completed"
assert response.response == {"result": "execution result"}
assert isinstance(response.usage, CreditUsage)
Expand All @@ -147,7 +142,6 @@ def test_agent_execute_without_version(self, mock_client):

assert isinstance(response, AgentExecutionResponse)
assert response.name == "test-agent"
assert response.version == "latest"

def test_agent_execute_batch_mode_required(self, mock_client):
"""Test that execute method requires batch mode."""
Expand Down
35 changes: 22 additions & 13 deletions vlmrun/client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,42 @@ def __init__(self, client: "VLMRunProtocol") -> None:
def get(
self,
name: str | None = None,
version: str | None = None,
id: str | None = None,
prompt: str | None = None,
) -> AgentInfo:
"""Get an agent by name and version.
"""Get an agent by name, id, or prompt. Only one of `name`, `id`, or `prompt` can be provided.

Args:
name: Name of the agent (lookup either by name + version or by id alone)
version: Version of the agent
name: Name of the agent
id: ID of the agent
prompt: Prompt of the agent

Raises:
APIError: If the agent is not found (404) or the agent name is invalid (400)

Returns:
AgentInfo: Agent information response
"""
if id and name:
raise ValueError("Only one of `id` or `name` can be provided.")
elif id is not None:
if id:
if name or prompt:
raise ValueError(
"Only one of `id` or `name` or `prompt` can be provided."
)
data = {"id": id}
elif name is not None:
data = {"name": name, "version": version}
elif name:
if id or prompt:
raise ValueError(
"Only one of `id` or `name` or `prompt` can be provided."
)
data = {"name": name}
elif prompt:
if id or name:
raise ValueError(
"Only one of `id` or `name` or `prompt` can be provided."
)
data = {"prompt": prompt}
else:
raise ValueError("Either `id` or `name` must be provided.")
raise ValueError("Either `id` or `name` or `prompt` must be provided.")

response, status_code, headers = self._requestor.request(
method="GET",
Expand Down Expand Up @@ -124,7 +136,6 @@ def create(
def execute(
self,
name: str | None = None,
version: str | None = None,
inputs: Optional[dict[str, Any]] = None,
batch: bool = True,
config: Optional[AgentExecutionConfig] = None,
Expand All @@ -135,7 +146,6 @@ def execute(

Args:
name: Name of the agent to execute. If not provided, we use the prompt to identify the unique agent.
version: Optional version of the agent to execute
inputs: Optional inputs to the agent
batch: Whether to process in batch mode (async)
config: Optional agent execution configuration
Expand All @@ -150,7 +160,6 @@ def execute(

data = {
"name": name,
"version": version,
"batch": batch,
"inputs": inputs,
}
Expand Down
3 changes: 0 additions & 3 deletions vlmrun/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ class AgentExecutionResponse(BaseModel):

id: str = Field(..., description="ID of the agent")
name: str = Field(..., description="Name of the agent")
version: str = Field(..., description="Version of the agent.")
created_at: datetime = Field(
..., description="Date and time when the agent was created (in UTC timezone)"
)
Expand Down Expand Up @@ -259,7 +258,6 @@ class AgentExecutionConfig(AgentExecutionOrCreationConfig):
class AgentInfo(BaseModel):
id: str = Field(..., description="ID of the agent")
name: str = Field(..., description="Name of the agent")
version: str = Field(..., description="Version of the agent.")
description: str = Field(..., description="Description of the agent")
prompt: str = Field(..., description="The prompt of the agent")
json_schema: Optional[Dict[str, Any]] = Field(
Expand All @@ -280,7 +278,6 @@ class AgentInfo(BaseModel):
class AgentCreationResponse(BaseModel):
id: str = Field(..., description="ID of the agent")
name: str = Field(..., description="Name of the agent")
version: str = Field(..., description="Version of the agent.")
created_at: datetime = Field(
..., description="Date and time when the agent was created (in UTC timezone)"
)
Expand Down
2 changes: 1 addition & 1 deletion vlmrun/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.2"
__version__ = "0.3.3"