Skip to content

Commit

Permalink
general codebase refactor
Browse files Browse the repository at this point in the history
Signed-off-by: wiseaidev <business@wiseai.dev>
  • Loading branch information
wiseaidev committed Aug 7, 2022
1 parent dcd84e0 commit dfc7898
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 41 deletions.
4 changes: 2 additions & 2 deletions aredis_om/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ async def check_for_command(conn, cmd):

@lru_cache(maxsize=None)
async def has_redis_json(conn=None):
if conn is None:
if not conn:
conn = get_redis_connection()
command_exists = await check_for_command(conn, "json.set")
return command_exists


@lru_cache(maxsize=None)
async def has_redisearch(conn=None):
if conn is None:
if not conn:
conn = get_redis_connection()
if has_redis_json(conn):
return True
Expand Down
6 changes: 3 additions & 3 deletions aredis_om/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def jsonable_encoder(
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
sqlalchemy_safe: bool = True,
) -> Any:
if include is not None and not isinstance(include, (set, dict)):
if include and not isinstance(include, (set, dict)):
include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)):
if exclude and not isinstance(exclude, (set, dict)):
exclude = set(exclude)

if isinstance(obj, BaseModel):
Expand Down Expand Up @@ -107,7 +107,7 @@ def jsonable_encoder(
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and (value or not exclude_none)
and ((include and key in include) or not exclude or key not in exclude)
):
encoded_key = jsonable_encoder(
Expand Down
48 changes: 24 additions & 24 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def is_supported_container_type(typ: Optional[type]) -> bool:


def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
for field_name in field_values.keys():
for field_name in field_values:
if "__" in field_name:
obj = model
for sub_field in field_name.split("__"):
Expand Down Expand Up @@ -432,11 +432,11 @@ def validate_sort_fields(self, sort_fields: List[str]):

@staticmethod
def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes:
if getattr(field.field_info, "primary_key", None) is True:
if getattr(field.field_info, "primary_key", None):
return RediSearchFieldTypes.TAG
elif op is Operators.LIKE:
fts = getattr(field.field_info, "full_text_search", None)
if fts is not True: # Could be PydanticUndefined
if not fts: # Could be PydanticUndefined
raise QuerySyntaxError(
f"You tried to do a full-text search on the field '{field.name}', "
f"but the field is not indexed for full-text search. Use the "
Expand Down Expand Up @@ -464,7 +464,7 @@ def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes
# is not itself directly indexed, but instead, we index any fields
# within the model inside the list marked as `index=True`.
return RediSearchFieldTypes.TAG
elif container_type is not None:
elif container_type:
raise QuerySyntaxError(
"Only lists and tuples are supported for multi-value fields. "
f"Docs: {ERRORS_URL}#E4"
Expand Down Expand Up @@ -567,7 +567,7 @@ def resolve_value(
# The value contains the TAG field separator. We can work
# around this by breaking apart the values and unioning them
# with multiple field:{} queries.
values: filter = filter(None, value.split(separator_char))
values: List[str] = [val for val in value.split(separator_char) if val]
for value in values:
value = escaper.escape(value)
result += f"@{field_name}:{{{value}}}"
Expand Down Expand Up @@ -1131,7 +1131,7 @@ async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
raise NotImplementedError

async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
if pipeline is None:
if not pipeline:
db = self.db()
else:
db = pipeline
Expand Down Expand Up @@ -1195,12 +1195,12 @@ def to_string(s):
step = 2 # Because the result has content
offset = 1 # The first item is the count of total matches.

for i in xrange(1, len(res), step):
for i in range(1, len(res), step):
fields_offset = offset

fields = dict(
dict(
izip(
zip(
map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]),
)
Expand Down Expand Up @@ -1244,7 +1244,7 @@ async def add(
pipeline: Optional[Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]:
if pipeline is None:
if not pipeline:
# By default, send commands in a pipeline. Saving each model will
# be atomic, but Redis may process other commands in between
# these saves.
Expand All @@ -1261,7 +1261,7 @@ async def add(

# If the user didn't give us a pipeline, then we need to execute
# the one we just created.
if pipeline is None:
if not pipeline:
result = await db.execute()
pipeline_verifier(result, expected_responses=len(models))

Expand Down Expand Up @@ -1303,7 +1303,7 @@ def __init_subclass__(cls, **kwargs):

async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
self.check()
if pipeline is None:
if not pipeline:
db = self.db()
else:
db = pipeline
Expand Down Expand Up @@ -1356,7 +1356,7 @@ def _get_value(cls, *args, **kwargs) -> Any:
values. Is there a better way?
"""
val = super()._get_value(*args, **kwargs)
if val is None:
if not val:
return ""
return val

Expand Down Expand Up @@ -1392,7 +1392,7 @@ def schema_for_fields(cls):
name, _type, field.field_info
)
schema_parts.append(redisearch_field)
elif getattr(field.field_info, "index", None) is True:
elif getattr(field.field_info, "index", None):
schema_parts.append(cls.schema_for_type(name, _type, field.field_info))
elif is_subscripted_type:
# Ignore subscripted types (usually containers!) that we don't
Expand Down Expand Up @@ -1437,7 +1437,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
schema = f"{name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, "full_text_search", False) is True:
if getattr(field_info, "full_text_search", False):
schema = (
f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} "
f"{name} AS {name}_fts TEXT"
Expand All @@ -1455,7 +1455,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
schema = " ".join(sub_fields)
else:
schema = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if schema and sortable is True:
if schema and sortable:
schema += " SORTABLE"
return schema

Expand All @@ -1475,7 +1475,7 @@ def __init__(self, *args, **kwargs):

async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
self.check()
if pipeline is None:
if not pipeline:
db = self.db()
else:
db = pipeline
Expand Down Expand Up @@ -1633,7 +1633,7 @@ def schema_for_type(
parent_type=typ,
)
)
return " ".join(filter(None, sub_fields))
return " ".join([sub_field for sub_field in sub_fields if sub_field])
# NOTE: This is the termination point for recursion. We've descended
# into models and lists until we found an actual value to index.
elif should_index:
Expand All @@ -1655,28 +1655,28 @@ def schema_for_type(

# TODO: GEO field
if parent_is_container_type or parent_is_model_in_container:
if typ is not str:
if not isinstance(typ, str):
raise RedisModelError(
"In this Preview release, list and tuple fields can only "
f"contain strings. Problem field: {name}. See docs: TODO"
)
if full_text_search is True:
if full_text_search:
raise RedisModelError(
"List and tuple fields cannot be indexed for full-text "
f"search. Problem field: {name}. See docs: TODO"
)
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if sortable is True:
if sortable:
raise sortable_tag_error
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
schema = f"{path} AS {index_field_name} NUMERIC"
elif issubclass(typ, str):
if full_text_search is True:
if full_text_search:
schema = (
f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} "
f"{path} AS {index_field_name}_fts TEXT"
)
if sortable is True:
if sortable:
# NOTE: With the current preview release, making a field
# full-text searchable and sortable only makes the TEXT
# field sortable. This means that results for full-text
Expand All @@ -1685,11 +1685,11 @@ def schema_for_type(
schema += " SORTABLE"
else:
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if sortable is True:
if sortable:
raise sortable_tag_error
else:
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
if sortable is True:
if sortable:
raise sortable_tag_error
return schema
return ""
Expand Down
18 changes: 8 additions & 10 deletions aredis_om/model/render_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def render_tree(
write to a StringIO buffer, then use that buffer to accumulate written lines
during recursive calls to render_tree().
"""
if buffer is None:
if not buffer:
buffer = io.StringIO()
if hasattr(current_node, nameattr):
name = lambda node: getattr(node, nameattr) # noqa: E731
Expand All @@ -31,11 +31,9 @@ def render_tree(
up = getattr(current_node, left_child, None)
down = getattr(current_node, right_child, None)

if up is not None:
if up:
next_last = "up"
next_indent = "{0}{1}{2}".format(
indent, " " if "up" in last else "|", " " * len(str(name(current_node)))
)
next_indent = f'{indent}{" " if "up" in last else "|"}{" " * len(str(name(current_node)))}'
render_tree(
up, nameattr, left_child, right_child, next_indent, next_last, buffer
)
Expand All @@ -49,7 +47,7 @@ def render_tree(
else:
start_shape = "├"

if up is not None and down is not None:
if up and down:
end_shape = "┤"
elif up:
end_shape = "┘"
Expand All @@ -59,14 +57,14 @@ def render_tree(
end_shape = ""

print(
"{0}{1}{2}{3}".format(indent, start_shape, name(current_node), end_shape),
f"{indent}{start_shape}{name(current_node)}{end_shape}",
file=buffer,
)

if down is not None:
if down:
next_last = "down"
next_indent = "{0}{1}{2}".format(
indent, " " if "down" in last else "|", " " * len(str(name(current_node)))
next_indent = (
f'{indent}{" " if "down" in last else "|"}{len(str(name(current_node)))}'
)
render_tree(
down, nameattr, left_child, right_child, next_indent, next_last, buffer
Expand Down
2 changes: 1 addition & 1 deletion aredis_om/unasync_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def f():
return None

obj = f()
if obj is None:
if not obj:
return False
else:
obj.close() # prevent unawaited coroutine warning
Expand Down
2 changes: 1 addition & 1 deletion tests/test_oss_redis_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def members(m):
async def test_all_keys(members, m):
pks = sorted([pk async for pk in await m.Member.all_pks()])
assert len(pks) == 3
assert pks == sorted([m.pk for m in members])
assert pks == sorted(m.pk for m in members)


@py_test_mark_asyncio
Expand Down

0 comments on commit dfc7898

Please sign in to comment.