diff --git a/README.md b/README.md index f5f36b1..36370f9 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,9 @@ import numpy as np from np.flight import Client # Initialize the Flight client -client = Client('grpc://localhost:8815') +with Client('grpc://localhost:8815') as client: +... + ``` ### Sending Data diff --git a/src/np/flight/numpy_client.py b/src/np/flight/numpy_client.py index 438e79f..6f2bb4d 100644 --- a/src/np/flight/numpy_client.py +++ b/src/np/flight/numpy_client.py @@ -26,7 +26,23 @@ def __init__(self, location, **kwargs) -> None: Args: client: An initialized Flight client for handling network communication. """ - self._client = fl.connect(location, **kwargs) + self._location = location + self._kwargs = kwargs + + def __enter__(self): + """ + Open the database connection + """ + self._client = fl.connect(self._location, **self._kwargs) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Close the connection + """ + self.flight.close() + if exc_val: # pragma: no cover + raise @property def flight(self): @@ -134,9 +150,3 @@ def compute(self, command: str, data: Dict[str, np.ndarray]) -> Dict[str, np.nda # Retrieve and convert results back to NumPy arrays return pa_2_np(self.get(command)) - - def close(self): - """ - Close the Flight server. - """ - self.flight.close() diff --git a/src/tests/test_client.py b/src/tests/test_client.py index e9c1e3a..b0ccf17 100644 --- a/src/tests/test_client.py +++ b/src/tests/test_client.py @@ -28,12 +28,11 @@ def numpy_client(mocker, flight_client): kwargs = {"key": "value"} # Act - numpy_client = Client(location, **kwargs) - - # Assert - mock_fl_connect.assert_called_once_with(location, **kwargs) - assert numpy_client.flight == flight_client - return numpy_client + with Client(location, **kwargs) as numpy_client: + # Assert + mock_fl_connect.assert_called_once_with(location, **kwargs) + assert numpy_client.flight == flight_client + yield numpy_client def test_numpy_client_init(numpy_client, flight_client):