Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ import numpy as np
from np.flight import Client

# Initialize the Flight client
client = Client('grpc://localhost:8815')
with Client('grpc://localhost:8815') as client:
...

```

### Sending Data
Expand Down
24 changes: 17 additions & 7 deletions src/np/flight/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,23 @@ def __init__(self, location, **kwargs) -> None:
Args:
client: An initialized Flight client for handling network communication.
"""
self._client = fl.connect(location, **kwargs)
self._location = location
self._kwargs = kwargs

def __enter__(self):
"""
Open the database connection
"""
self._client = fl.connect(self._location, **self._kwargs)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""
Close the connection
"""
self.flight.close()
if exc_val: # pragma: no cover
raise

@property
def flight(self):
Expand Down Expand Up @@ -134,9 +150,3 @@ def compute(self, command: str, data: Dict[str, np.ndarray]) -> Dict[str, np.nda

# Retrieve and convert results back to NumPy arrays
return pa_2_np(self.get(command))

def close(self):
"""
Close the Flight server.
"""
self.flight.close()
11 changes: 5 additions & 6 deletions src/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ def numpy_client(mocker, flight_client):
kwargs = {"key": "value"}

# Act
numpy_client = Client(location, **kwargs)

# Assert
mock_fl_connect.assert_called_once_with(location, **kwargs)
assert numpy_client.flight == flight_client
return numpy_client
with Client(location, **kwargs) as numpy_client:
# Assert
mock_fl_connect.assert_called_once_with(location, **kwargs)
assert numpy_client.flight == flight_client
yield numpy_client


def test_numpy_client_init(numpy_client, flight_client):
Expand Down