From dfc7898a6e5e0d99e3a1125d6b20a9ef769a0369 Mon Sep 17 00:00:00 2001 From: wiseaidev Date: Sat, 6 Aug 2022 20:15:30 +0300 Subject: [PATCH] general codebase refactor Signed-off-by: wiseaidev --- aredis_om/checks.py | 4 +-- aredis_om/model/encoders.py | 6 ++-- aredis_om/model/model.py | 48 ++++++++++++++++---------------- aredis_om/model/render_tree.py | 18 ++++++------ aredis_om/unasync_util.py | 2 +- tests/test_oss_redis_features.py | 2 +- 6 files changed, 39 insertions(+), 41 deletions(-) diff --git a/aredis_om/checks.py b/aredis_om/checks.py index be2332cf..ea59629a 100644 --- a/aredis_om/checks.py +++ b/aredis_om/checks.py @@ -12,7 +12,7 @@ async def check_for_command(conn, cmd): @lru_cache(maxsize=None) async def has_redis_json(conn=None): - if conn is None: + if not conn: conn = get_redis_connection() command_exists = await check_for_command(conn, "json.set") return command_exists @@ -20,7 +20,7 @@ async def has_redis_json(conn=None): @lru_cache(maxsize=None) async def has_redisearch(conn=None): - if conn is None: + if not conn: conn = get_redis_connection() if has_redis_json(conn): return True diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index 4007640f..3d541054 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -64,9 +64,9 @@ def jsonable_encoder( custom_encoder: Dict[Any, Callable[[Any], Any]] = {}, sqlalchemy_safe: bool = True, ) -> Any: - if include is not None and not isinstance(include, (set, dict)): + if include and not isinstance(include, (set, dict)): include = set(include) - if exclude is not None and not isinstance(exclude, (set, dict)): + if exclude and not isinstance(exclude, (set, dict)): exclude = set(exclude) if isinstance(obj, BaseModel): @@ -107,7 +107,7 @@ def jsonable_encoder( or (not isinstance(key, str)) or (not key.startswith("_sa")) ) - and (value is not None or not exclude_none) + and (value or not exclude_none) and ((include and key in include) or not exclude or key not in exclude) ): encoded_key = jsonable_encoder( diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 92bb6f9a..a69101eb 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -117,7 +117,7 @@ def is_supported_container_type(typ: Optional[type]) -> bool: def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]): - for field_name in field_values.keys(): + for field_name in field_values: if "__" in field_name: obj = model for sub_field in field_name.split("__"): @@ -432,11 +432,11 @@ def validate_sort_fields(self, sort_fields: List[str]): @staticmethod def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes: - if getattr(field.field_info, "primary_key", None) is True: + if getattr(field.field_info, "primary_key", None): return RediSearchFieldTypes.TAG elif op is Operators.LIKE: fts = getattr(field.field_info, "full_text_search", None) - if fts is not True: # Could be PydanticUndefined + if not fts: # Could be PydanticUndefined raise QuerySyntaxError( f"You tried to do a full-text search on the field '{field.name}', " f"but the field is not indexed for full-text search. Use the " @@ -464,7 +464,7 @@ def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes # is not itself directly indexed, but instead, we index any fields # within the model inside the list marked as `index=True`. return RediSearchFieldTypes.TAG - elif container_type is not None: + elif container_type: raise QuerySyntaxError( "Only lists and tuples are supported for multi-value fields. " f"Docs: {ERRORS_URL}#E4" @@ -567,7 +567,7 @@ def resolve_value( # The value contains the TAG field separator. We can work # around this by breaking apart the values and unioning them # with multiple field:{} queries. - values: filter = filter(None, value.split(separator_char)) + values: List[str] = [val for val in value.split(separator_char) if val] for value in values: value = escaper.escape(value) result += f"@{field_name}:{{{value}}}" @@ -1131,7 +1131,7 @@ async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel": raise NotImplementedError async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None): - if pipeline is None: + if not pipeline: db = self.db() else: db = pipeline @@ -1195,12 +1195,12 @@ def to_string(s): step = 2 # Because the result has content offset = 1 # The first item is the count of total matches. - for i in xrange(1, len(res), step): + for i in range(1, len(res), step): fields_offset = offset fields = dict( dict( - izip( + zip( map(to_string, res[i + fields_offset][::2]), map(to_string, res[i + fields_offset][1::2]), ) @@ -1244,7 +1244,7 @@ async def add( pipeline: Optional[Pipeline] = None, pipeline_verifier: Callable[..., Any] = verify_pipeline_response, ) -> Sequence["RedisModel"]: - if pipeline is None: + if not pipeline: # By default, send commands in a pipeline. Saving each model will # be atomic, but Redis may process other commands in between # these saves. @@ -1261,7 +1261,7 @@ async def add( # If the user didn't give us a pipeline, then we need to execute # the one we just created. - if pipeline is None: + if not pipeline: result = await db.execute() pipeline_verifier(result, expected_responses=len(models)) @@ -1303,7 +1303,7 @@ def __init_subclass__(cls, **kwargs): async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel": self.check() - if pipeline is None: + if not pipeline: db = self.db() else: db = pipeline @@ -1356,7 +1356,7 @@ def _get_value(cls, *args, **kwargs) -> Any: values. Is there a better way? """ val = super()._get_value(*args, **kwargs) - if val is None: + if not val: return "" return val @@ -1392,7 +1392,7 @@ def schema_for_fields(cls): name, _type, field.field_info ) schema_parts.append(redisearch_field) - elif getattr(field.field_info, "index", None) is True: + elif getattr(field.field_info, "index", None): schema_parts.append(cls.schema_for_type(name, _type, field.field_info)) elif is_subscripted_type: # Ignore subscripted types (usually containers!) that we don't @@ -1437,7 +1437,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): elif any(issubclass(typ, t) for t in NUMERIC_TYPES): schema = f"{name} NUMERIC" elif issubclass(typ, str): - if getattr(field_info, "full_text_search", False) is True: + if getattr(field_info, "full_text_search", False): schema = ( f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " f"{name} AS {name}_fts TEXT" @@ -1455,7 +1455,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): schema = " ".join(sub_fields) else: schema = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" - if schema and sortable is True: + if schema and sortable: schema += " SORTABLE" return schema @@ -1475,7 +1475,7 @@ def __init__(self, *args, **kwargs): async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel": self.check() - if pipeline is None: + if not pipeline: db = self.db() else: db = pipeline @@ -1633,7 +1633,7 @@ def schema_for_type( parent_type=typ, ) ) - return " ".join(filter(None, sub_fields)) + return " ".join([sub_field for sub_field in sub_fields if sub_field]) # NOTE: This is the termination point for recursion. We've descended # into models and lists until we found an actual value to index. elif should_index: @@ -1655,28 +1655,28 @@ def schema_for_type( # TODO: GEO field if parent_is_container_type or parent_is_model_in_container: - if typ is not str: + if not isinstance(typ, str): raise RedisModelError( "In this Preview release, list and tuple fields can only " f"contain strings. Problem field: {name}. See docs: TODO" ) - if full_text_search is True: + if full_text_search: raise RedisModelError( "List and tuple fields cannot be indexed for full-text " f"search. Problem field: {name}. See docs: TODO" ) schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" - if sortable is True: + if sortable: raise sortable_tag_error elif any(issubclass(typ, t) for t in NUMERIC_TYPES): schema = f"{path} AS {index_field_name} NUMERIC" elif issubclass(typ, str): - if full_text_search is True: + if full_text_search: schema = ( f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " f"{path} AS {index_field_name}_fts TEXT" ) - if sortable is True: + if sortable: # NOTE: With the current preview release, making a field # full-text searchable and sortable only makes the TEXT # field sortable. This means that results for full-text @@ -1685,11 +1685,11 @@ def schema_for_type( schema += " SORTABLE" else: schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" - if sortable is True: + if sortable: raise sortable_tag_error else: schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" - if sortable is True: + if sortable: raise sortable_tag_error return schema return "" diff --git a/aredis_om/model/render_tree.py b/aredis_om/model/render_tree.py index 8ac5748d..5366e8f5 100644 --- a/aredis_om/model/render_tree.py +++ b/aredis_om/model/render_tree.py @@ -21,7 +21,7 @@ def render_tree( write to a StringIO buffer, then use that buffer to accumulate written lines during recursive calls to render_tree(). """ - if buffer is None: + if not buffer: buffer = io.StringIO() if hasattr(current_node, nameattr): name = lambda node: getattr(node, nameattr) # noqa: E731 @@ -31,11 +31,9 @@ def render_tree( up = getattr(current_node, left_child, None) down = getattr(current_node, right_child, None) - if up is not None: + if up: next_last = "up" - next_indent = "{0}{1}{2}".format( - indent, " " if "up" in last else "|", " " * len(str(name(current_node))) - ) + next_indent = f'{indent}{" " if "up" in last else "|"}{" " * len(str(name(current_node)))}' render_tree( up, nameattr, left_child, right_child, next_indent, next_last, buffer ) @@ -49,7 +47,7 @@ def render_tree( else: start_shape = "├" - if up is not None and down is not None: + if up and down: end_shape = "┤" elif up: end_shape = "┘" @@ -59,14 +57,14 @@ def render_tree( end_shape = "" print( - "{0}{1}{2}{3}".format(indent, start_shape, name(current_node), end_shape), + f"{indent}{start_shape}{name(current_node)}{end_shape}", file=buffer, ) - if down is not None: + if down: next_last = "down" - next_indent = "{0}{1}{2}".format( - indent, " " if "down" in last else "|", " " * len(str(name(current_node))) + next_indent = ( + f'{indent}{" " if "down" in last else "|"}{len(str(name(current_node)))}' ) render_tree( down, nameattr, left_child, right_child, next_indent, next_last, buffer diff --git a/aredis_om/unasync_util.py b/aredis_om/unasync_util.py index 3bea28fe..c34eed27 100644 --- a/aredis_om/unasync_util.py +++ b/aredis_om/unasync_util.py @@ -14,7 +14,7 @@ async def f(): return None obj = f() - if obj is None: + if not obj: return False else: obj.close() # prevent unawaited coroutine warning diff --git a/tests/test_oss_redis_features.py b/tests/test_oss_redis_features.py index bdad16c9..30f8c02f 100644 --- a/tests/test_oss_redis_features.py +++ b/tests/test_oss_redis_features.py @@ -80,7 +80,7 @@ async def members(m): async def test_all_keys(members, m): pks = sorted([pk async for pk in await m.Member.all_pks()]) assert len(pks) == 3 - assert pks == sorted([m.pk for m in members]) + assert pks == sorted(m.pk for m in members) @py_test_mark_asyncio