Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions replicated/async_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional

from .fingerprint import get_machine_fingerprint
from .http_client import AsyncHTTPClient
from .services import AsyncCustomerService
from .state import StateManager
Expand All @@ -21,6 +22,7 @@ def __init__(
self.base_url = base_url
self.timeout = timeout
self.state_directory = state_directory
self._machine_id = get_machine_fingerprint()

self.http_client = AsyncHTTPClient(
base_url=base_url,
Expand Down
2 changes: 2 additions & 0 deletions replicated/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional

from .fingerprint import get_machine_fingerprint
from .http_client import SyncHTTPClient
from .services import CustomerService
from .state import StateManager
Expand All @@ -21,6 +22,7 @@ def __init__(
self.base_url = base_url
self.timeout = timeout
self.state_directory = state_directory
self._machine_id = get_machine_fingerprint()

self.http_client = SyncHTTPClient(
base_url=base_url,
Expand Down
12 changes: 6 additions & 6 deletions replicated/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
self._client = client
self.customer_id = customer_id
self.instance_id = instance_id
self._machine_id = client._machine_id
self._data = kwargs
self._status = "ready"
self._metrics: dict[str, Union[int, float, str]] = {}
Expand All @@ -80,7 +81,7 @@ def send_metric(self, name: str, value: Union[int, float, str]) -> None:
headers = {
**self._client._get_auth_headers(),
"X-Replicated-InstanceID": self.instance_id,
"X-Replicated-ClusterID": self.instance_id,
"X-Replicated-ClusterID": self._machine_id,
"X-Replicated-AppStatus": self._status,
}

Expand Down Expand Up @@ -148,11 +149,10 @@ def _report_instance(self) -> None:
json.dumps(instance_tags).encode()
).decode()

# cluster_id is same as instance_id for non-K8s environments
headers = {
**self._client._get_auth_headers(),
"X-Replicated-InstanceID": self.instance_id,
"X-Replicated-ClusterID": self.instance_id,
"X-Replicated-ClusterID": self._machine_id,
"X-Replicated-AppStatus": self._status,
"X-Replicated-InstanceTagData": instance_tags_b64,
}
Expand Down Expand Up @@ -185,6 +185,7 @@ def __init__(
self._client = client
self.customer_id = customer_id
self.instance_id = instance_id
self._machine_id = client._machine_id
self._data = kwargs
self._status = "ready"
self._metrics: dict[str, Union[int, float, str]] = {}
Expand All @@ -201,7 +202,7 @@ async def send_metric(self, name: str, value: Union[int, float, str]) -> None:
headers = {
**self._client._get_auth_headers(),
"X-Replicated-InstanceID": self.instance_id,
"X-Replicated-ClusterID": self.instance_id,
"X-Replicated-ClusterID": self._machine_id,
"X-Replicated-AppStatus": self._status,
}

Expand Down Expand Up @@ -269,11 +270,10 @@ async def _report_instance(self) -> None:
json.dumps(instance_tags).encode()
).decode()

# cluster_id is same as instance_id for non-K8s environments
headers = {
**self._client._get_auth_headers(),
"X-Replicated-InstanceID": self.instance_id,
"X-Replicated-ClusterID": self.instance_id,
"X-Replicated-ClusterID": self._machine_id,
"X-Replicated-AppStatus": self._status,
"X-Replicated-InstanceTagData": instance_tags_b64,
}
Expand Down
133 changes: 133 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,68 @@ def test_default_state_directory_unchanged(self):
assert "my-app" in state_dir_str
assert "Replicated" in state_dir_str

def test_client_has_machine_id(self):
"""Test that client initializes with a machine_id."""
client = ReplicatedClient(publishable_key="pk_test_123", app_slug="my-app")
assert hasattr(client, "_machine_id")
assert client._machine_id is not None
assert isinstance(client._machine_id, str)
assert len(client._machine_id) == 64 # SHA256 hash

@patch("replicated.http_client.httpx.Client")
def test_instance_has_machine_id_from_client(self, mock_httpx):
"""Test that instances created from client have the client's machine_id."""
from replicated.resources import Instance

mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"customer": {
"id": "customer_123",
"email": "test@example.com",
"name": "test user",
"serviceToken": "service_token_123",
"instanceId": "instance_123",
}
}

mock_client = Mock()
mock_client.request.return_value = mock_response
mock_httpx.return_value = mock_client

client = ReplicatedClient(publishable_key="pk_test_123", app_slug="my-app")
customer = client.customer.get_or_create("test@example.com")
instance = customer.get_or_create_instance()

assert isinstance(instance, Instance)
assert hasattr(instance, "_machine_id")
assert instance._machine_id == client._machine_id

@patch("replicated.http_client.httpx.Client")
def test_instance_uses_machine_id_in_headers(self, mock_httpx):
"""Test that instance methods use machine_id as cluster ID in headers."""
from replicated.resources import Instance

mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {}

mock_client = Mock()
mock_client.request.return_value = mock_response
mock_httpx.return_value = mock_client

client = ReplicatedClient(publishable_key="pk_test_123", app_slug="my-app")
instance = Instance(client, "customer_123", "instance_123")

# Send a metric
instance.send_metric("test_metric", 42)

# Verify the request was made with correct headers
call_args = mock_client.request.call_args
headers = call_args[1]["headers"]
assert "X-Replicated-ClusterID" in headers
assert headers["X-Replicated-ClusterID"] == client._machine_id


class TestAsyncReplicatedClient:
@pytest.mark.asyncio
Expand Down Expand Up @@ -168,3 +230,74 @@ async def test_default_state_directory_unchanged(self):
state_dir_str = str(client.state_manager._state_dir)
assert "my-app" in state_dir_str
assert "Replicated" in state_dir_str

@pytest.mark.asyncio
async def test_client_has_machine_id(self):
"""Test that async client initializes with a machine_id."""
client = AsyncReplicatedClient(publishable_key="pk_test_123", app_slug="my-app")
assert hasattr(client, "_machine_id")
assert client._machine_id is not None
assert isinstance(client._machine_id, str)
assert len(client._machine_id) == 64 # SHA256 hash

@pytest.mark.asyncio
async def test_instance_has_machine_id_from_client(self):
"""Test that async instances have the client's machine_id."""
from replicated.resources import AsyncInstance

with patch("replicated.http_client.httpx.AsyncClient") as mock_httpx:
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {
"customer": {
"id": "customer_123",
"email": "test@example.com",
"name": "test user",
"serviceToken": "service_token_123",
"instanceId": "instance_123",
}
}

mock_client = Mock()
mock_client.request.return_value = mock_response
mock_httpx.return_value = mock_client

client = AsyncReplicatedClient(
publishable_key="pk_test_123", app_slug="my-app"
)
customer = await client.customer.get_or_create("test@example.com")
instance = await customer.get_or_create_instance()

assert isinstance(instance, AsyncInstance)
assert hasattr(instance, "_machine_id")
assert instance._machine_id == client._machine_id

@pytest.mark.asyncio
async def test_instance_uses_machine_id_in_headers(self):
"""Test that async instance methods use machine_id as cluster ID in headers."""
from unittest.mock import AsyncMock

from replicated.resources import AsyncInstance

with patch("replicated.http_client.httpx.AsyncClient") as mock_httpx:
mock_response = Mock()
mock_response.is_success = True
mock_response.json.return_value = {}

mock_client = Mock()
mock_client.request = AsyncMock(return_value=mock_response)
mock_httpx.return_value = mock_client

client = AsyncReplicatedClient(
publishable_key="pk_test_123", app_slug="my-app"
)
instance = AsyncInstance(client, "customer_123", "instance_123")

# Send a metric
await instance.send_metric("test_metric", 42)

# Verify the request was made with correct headers
call_args = mock_client.request.call_args
headers = call_args[1]["headers"]
assert "X-Replicated-ClusterID" in headers
assert headers["X-Replicated-ClusterID"] == client._machine_id