Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions mcbackend/backends/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import time
from datetime import datetime, timezone
from typing import Dict, Optional, Sequence, Tuple
from typing import Callable, Dict, Optional, Sequence, Tuple

import clickhouse_driver
import numpy
Expand Down Expand Up @@ -215,9 +215,14 @@ class ClickHouseRun(Run):
"""Represents an MCMC run stored in ClickHouse."""

def __init__(
self, meta: RunMeta, *, created_at: datetime = None, client: clickhouse_driver.Client
self,
meta: RunMeta,
*,
created_at: datetime = None,
client_fn: Callable[[], clickhouse_driver.Client],
) -> None:
self._client = client
self._client_fn = client_fn
self._client = client_fn()
if created_at is None:
created_at = datetime.now().astimezone(timezone.utc)
self.created_at = created_at
Expand All @@ -229,7 +234,7 @@ def __init__(
def init_chain(self, chain_number: int) -> ClickHouseChain:
cmeta = ChainMeta(self.meta.rid, chain_number)
create_chain_table(self._client, cmeta, self.meta)
chain = ClickHouseChain(cmeta, self.meta, client=self._client)
chain = ClickHouseChain(cmeta, self.meta, client=self._client_fn())
if self._chains is None:
self._chains = []
self._chains.append(chain)
Expand All @@ -245,16 +250,39 @@ def get_chains(self) -> Tuple[ClickHouseChain]:
chains = []
for (cid,) in self._client.execute(f"SHOW TABLES LIKE '{self.meta.rid}%'"):
cm = ChainMeta(self.meta.rid, int(cid.split("_")[-1]))
chains.append(ClickHouseChain(cm, self.meta, client=self._client))
chains.append(ClickHouseChain(cm, self.meta, client=self._client_fn()))
return tuple(chains)


class ClickHouseBackend(Backend):
"""A backend to store samples in a ClickHouse database."""

def __init__(self, client: clickhouse_driver.Client) -> None:
def __init__(
self,
client: clickhouse_driver.Client = None,
client_fn: Callable[[], clickhouse_driver.Client] = None,
):
"""Create a ClickHouse backend around a database client.

Parameters
----------
client : clickhouse_driver.Client
One client to use for all runs and chains.
client_fn : callable
A function to create database clients.
Use this in multithreading scenarios to get higher insert performance.
"""
if client is None and client_fn is None:
raise ValueError("Either a `client` or a `client_fn` must be provided.")
self._client_fn = client_fn
self._client = client
create_runs_table(client)

if client_fn is None:
self._client_fn = lambda: client
if client is None:
self._client = self._client_fn()

create_runs_table(self._client)
super().__init__()

def init_run(self, meta: RunMeta) -> ClickHouseRun:
Expand All @@ -271,7 +299,7 @@ def init_run(self, meta: RunMeta) -> ClickHouseRun:
proto=base64.encodebytes(bytes(meta)).decode("ascii"),
)
self._client.execute(query, [params])
return ClickHouseRun(meta, client=self._client, created_at=created_at)
return ClickHouseRun(meta, client_fn=self._client_fn, created_at=created_at)

def get_runs(self) -> pandas.DataFrame:
df = self._client.query_dataframe(
Expand All @@ -295,5 +323,5 @@ def get_run(self, rid: str) -> ClickHouseRun:
data = base64.decodebytes(rows[0][2].encode("ascii"))
meta = RunMeta().parse(data)
return ClickHouseRun(
meta, client=self._client, created_at=rows[0][1].replace(tzinfo=timezone.utc)
meta, client_fn=self._client_fn, created_at=rows[0][1].replace(tzinfo=timezone.utc)
)
62 changes: 61 additions & 1 deletion mcbackend/test_backend_clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from subprocess import call
from typing import Sequence, Tuple

import clickhouse_driver
Expand Down Expand Up @@ -37,6 +38,63 @@ def fully_initialized(
return run, chains


@pytest.mark.skipif(
condition=not HAS_REAL_DB,
reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.",
)
class TestClickHouseBackendInitialization:
"""This is separate because ``TestClickHouseBackend.setup_method`` depends on these things."""

def test_exceptions(self):
with pytest.raises(ValueError, match="must be provided"):
ClickHouseBackend()
pass

def test_backend_from_client_object(self):
db = "testing_" + hagelkorn.random()
_client_main = clickhouse_driver.Client("localhost")
_client_main.execute(f"CREATE DATABASE {db};")

try:
# When created from a client object, all chains share the client
backend = ClickHouseBackend(client=clickhouse_driver.Client("localhost", database=db))
assert callable(backend._client_fn)
run = backend.init_run(make_runmeta())
c1 = run.init_chain(0)
c2 = run.init_chain(1)
assert c1._client is c2._client
finally:
_client_main.execute(f"DROP DATABASE {db};")
_client_main.disconnect()
pass

def test_backend_from_client_function(self):
db = "testing_" + hagelkorn.random()
_client_main = clickhouse_driver.Client("localhost")
_client_main.execute(f"CREATE DATABASE {db};")

def client_fn():
return clickhouse_driver.Client("localhost", database=db)

try:
# When created from a client function, each chain has its own client
backend = ClickHouseBackend(client_fn=client_fn)
assert backend._client is not None
run = backend.init_run(make_runmeta())
c1 = run.init_chain(0)
c2 = run.init_chain(1)
assert c1._client is not c2._client

# By passing both, one may use different settings
bclient = client_fn()
backend = ClickHouseBackend(client=bclient, client_fn=client_fn)
assert backend._client is bclient
finally:
_client_main.execute(f"DROP DATABASE {db};")
_client_main.disconnect()
pass


@pytest.mark.skipif(
condition=not HAS_REAL_DB,
reason="Integration tests need a ClickHouse server on localhost:9000 without authentication.",
Expand All @@ -52,7 +110,9 @@ def setup_method(self, method):
self._client_main = clickhouse_driver.Client("localhost")
self._client_main.execute(f"CREATE DATABASE {self._db};")
self._client = clickhouse_driver.Client("localhost", database=self._db)
self.backend = ClickHouseBackend(self._client)
self.backend = ClickHouseBackend(
client_fn=lambda: clickhouse_driver.Client("localhost", database=self._db)
)
return

def teardown_method(self, method):
Expand Down