Skip to content

Commit

Permalink
fix docstrings in embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
azliu0 committed Apr 15, 2024
1 parent b9a99f7 commit a87efbf
Showing 1 changed file with 80 additions and 91 deletions.
171 changes: 80 additions & 91 deletions server/nlp/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Embeddings."""

import os
import time

Expand Down Expand Up @@ -31,17 +33,13 @@


def load_corpus(corpus: list[RedisDocument]):
"""Loads given corpus into redis
"""Loads given corpus into redis.
PARAMETERS
----------
corpus : :obj:`list` of :obj:`RedisDocument`
list of documents, each represented by dictionary
Args:
corpus: list of documents, each represented by dictionary
Raises:
------
Exception
if failed to load corpus into redis
exception: if failed to load corpus into redis
"""
print("loading corpus...")

Expand All @@ -57,6 +55,14 @@ def load_corpus(corpus: list[RedisDocument]):


def compute_openai_embeddings(texts):
"""Compute embeddings from texts using OpenAI.
Args:
texts: list of texts to embed
Returns:
list of embeddings
"""
embeddings = []
for i in range(len(texts)):
embeddings.append(
Expand All @@ -68,7 +74,7 @@ def compute_openai_embeddings(texts):


def compute_embeddings():
"""Compute embeddings from redis documents"""
"""Compute embeddings from redis documents."""
print("computing embeddings...")

# get keys, questions, content
Expand Down Expand Up @@ -97,17 +103,14 @@ def compute_embeddings():


def load_embeddings(embeddings: list[list[float]]):
"""Load embeddings into redis
"""Load embeddings into redis.
PARAMETERS
----------
embeddings : :obj:`list` of :obj:`list` of :obj:`float`
list of embeddings
Args:
embeddings:
list of embeddings
Raises:
------
Exception
if failed to load embeddings into redis
exception: if failed to load embeddings into redis
"""
print("loading embeddings into redis...")

Expand All @@ -125,18 +128,16 @@ def load_embeddings(embeddings: list[list[float]]):


def create_index(corpus_len: int):
"""Create search index in redis
assumes that documents and embeddings have already been loaded into redis
"""Create search index in redis.
PARAMETERS
----------
corpus_len : :obj:`int`
number of documents in corpus
Assumes that documents and embeddings have already been loaded into redis
Args:
corpus_len:
number of documents in corpus
Raises:
------
Exception
if failed to create index
exception: if failed to create index
"""
print("creating index...")

Expand Down Expand Up @@ -178,12 +179,13 @@ def create_index(corpus_len: int):


def create_query(k: int):
"""Create k-NN redis query
"""Create k-NN redis query.
Args:
k: number of nearest neighbors to return
PARAMETERS
----------
k : :obj:`int`
number of nearest neighbors to return
Returns:
redis query object
"""
return (
Query(f"(*)=>[KNN {k} @vector $query_vector AS vector_score]")
Expand All @@ -194,18 +196,13 @@ def create_query(k: int):


def queries(query, queries: list[str]) -> list[dict]:
"""Run queries against redis
"""Run queries against redis.
PARAMETERS
----------
query : :obj:`Query`
redis query object
queries : :obj:`list` of :obj:`str`
list of question queries
Args:
query: redis query object
queries: list of question queries
Returns:
-------
:obj:`list` of :obj:`dict`
list of dictionaries containing query and result
"""
print("running queries...")
Expand Down Expand Up @@ -243,36 +240,27 @@ def queries(query, queries: list[str]) -> list[dict]:


def query_all(k: int, questions: list[str]):
"""Return k most similar documents for each query
"""Return k most similar documents for each query.
PARAMETERS
----------
k : :obj:`int`
number of nearest neighbors to return
questions : :obj:`list` of :obj:`str`
list of question queries
Args:
k: number of nearest neighbors to return
questions: list of question queries
Returns:
-------
:obj:`list` of :obj:`dict`
list of dictionaries containing query and result
"""
redis_query = create_query(k)
return queries(redis_query, questions)


def embed_corpus(corpus: list[RedisDocument]):
"""Load corpus, compute embeddings, load embeddings into redis
"""Load corpus, compute embeddings, load embeddings into redis.
PARAMETERS
----------
corpus : :obj:`list` of :obj:`dict`
list of documents, each represented by dictionary
Args:
corpus: list of documents, each represented by dictionary
Raises:
------
Exception
if failed to load corpus
exception: if failed to load corpus
"""
# flush database
print("cleaning database...")
Expand All @@ -288,37 +276,38 @@ def embed_corpus(corpus: list[RedisDocument]):
create_index(len(corpus))


def test():
try:
embed_corpus()
except Exception as err:
print(f"Unexpected {err=}, {type(err)=}")
raise

questions = [
"What is the deadline to apply for the hackathon?",
"When is HackMIT?",
"What are the challenges?",
"How does judging work?",
"What building should I go to during the event?",
"What prizes are available?",
"How many people are allowed on a team?",
"What is HackMIT?",
"Can I attend HackMIT if I am an MIT grad student?",
"Can I attend HackMIT if I am a sophomore in high school?",
"I'm a high school student, but I'm really advanced. Can I attend HackMIT?",
"Do I need to bring money to the event?",
"Will we be able to sleep at the event?",
"Will we be able to stay overnight at the event?",
"What should I do if I am a beginner at the event?",
]
results = query_all(3, questions)

for result in results:
print(result["query"])
for doc in result["result"]:
print(f"Score: {doc['score']}")
print(f"Source: {doc['source']}")
print(f"Q: {doc['question']}")
print(f"A: {doc['content']}")
print()
# TODO(azliu): turn this into a test case
# def test():
# try:
# embed_corpus()
# except Exception as err:
# print(f"Unexpected {err=}, {type(err)=}")
# raise

# questions = [
# "What is the deadline to apply for the hackathon?",
# "When is HackMIT?",
# "What are the challenges?",
# "How does judging work?",
# "What building should I go to during the event?",
# "What prizes are available?",
# "How many people are allowed on a team?",
# "What is HackMIT?",
# "Can I attend HackMIT if I am an MIT grad student?",
# "Can I attend HackMIT if I am a sophomore in high school?",
# "I'm a high school student, but I'm really advanced. Can I attend HackMIT?",
# "Do I need to bring money to the event?",
# "Will we be able to sleep at the event?",
# "Will we be able to stay overnight at the event?",
# "What should I do if I am a beginner at the event?",
# ]
# results = query_all(3, questions)

# for result in results:
# print(result["query"])
# for doc in result["result"]:
# print(f"Score: {doc['score']}")
# print(f"Source: {doc['source']}")
# print(f"Q: {doc['question']}")
# print(f"A: {doc['content']}")
# print()

0 comments on commit a87efbf

Please sign in to comment.