diff --git a/mcbackend/backends/clickhouse.py b/mcbackend/backends/clickhouse.py index 0beb223..de13a78 100644 --- a/mcbackend/backends/clickhouse.py +++ b/mcbackend/backends/clickhouse.py @@ -5,7 +5,7 @@ import logging import time from datetime import datetime, timezone -from typing import Dict, Optional, Sequence, Tuple +from typing import Callable, Dict, Optional, Sequence, Tuple import clickhouse_driver import numpy @@ -215,9 +215,14 @@ class ClickHouseRun(Run): """Represents an MCMC run stored in ClickHouse.""" def __init__( - self, meta: RunMeta, *, created_at: datetime = None, client: clickhouse_driver.Client + self, + meta: RunMeta, + *, + created_at: datetime = None, + client_fn: Callable[[], clickhouse_driver.Client], ) -> None: - self._client = client + self._client_fn = client_fn + self._client = client_fn() if created_at is None: created_at = datetime.now().astimezone(timezone.utc) self.created_at = created_at @@ -229,7 +234,7 @@ def __init__( def init_chain(self, chain_number: int) -> ClickHouseChain: cmeta = ChainMeta(self.meta.rid, chain_number) create_chain_table(self._client, cmeta, self.meta) - chain = ClickHouseChain(cmeta, self.meta, client=self._client) + chain = ClickHouseChain(cmeta, self.meta, client=self._client_fn()) if self._chains is None: self._chains = [] self._chains.append(chain) @@ -245,16 +250,39 @@ def get_chains(self) -> Tuple[ClickHouseChain]: chains = [] for (cid,) in self._client.execute(f"SHOW TABLES LIKE '{self.meta.rid}%'"): cm = ChainMeta(self.meta.rid, int(cid.split("_")[-1])) - chains.append(ClickHouseChain(cm, self.meta, client=self._client)) + chains.append(ClickHouseChain(cm, self.meta, client=self._client_fn())) return tuple(chains) class ClickHouseBackend(Backend): """A backend to store samples in a ClickHouse database.""" - def __init__(self, client: clickhouse_driver.Client) -> None: + def __init__( + self, + client: clickhouse_driver.Client = None, + client_fn: Callable[[], clickhouse_driver.Client] = None, + ): + """Create a ClickHouse backend around a database client. + + Parameters + ---------- + client : clickhouse_driver.Client + One client to use for all runs and chains. + client_fn : callable + A function to create database clients. + Use this in multithreading scenarios to get higher insert performance. + """ + if client is None and client_fn is None: + raise ValueError("Either a `client` or a `client_fn` must be provided.") + self._client_fn = client_fn self._client = client - create_runs_table(client) + + if client_fn is None: + self._client_fn = lambda: client + if client is None: + self._client = self._client_fn() + + create_runs_table(self._client) super().__init__() def init_run(self, meta: RunMeta) -> ClickHouseRun: @@ -271,7 +299,7 @@ def init_run(self, meta: RunMeta) -> ClickHouseRun: proto=base64.encodebytes(bytes(meta)).decode("ascii"), ) self._client.execute(query, [params]) - return ClickHouseRun(meta, client=self._client, created_at=created_at) + return ClickHouseRun(meta, client_fn=self._client_fn, created_at=created_at) def get_runs(self) -> pandas.DataFrame: df = self._client.query_dataframe( @@ -295,5 +323,5 @@ def get_run(self, rid: str) -> ClickHouseRun: data = base64.decodebytes(rows[0][2].encode("ascii")) meta = RunMeta().parse(data) return ClickHouseRun( - meta, client=self._client, created_at=rows[0][1].replace(tzinfo=timezone.utc) + meta, client_fn=self._client_fn, created_at=rows[0][1].replace(tzinfo=timezone.utc) ) diff --git a/mcbackend/test_backend_clickhouse.py b/mcbackend/test_backend_clickhouse.py index 0adccf9..ff85160 100644 --- a/mcbackend/test_backend_clickhouse.py +++ b/mcbackend/test_backend_clickhouse.py @@ -1,4 +1,5 @@ import logging +from subprocess import call from typing import Sequence, Tuple import clickhouse_driver @@ -37,6 +38,63 @@ def fully_initialized( return run, chains +@pytest.mark.skipif( + condition=not HAS_REAL_DB, + reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.", +) +class TestClickHouseBackendInitialization: + """This is separate because ``TestClickHouseBackend.setup_method`` depends on these things.""" + + def test_exceptions(self): + with pytest.raises(ValueError, match="must be provided"): + ClickHouseBackend() + pass + + def test_backend_from_client_object(self): + db = "testing_" + hagelkorn.random() + _client_main = clickhouse_driver.Client("localhost") + _client_main.execute(f"CREATE DATABASE {db};") + + try: + # When created from a client object, all chains share the client + backend = ClickHouseBackend(client=clickhouse_driver.Client("localhost", database=db)) + assert callable(backend._client_fn) + run = backend.init_run(make_runmeta()) + c1 = run.init_chain(0) + c2 = run.init_chain(1) + assert c1._client is c2._client + finally: + _client_main.execute(f"DROP DATABASE {db};") + _client_main.disconnect() + pass + + def test_backend_from_client_function(self): + db = "testing_" + hagelkorn.random() + _client_main = clickhouse_driver.Client("localhost") + _client_main.execute(f"CREATE DATABASE {db};") + + def client_fn(): + return clickhouse_driver.Client("localhost", database=db) + + try: + # When created from a client function, each chain has its own client + backend = ClickHouseBackend(client_fn=client_fn) + assert backend._client is not None + run = backend.init_run(make_runmeta()) + c1 = run.init_chain(0) + c2 = run.init_chain(1) + assert c1._client is not c2._client + + # By passing both, one may use different settings + bclient = client_fn() + backend = ClickHouseBackend(client=bclient, client_fn=client_fn) + assert backend._client is bclient + finally: + _client_main.execute(f"DROP DATABASE {db};") + _client_main.disconnect() + pass + + @pytest.mark.skipif( condition=not HAS_REAL_DB, reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.", @@ -52,7 +110,9 @@ def setup_method(self, method): self._client_main = clickhouse_driver.Client("localhost") self._client_main.execute(f"CREATE DATABASE {self._db};") self._client = clickhouse_driver.Client("localhost", database=self._db) - self.backend = ClickHouseBackend(self._client) + self.backend = ClickHouseBackend( + client_fn=lambda: clickhouse_driver.Client("localhost", database=self._db) + ) return def teardown_method(self, method):