Skip to content

Commit

Permalink
update save to index.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Apr 16, 2023
1 parent d2405d9 commit e666de1
Show file tree
Hide file tree
Showing 16 changed files with 199 additions and 335 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[![PyPI version](https://badge.fury.io/py/similarities.svg)](https://badge.fury.io/py/similarities)
[![Downloads](https://pepy.tech/badge/similarities)](https://pepy.tech/project/similarities)
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
[![GitHub contributors](https://img.shields.io/github/contributors/shibing624/similarities.svg)](https://github.com/shibing624/similarities/graphs/contributors)
[![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
[![python_version](https://img.shields.io/badge/Python-3.5%2B-green.svg)](requirements.txt)
[![GitHub issues](https://img.shields.io/github/issues/shibing624/similarities.svg)](https://github.com/shibing624/similarities/issues)
Expand Down
4 changes: 2 additions & 2 deletions examples/base_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
model.add_corpus(corpus)
res = model.most_similar(queries=sentences, topn=3)
print(res)
for q_id, c in res.items():
for q_id, id_score_dict in res.items():
print('query:', sentences[q_id])
print("search top 3:")
for corpus_id, s in c.items():
for corpus_id, s in id_score_dict.items():
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')
2 changes: 2 additions & 0 deletions examples/base_english_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
]

model.add_corpus(corpus)
model.save_index('en_corpus_emb.json')
model.load_index('en_corpus_emb.json')
res = model.most_similar(queries=sentences1, topn=3)
print(res)
for q_id, c in res.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
"""
import datetime
import os
import pathlib
import random
import sys

from loguru import logger

sys.path.append('../..')
from similarities import BM25Similarity
from similarities.utils import http_get
from similarities.utils.get_file import http_get
from similarities.data_loader import SearchDataLoader
from similarities.evaluation import evaluate

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

sys.path.append('../..')
from similarities import Similarity
from similarities.utils import http_get
from similarities.utils.get_file import http_get
from similarities.data_loader import SearchDataLoader
from similarities.evaluation import evaluate

Expand Down
28 changes: 18 additions & 10 deletions examples/fast_sim_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@
]


def hnswlib_demo():
def annoy_demo():
corpus_new = [i + str(id) for id, i in enumerate(corpus * 10)]
print(corpus_new)
model = HnswlibSimilarity(corpus=corpus_new)
model = AnnoySimilarity(corpus=corpus_new)
print(model)
similarity_score = model.similarity(sentences[0], sentences[1])
print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}")
model.add_corpus(corpus)
model.build_index()
model.save_index('test.model')
model.save_index('annoy_model.bin')
print(model.most_similar("men喜欢这首歌"))
# Semantic Search batch
del model
model = AnnoySimilarity()
model.load_index('annoy_model.bin')
print(model.most_similar("men喜欢这首歌"))
queries = ["如何更换花呗绑定银行卡", "men喜欢这首歌"]
res = model.most_similar(queries, topn=3)
Expand All @@ -44,20 +47,25 @@ def hnswlib_demo():
for corpus_id, s in c.items():
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')

os.remove('test.model')
# os.remove('annoy_model.bin')
print('-' * 50 + '\n')


def annoy_demo():
def hnswlib_demo():
corpus_new = [i + str(id) for id, i in enumerate(corpus * 10)]
model = AnnoySimilarity(corpus=corpus_new)
print(corpus_new)
model = HnswlibSimilarity(corpus=corpus_new)
print(model)
similarity_score = model.similarity(sentences[0], sentences[1])
print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}")
model.add_corpus(corpus)
model.build_index()
model.save_index('test.model')
model.save_index('hnsw_model.bin')
print(model.most_similar("men喜欢这首歌"))
# Semantic Search batch
del model
model = HnswlibSimilarity()
model.load_index('hnsw_model.bin')
print(model.most_similar("men喜欢这首歌"))
queries = ["如何更换花呗绑定银行卡", "men喜欢这首歌"]
res = model.most_similar(queries, topn=3)
Expand All @@ -68,10 +76,10 @@ def annoy_demo():
for corpus_id, s in c.items():
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')

os.remove('test.model')
# os.remove('hnsw_model.bin')
print('-' * 50 + '\n')


if __name__ == '__main__':
hnswlib_demo()
annoy_demo()
hnswlib_demo()
6 changes: 1 addition & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from setuptools import setup, find_packages

# Avoids IDE errors, but actual version is read from version.py
__version__ = None
__version__ = ""
exec(open('similarities/version.py').read())

if sys.version_info < (3,):
Expand Down Expand Up @@ -33,10 +33,6 @@
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
keywords='similarities,Chinese Text Similarity Calculation Tool,similarity,word2vec',
Expand Down
92 changes: 57 additions & 35 deletions similarities/fastsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ def __init__(
self,
corpus: Union[List[str], Dict[str, str]] = None,
model_name_or_path="shibing624/text2vec-base-chinese",
embedding_size: int = 768,
n_trees: int = 256
):
super().__init__(corpus, model_name_or_path)
self.index = None
self.embedding_size = embedding_size
self.embedding_size = self.get_sentence_embedding_dimension()
self.n_trees = n_trees
if corpus is not None and self.corpus_embeddings:
self.build_index()
Expand All @@ -35,47 +34,58 @@ def __str__(self):
base += f", corpus size: {len(self.corpus)}"
return base

def build_index(self):
"""Build Annoy index after add new documents."""
# Create Annoy Index
def create_index(self):
"""Create Annoy Index."""
try:
from annoy import AnnoyIndex
except ImportError:
raise ImportError("Annoy is not installed. Please install it first, e.g. with `pip install annoy`.")

# Creating the annoy index
self.index = AnnoyIndex(self.embedding_size, 'angular')
logger.debug(f"Init Annoy index, embedding_size: {self.embedding_size}")

logger.info(f"Init Annoy index, embedding_size: {self.embedding_size}")
def build_index(self):
"""Build Annoy index after add new documents."""
self.create_index()
logger.debug(f"Building index with {self.n_trees} trees.")

for i in range(len(self.corpus_embeddings)):
self.index.add_item(i, self.corpus_embeddings[i])
self.index.build(self.n_trees)

def save_index(self, index_path: str):
def save_index(self, index_path: str = "annoy_index.bin"):
"""Save the annoy index to disk."""
if self.index and index_path:
logger.info(f"Saving index to: {index_path}")
if index_path:
if self.index is None:
self.build_index()
self.index.save(index_path)
corpus_emb_json_path = index_path + ".json"
super().save_index(corpus_emb_json_path)
logger.info(f"Saving Annoy index to: {index_path}, corpus embedding to: {corpus_emb_json_path}")
else:
logger.warning("No index path given. Index not saved.")

def load_index(self, index_path: str):
def load_index(self, index_path: str = "annoy_index.bin"):
"""Load Annoy Index from disc."""
if index_path and os.path.exists(index_path):
logger.info(f"Loading index from: {index_path}")
corpus_emb_json_path = index_path + ".json"
logger.info(f"Loading index from: {index_path}, corpus embedding from: {corpus_emb_json_path}")
super().load_index(corpus_emb_json_path)
if self.index is None:
self.create_index()
self.index.load(index_path)
else:
logger.warning("No index path given. Index not loaded.")

def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10):
def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10,
score_function: str = "cos_sim"):
"""Find the topn most similar texts to the query against the corpus."""
result = {}
if self.corpus_embeddings and self.index is None:
logger.warning(f"No index found. Please add corpus and build index first, e.g. with `build_index()`."
f"Now returning slow search result.")
return super().most_similar(queries, topn)
return super().most_similar(queries, topn, score_function=score_function)
if not self.corpus_embeddings:
logger.error("No corpus_embeddings found. Please add corpus first, e.g. with `add_corpus()`.")
return result
Expand All @@ -91,7 +101,7 @@ def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int
corpus_ids, distances = self.index.get_nns_by_vector(queries_embeddings[idx], topn, include_distances=True)
for corpus_id, distance in zip(corpus_ids, distances):
score = 1 - (distance ** 2) / 2
result[qid][self.corpus_ids_map[corpus_id]] = score
result[qid][corpus_id] = score

return result

Expand All @@ -106,10 +116,10 @@ def __init__(
self,
corpus: Union[List[str], Dict[str, str]] = None,
model_name_or_path="shibing624/text2vec-base-chinese",
embedding_size: int = 768, ef_construction: int = 400, M: int = 64, ef: int = 50
ef_construction: int = 400, M: int = 64, ef: int = 50
):
super().__init__(corpus, model_name_or_path)
self.embedding_size = embedding_size
self.embedding_size = self.get_sentence_embedding_dimension()
self.ef_construction = ef_construction
self.M = M
self.ef = ef
Expand All @@ -123,52 +133,64 @@ def __str__(self):
base += f", corpus size: {len(self.corpus)}"
return base

def build_index(self):
"""Build Hnswlib index after add new documents."""
# Create hnswlib Index
def create_index(self):
"""Create Hnswlib Index."""
try:
import hnswlib
except ImportError:
raise ImportError("Hnswlib is not installed. Please install it first, e.g. with `pip install hnswlib`.")

# We use Inner Product (dot-product) as Index. We will normalize our vectors to unit length,
# then is Inner Product equal to cosine similarity
# Creating the hnswlib index
self.index = hnswlib.Index(space='cosine', dim=self.embedding_size)
self.index.init_index(max_elements=len(self.corpus_embeddings), ef_construction=self.ef_construction, M=self.M)
# Controlling the recall by setting ef:
self.index.set_ef(self.ef) # ef should always be > top_k_hits
logger.debug(f"Init Hnswlib index, embedding_size: {self.embedding_size}")

def build_index(self):
"""Build Hnswlib index after add new documents."""
# Init the HNSWLIB index
logger.info(f"Creating HNSWLIB index, max_elements: {len(self.corpus)}")
self.create_index()
logger.info(f"Building HNSWLIB index, max_elements: {len(self.corpus)}")
logger.debug(f"Parameters Required: M: {self.M}")
logger.debug(f"Parameters Required: ef_construction: {self.ef_construction}")
logger.debug(f"Parameters Required: ef(>topn): {self.ef}")

self.index.init_index(max_elements=len(self.corpus_embeddings), ef_construction=self.ef_construction, M=self.M)
# Then we train the index to find a suitable clustering
self.index.add_items(self.corpus_embeddings, list(range(len(self.corpus_embeddings))))
# Controlling the recall by setting ef:
self.index.set_ef(self.ef) # ef should always be > top_k_hits

def save_index(self, index_path: str):
"""Save the annoy index to disk."""
if self.index and index_path:
logger.info(f"Saving index to: {index_path}")
def save_index(self, index_path: str = "hnswlib_index.bin"):
"""Save the index to disk."""
if index_path:
if self.index is None:
self.build_index()
self.index.save_index(index_path)
corpus_emb_json_path = index_path + ".json"
super().save_index(corpus_emb_json_path)
logger.info(f"Saving hnswlib index to: {index_path}, corpus embedding to: {corpus_emb_json_path}")
else:
logger.warning("No index path given. Index not saved.")

def load_index(self, index_path: str):
"""Load Annoy Index from disc."""
def load_index(self, index_path: str = "hnswlib_index.bin"):
"""Load Index from disc."""
if index_path and os.path.exists(index_path):
logger.info(f"Loading index from: {index_path}")
corpus_emb_json_path = index_path + ".json"
logger.info(f"Loading index from: {index_path}, corpus embedding from: {corpus_emb_json_path}")
super().load_index(corpus_emb_json_path)
if self.index is None:
self.create_index()
self.index.load_index(index_path)
else:
logger.warning("No index path given. Index not loaded.")

def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10):
def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10,
score_function: str = "cos_sim"):
"""Find the topn most similar texts to the query against the corpus."""
result = {}
if self.corpus_embeddings and self.index is None:
logger.warning(f"No index found. Please add corpus and build index first, e.g. with `build_index()`."
f"Now returning slow search result.")
return super().most_similar(queries, topn)
return super().most_similar(queries, topn, score_function=score_function)
if not self.corpus_embeddings:
logger.error("No corpus_embeddings found. Please add corpus first, e.g. with `add_corpus()`.")
return result
Expand All @@ -186,6 +208,6 @@ def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int
hits = [{'corpus_id': id, 'score': 1 - distance} for id, distance in zip(corpus_ids[i], distances[i])]
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
for hit in hits:
result[qid][self.corpus_ids_map[hit['corpus_id']]] = hit['score']
result[qid][hit['corpus_id']] = hit['score']

return result
Loading

0 comments on commit e666de1

Please sign in to comment.