Skip to content

Commit

Permalink
fix: Adopt connection pooling for HBase (feast-dev#3793)
Browse files Browse the repository at this point in the history
  • Loading branch information
sudohainguyen committed Oct 21, 2023
1 parent 175d796 commit b3852bf
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 81 deletions.
6 changes: 5 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,11 @@ def _list_feature_views(
for fv in self._registry.list_feature_views(
self.project, allow_cache=allow_cache
):
if hide_dummy_entity and fv.entities[0] == DUMMY_ENTITY_NAME:
if (
hide_dummy_entity
and fv.entities
and fv.entities[0] == DUMMY_ENTITY_NAME
):
fv.entities = []
fv.entity_columns = []
feature_views.append(fv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

from happybase import Connection
from happybase import ConnectionPool
from happybase.connection import DEFAULT_PROTOCOL, DEFAULT_TRANSPORT
from pydantic import StrictStr
from pydantic.typing import Literal

from feast import Entity
from feast.feature_view import FeatureView
from feast.infra.online_stores.helpers import compute_entity_id
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.hbase_utils import HbaseConstants, HbaseUtils
from feast.infra.utils.hbase_utils import HBaseConnector, HbaseConstants
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
Expand All @@ -23,35 +25,20 @@ class HbaseOnlineStoreConfig(FeastConfigBaseModel):
type: Literal["hbase"] = "hbase"
"""Online store type selector"""

host: str
host: StrictStr
"""Hostname of Hbase Thrift server"""

port: str
port: StrictStr
"""Port in which Hbase Thrift server is running"""

connection_pool_size: int = 4
"""Number of connections to Hbase Thrift server to keep in the connection pool"""

class HbaseConnection:
"""
Hbase connecttion to connect to hbase.
Attributes:
store_config: Online store config for Hbase store.
"""
protocol: StrictStr = DEFAULT_PROTOCOL
"""Protocol used to communicate with Hbase Thrift server"""

def __init__(self, store_config: HbaseOnlineStoreConfig):
self._store_config = store_config
self._real_conn = Connection(
host=store_config.host, port=int(store_config.port)
)

@property
def real_conn(self) -> Connection:
"""Stores the real happybase Connection to connect to hbase."""
return self._real_conn

def close(self) -> None:
"""Close the happybase connection."""
self.real_conn.close()
transport: StrictStr = DEFAULT_TRANSPORT
"""Transport used to communicate with Hbase Thrift server"""


class HbaseOnlineStore(OnlineStore):
Expand All @@ -62,7 +49,7 @@ class HbaseOnlineStore(OnlineStore):
_conn: Happybase Connection to connect to hbase thrift server.
"""

_conn: Connection = None
_conn: ConnectionPool = None

def _get_conn(self, config: RepoConfig):
"""
Expand All @@ -76,7 +63,13 @@ def _get_conn(self, config: RepoConfig):
assert isinstance(store_config, HbaseOnlineStoreConfig)

if not self._conn:
self._conn = Connection(host=store_config.host, port=int(store_config.port))
self._conn = ConnectionPool(
host=store_config.host,
port=int(store_config.port),
size=int(store_config.connection_pool_size),
protocol=store_config.protocol,
transport=store_config.transport,
)
return self._conn

@log_exceptions_and_usage(online_store="hbase")
Expand All @@ -102,7 +95,7 @@ def online_write_batch(
the online store. Can be used to display progress.
"""

hbase = HbaseUtils(self._get_conn(config))
hbase = HBaseConnector(self._get_conn(config))
project = config.project
table_name = self._table_id(project, table)

Expand Down Expand Up @@ -154,7 +147,7 @@ def online_read(
entity_keys: a list of entity keys that should be read from the FeatureStore.
requested_features: a list of requested feature names.
"""
hbase = HbaseUtils(self._get_conn(config))
hbase = HBaseConnector(self._get_conn(config))
project = config.project
table_name = self._table_id(project, table)

Expand Down Expand Up @@ -206,7 +199,7 @@ def update(
tables_to_delete: Tables to delete from the Hbase Online Store.
tables_to_keep: Tables to keep in the Hbase Online Store.
"""
hbase = HbaseUtils(self._get_conn(config))
hbase = HBaseConnector(self._get_conn(config))
project = config.project

# We don't create any special state for the entites in this implementation.
Expand All @@ -232,7 +225,7 @@ def teardown(
config: The RepoConfig for the current FeatureStore.
tables: Tables to delete from the feature repo.
"""
hbase = HbaseUtils(self._get_conn(config))
hbase = HBaseConnector(self._get_conn(config))
project = config.project

for table in tables:
Expand Down
127 changes: 78 additions & 49 deletions sdk/python/feast/infra/utils/hbase_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from typing import List

from happybase import Connection

from feast.infra.key_encoding_utils import serialize_entity_key
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
from happybase import ConnectionPool


class HbaseConstants:
Expand All @@ -28,7 +25,7 @@ def get_col_from_feature(feature):
return HbaseConstants.DEFAULT_COLUMN_FAMILY + ":" + feature


class HbaseUtils:
class HBaseConnector:
"""
Utils class to manage different Hbase operations.
Expand All @@ -40,14 +37,22 @@ class HbaseUtils:
"""

def __init__(
self, conn: Connection = None, host: str = None, port: int = None, timeout=None
self,
pool: ConnectionPool = None,
host: str = None,
port: int = None,
connection_pool_size: int = 4,
):
if conn is None:
if pool is None:
self.host = host
self.port = port
self.conn = Connection(host=host, port=port, timeout=timeout)
self.pool = ConnectionPool(
host=host,
port=port,
size=connection_pool_size,
)
else:
self.conn = conn
self.pool = pool

def create_table(self, table_name: str, colm_family: List[str]):
"""
Expand All @@ -60,7 +65,9 @@ def create_table(self, table_name: str, colm_family: List[str]):
cf_dict: dict = {}
for cf in colm_family:
cf_dict[cf] = dict()
return self.conn.create_table(table_name, cf_dict)

with self.pool.connection() as conn:
return conn.create_table(table_name, cf_dict)

def create_table_with_default_cf(self, table_name: str):
"""
Expand All @@ -69,7 +76,8 @@ def create_table_with_default_cf(self, table_name: str):
Arguments:
table_name: Name of the Hbase table.
"""
return self.conn.create_table(table_name, {"default": dict()})
with self.pool.connection() as conn:
return conn.create_table(table_name, {"default": dict()})

def check_if_table_exist(self, table_name: str):
"""
Expand All @@ -78,16 +86,18 @@ def check_if_table_exist(self, table_name: str):
Arguments:
table_name: Name of the Hbase table.
"""
return bytes(table_name, "utf-8") in self.conn.tables()
with self.pool.connection() as conn:
return bytes(table_name, "utf-8") in conn.tables()

def batch(self, table_name: str):
"""
Returns a 'Batch' instance that can be used for mass data manipulation in the hbase table.
Returns a "Batch" instance that can be used for mass data manipulation in the hbase table.
Arguments:
table_name: Name of the Hbase table.
"""
return self.conn.table(table_name).batch()
with self.pool.connection() as conn:
return conn.table(table_name).batch()

def put(self, table_name: str, row_key: str, data: dict):
"""
Expand All @@ -98,8 +108,9 @@ def put(self, table_name: str, row_key: str, data: dict):
row_key: Row key of the row to be inserted to hbase table.
data: Mapping of column family name:column name to column values
"""
table = self.conn.table(table_name)
table.put(row_key, data)
with self.pool.connection() as conn:
table = conn.table(table_name)
table.put(row_key, data)

def row(
self,
Expand All @@ -119,8 +130,9 @@ def row(
timestamp: timestamp specifies the maximum version the cells can have.
include_timestamp: specifies if (column, timestamp) to be return instead of only column.
"""
table = self.conn.table(table_name)
return table.row(row_key, columns, timestamp, include_timestamp)
with self.pool.connection() as conn:
table = conn.table(table_name)
return table.row(row_key, columns, timestamp, include_timestamp)

def rows(
self,
Expand All @@ -140,52 +152,69 @@ def rows(
timestamp: timestamp specifies the maximum version the cells can have.
include_timestamp: specifies if (column, timestamp) to be return instead of only column.
"""
table = self.conn.table(table_name)
return table.rows(row_keys, columns, timestamp, include_timestamp)
with self.pool.connection() as conn:
table = conn.table(table_name)
return table.rows(row_keys, columns, timestamp, include_timestamp)

def print_table(self, table_name):
"""Prints the table scanning all the rows of the hbase table."""
table = self.conn.table(table_name)
scan_data = table.scan()
for row_key, cols in scan_data:
print(row_key.decode("utf-8"), cols)
with self.pool.connection() as conn:
table = conn.table(table_name)
scan_data = table.scan()
for row_key, cols in scan_data:
print(row_key.decode("utf-8"), cols)

def delete_table(self, table: str):
"""Deletes the hbase table given the table name."""
if self.check_if_table_exist(table):
self.conn.delete_table(table, disable=True)
with self.pool.connection() as conn:
conn.delete_table(table, disable=True)

def close_conn(self):
"""Closes the happybase connection."""
self.conn.close()
with self.pool.connection() as conn:
conn.close()


def main():
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
from feast.protos.feast.types.Value_pb2 import Value

connection = Connection(host="localhost", port=9090)
table = connection.table("test_hbase_driver_hourly_stats")
row_keys = [
serialize_entity_key(
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1004)]),
entity_key_serialization_version=2,
).hex(),
serialize_entity_key(
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1005)]),
entity_key_serialization_version=2,
).hex(),
serialize_entity_key(
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1024)]),
entity_key_serialization_version=2,
).hex(),
]
rows = table.rows(row_keys)

for row_key, row in rows:
for key, value in row.items():
col_name = bytes.decode(key, "utf-8").split(":")[1]
print(col_name, value)
print()
pool = ConnectionPool(
host="localhost",
port=9090,
size=2,
)
with pool.connection() as connection:
table = connection.table("test_hbase_driver_hourly_stats")
row_keys = [
serialize_entity_key(
EntityKey(
join_keys=["driver_id"], entity_values=[Value(int64_val=1004)]
),
entity_key_serialization_version=2,
).hex(),
serialize_entity_key(
EntityKey(
join_keys=["driver_id"], entity_values=[Value(int64_val=1005)]
),
entity_key_serialization_version=2,
).hex(),
serialize_entity_key(
EntityKey(
join_keys=["driver_id"], entity_values=[Value(int64_val=1024)]
),
entity_key_serialization_version=2,
).hex(),
]
rows = table.rows(row_keys)

for _, row in rows:
for key, value in row.items():
col_name = bytes.decode(key, "utf-8").split(":")[1]
print(col_name, value)
print()


if __name__ == "__main__":
Expand Down

0 comments on commit b3852bf

Please sign in to comment.