diff --git a/casbin_redis_watcher/__init__.py b/casbin_redis_watcher/__init__.py index e55ddc1..a896a23 100644 --- a/casbin_redis_watcher/__init__.py +++ b/casbin_redis_watcher/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. from .options import WatcherOptions -from .watcher import RedisWatcher, new_watcher +from .watcher import RedisWatcher, new_watcher, new_publish_watcher diff --git a/casbin_redis_watcher/watcher.py b/casbin_redis_watcher/watcher.py index 70ccabb..1a2ea14 100644 --- a/casbin_redis_watcher/watcher.py +++ b/casbin_redis_watcher/watcher.py @@ -30,7 +30,6 @@ def __init__(self): self.options: WatcherOptions = None self.close = None self.callback: callable = None - self.ctx = None self.subscribe_thread: Thread = Thread(target=self.subscribe, daemon=True) self.subscribe_event = Event() self.logger = logging.getLogger(__name__) @@ -42,18 +41,6 @@ def init_config(self, option: WatcherOptions): self.logger.warning("No callback function is set.Use the default callback function.") self.callback = self.default_callback_func - rds = Redis(host=option.host, port=option.port, password=option.password) - - if option.sub_client: - self.sub_client = option.sub_client - else: - self.sub_client = rds.client().pubsub() - - if option.pub_client: - self.pub_client = option.pub_client - else: - self.pub_client = rds.client() - self.options = option def set_update_callback(self, callback: callable): @@ -160,12 +147,10 @@ def new_watcher(option: WatcherOptions): option.init_config() w = RedisWatcher() rds = Redis(host=option.host, port=option.port, password=option.password) + if rds.ping() is False: + raise Exception("Redis server is not available.") w.sub_client = rds.client().pubsub() w.pub_client = rds.client() - if w.sub_client.ping() is False or w.pub_client.ping() is False: - w.logger.error("Casbin Redis Watcher error: Redis connection failed.") - w.ctx = None - w.close = None w.init_config(option) w.close = False w.subscribe_thread.start() @@ -173,12 +158,13 @@ def new_watcher(option: WatcherOptions): return w -# TODO -def new_publish_watcher(addr: str, option: WatcherOptions): - option.addr = addr +def new_publish_watcher(option: WatcherOptions): + option.init_config() w = RedisWatcher() - w.pub_client = Redis().client() - w.ctx = None - w.close = None + rds = Redis(host=option.host, port=option.port, password=option.password) + if rds.ping() is False: + raise Exception("Redis server is not available.") + w.pub_client = rds.client() w.init_config(option) + w.close = False return w diff --git a/tests/test_redis_watcher.py b/tests/test_redis_watcher.py index 40cf541..ffba488 100644 --- a/tests/test_redis_watcher.py +++ b/tests/test_redis_watcher.py @@ -19,7 +19,7 @@ import casbin import redis -from casbin_redis_watcher import new_watcher, WatcherOptions +from casbin_redis_watcher import WatcherOptions, new_watcher, new_publish_watcher def get_examples(path): @@ -37,6 +37,15 @@ def test_watcher_init(self): assert isinstance(w.sub_client, redis.client.PubSub) assert isinstance(w.pub_client, redis.client.Redis) + def test_publish_watcher_init(self): + test_option = WatcherOptions() + test_option.host = "localhost" + test_option.port = "6379" + test_option.optional_update_callback = lambda event: print("update callback, event: {}".format(event)) + w = new_publish_watcher(test_option) + assert w.sub_client is None + assert isinstance(w.pub_client, redis.client.Redis) + def test_watcher_init_without_callback(self): test_option = WatcherOptions() test_option.host = "localhost"