Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions redisvl/cli/index.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
import sys
from tabulate import tabulate
from argparse import Namespace

from tabulate import tabulate

from redisvl.cli.log import get_logger
from redisvl.cli.utils import create_redis_url, add_index_parsing_options
from redisvl.cli.utils import add_index_parsing_options, create_redis_url
from redisvl.index import SearchIndex
from redisvl.utils.connection import get_redis_connection
from redisvl.utils.utils import convert_bytes, make_dict
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self):
"--format",
help="Output format for info command",
type=str,
default="rounded_outline"
default="rounded_outline",
)
parser = add_index_parsing_options(parser)

Expand Down Expand Up @@ -126,6 +126,7 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:

return index


def _display_in_table(index_info, output_format="rounded_outline"):
print("\n")
attributes = index_info.get("attributes", [])
Expand Down Expand Up @@ -183,4 +184,4 @@ def _display_in_table(index_info, output_format="rounded_outline"):
headers=headers,
tablefmt=output_format,
)
)
)
4 changes: 1 addition & 3 deletions redisvl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

from redisvl.cli.index import Index
from redisvl.cli.log import get_logger
from redisvl.cli.version import Version
from redisvl.cli.stats import Stats

from redisvl.cli.version import Version

logger = get_logger(__name__)

Expand Down Expand Up @@ -50,4 +49,3 @@ def version(self):
def stats(self):
Stats()
exit(0)

15 changes: 6 additions & 9 deletions redisvl/cli/stats.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import argparse
import sys
from tabulate import tabulate
from argparse import Namespace

from redisvl.cli.utils import create_redis_url, add_index_parsing_options
from tabulate import tabulate

from redisvl.cli.log import get_logger
from redisvl.cli.utils import add_index_parsing_options, create_redis_url
from redisvl.index import SearchIndex
from redisvl.utils.connection import get_redis_connection

from redisvl.cli.log import get_logger
logger = get_logger("[RedisVL]")

STATS_KEYS = [
Expand All @@ -32,6 +33,7 @@
"vector_index_sz_mb",
]


class Stats:
usage = "\n".join(
[
Expand All @@ -43,11 +45,7 @@ def __init__(self):
parser = argparse.ArgumentParser(usage=self.usage)

parser.add_argument(
"-f",
"--format",
help="Output format",
type=str,
default="rounded_outline"
"-f", "--format", help="Output format", type=str, default="rounded_outline"
)
parser = add_index_parsing_options(parser)
args = parser.parse_args(sys.argv[2:])
Expand All @@ -57,7 +55,6 @@ def __init__(self):
logger.error(e)
exit(0)


def stats(self, args: Namespace):
"""Obtain stats about an index

Expand Down
15 changes: 5 additions & 10 deletions redisvl/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from argparse import Namespace, ArgumentParser
from argparse import ArgumentParser, Namespace


def create_redis_url(args: Namespace) -> str:
Expand All @@ -18,20 +18,15 @@ def create_redis_url(args: Namespace) -> str:
url += args.host + ":" + str(args.port)
return url


def add_index_parsing_options(parser: ArgumentParser) -> ArgumentParser:
parser.add_argument(
"-i", "--index", help="Index name", type=str, required=False
)
parser.add_argument("-i", "--index", help="Index name", type=str, required=False)
parser.add_argument(
"-s", "--schema", help="Path to schema file", type=str, required=False
)
parser.add_argument("--host", help="Redis host", type=str, default="localhost")
parser.add_argument("-p", "--port", help="Redis port", type=int, default=6379)
parser.add_argument(
"--user", help="Redis username", type=str, default="default"
)
parser.add_argument("--user", help="Redis username", type=str, default="default")
parser.add_argument("--ssl", help="Use SSL", action="store_true")
parser.add_argument(
"-a", "--password", help="Redis password", type=str, default=""
)
parser.add_argument("-a", "--password", help="Redis password", type=str, default="")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will eventually need to add TLS support here in the form of a path to the server CA, and client/user certs (public and private) too. With the bedrock launch coming, it would be great to push users here for an easier onboarding :)

return parser
7 changes: 5 additions & 2 deletions redisvl/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from redisvl.query.query import VectorQuery
from redisvl.query.query import FilterQuery, VectorQuery

__all__ = ["VectorQuery"]
__all__ = [
"VectorQuery",
"FilterQuery",
]
7 changes: 6 additions & 1 deletion redisvl/query/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class Geo(FilterField):
field in a Redis index.

"""

OPERATORS = {
FilterOperator.EQ: "==",
FilterOperator.NE: "!=",
Expand Down Expand Up @@ -174,12 +175,13 @@ def __init__(self, longitude: float, latitude: float, unit: str = "km"):

class GeoRadius(GeoSpec):
"""A GeoRadius is a GeoSpec representing a geographic radius"""

def __init__(
self,
longitude: float,
latitude: float,
radius: Optional[int] = 1,
unit: Optional[str] = "km"
unit: Optional[str] = "km",
):
"""Create a GeoRadius specification (GeoSpec)

Expand All @@ -202,6 +204,7 @@ def get_args(self) -> List[Union[float, int, str]]:

class Num(FilterField):
"""A Num is a FilterField representing a numeric field in a Redis index."""

OPERATORS = {
FilterOperator.EQ: "==",
FilterOperator.NE: "!=",
Expand Down Expand Up @@ -311,6 +314,7 @@ def __le__(self, other: str) -> "FilterExpression":

class Text(FilterField):
"""A Text is a FilterField representing a text field in a Redis index."""

OPERATORS = {
FilterOperator.EQ: "==",
FilterOperator.NE: "!=",
Expand Down Expand Up @@ -399,6 +403,7 @@ class FilterExpression:
... filter_expression=filter,
... )
"""

def __init__(
self,
_filter: str = None,
Expand Down
82 changes: 82 additions & 0 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,88 @@ def params(self) -> Dict[str, Any]:
pass


class FilterQuery(BaseQuery):
def __init__(
self,
return_fields: List[str],
filter_expression: FilterExpression,
num_results: Optional[int] = 10,
params: Optional[Dict[str, Any]] = None,
):
"""Query for a filter expression.

Args:
return_fields (List[str]): The fields to return.
filter_expression (FilterExpression): The filter expression to query for.
num_results (Optional[int], optional): The number of results to return. Defaults to 10.
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.

Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression

Examples:
>>> from redisvl.query import FilterQuery
>>> from redisvl.query.filter import Tag
>>> t = Tag("brand") == "Nike"
>>> q = FilterQuery(return_fields=["brand", "price"], filter_expression=t)
"""

super().__init__(return_fields, num_results)
self.set_filter(filter_expression)
self._params = params

def __str__(self) -> str:
return " ".join([str(x) for x in self.query.get_args()])

def set_filter(self, filter_expression: FilterExpression):
"""Set the filter for the query.

Args:
filter_expression (FilterExpression): The filter to apply to the query.

Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
"""
if not isinstance(filter_expression, FilterExpression):
raise TypeError(
"filter_expression must be of type redisvl.query.FilterExpression"
)
self._filter = str(filter_expression)

def get_filter(self) -> FilterExpression:
"""Get the filter for the query.

Returns:
FilterExpression: The filter for the query.
"""
return self._filter

@property
def query(self) -> Query:
"""Return a Redis-Py Query object representing the query.

Returns:
redis.commands.search.query.Query: The query object.
"""
base_query = str(self._filter)
query = (
Query(base_query)
.return_fields(*self._return_fields)
.paging(0, self._num_results)
.dialect(2)
)
return query

@property
def params(self) -> Dict[str, Any]:
"""Return the parameters for the query.

Returns:
Dict[str, Any]: The parameters for the query.
"""
return self._params


class VectorQuery(BaseQuery):
dtypes = {
"float32": np.float32,
Expand Down
2 changes: 1 addition & 1 deletion redisvl/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,4 @@ def read_schema(file_path: str):
with open(fp, "r") as f:
schema = yaml.safe_load(f)

return SchemaModel(**schema)
return SchemaModel(**schema)
6 changes: 3 additions & 3 deletions redisvl/vectorize/text/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def embed_many(
TypeError: If the wrong input type is passed in for the test.
"""
if not isinstance(texts, list):
raise TypeError("Must pass in a list of str values to embed.")
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")
raise TypeError("Must pass in a list of str values to embed.")
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down
18 changes: 9 additions & 9 deletions redisvl/vectorize/text/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class OpenAITextVectorizer(BaseVectorizer):
API key to be passed in the api_config dictionary. The API key can be obtained from
https://api.openai.com/.
"""

def __init__(self, model: str, api_config: Optional[Dict] = None):
"""Initialize the OpenAI vectorizer.

Expand Down Expand Up @@ -45,14 +46,13 @@ def __init__(self, model: str, api_config: Optional[Dict] = None):
def _set_model_dims(self) -> int:
try:
embedding = self._model_client.create(
input=["dimension test"],
engine=self._model
input=["dimension test"], engine=self._model
)["data"][0]["embedding"]
except (KeyError, IndexError) as ke:
raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}")
except openai.error.AuthenticationError as ae:
raise ValueError(f"Error authenticating with the OpenAI API: {str(ae)}")
except Exception as e: # pylint: disable=broad-except
except Exception as e: # pylint: disable=broad-except
# fall back (TODO get more specific)
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
return len(embedding)
Expand Down Expand Up @@ -87,9 +87,9 @@ def embed_many(
TypeError: If the wrong input type is passed in for the test.
"""
if not isinstance(texts, list):
raise TypeError("Must pass in a list of str values to embed.")
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")
raise TypeError("Must pass in a list of str values to embed.")
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down Expand Up @@ -164,9 +164,9 @@ async def aembed_many(
TypeError: If the wrong input type is passed in for the test.
"""
if not isinstance(texts, list):
raise TypeError("Must pass in a list of str values to embed.")
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")
raise TypeError("Must pass in a list of str values to embed.")
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down
Loading