In [1]:
from pymilvus import MilvusClient, DataType,db,connections, AnnSearchRequest
import os
import pandas as pd

In [2]:
model = None
def get_embeddings(queries):
    global model
    from FlagEmbedding import BGEM3FlagModel
    if model is None:
        model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
        print("Model loaded.")
    embeddings = model.encode(queries, return_dense=True, return_sparse=True, return_colbert_vecs=False)
    dense_vectors:list=embeddings['dense_vecs']
    lexical_weights:list=embeddings['lexical_weights']
    return {"dense_vectors":dense_vectors,"sparse_vectors":lexical_weights}

In [3]:
# CLUSTER_DOMAIN="host.docker.internal"
CLUSTER_DOMAIN = "localhost"
POART = 19530
CLUSTER_ENDPOINT = f"http://{CLUSTER_DOMAIN}:{POART}"
DATABASE_NAME = "HSBC"
COLLECTION_NAME = "banks_earnings_calls"
DEFAULT_EMBEDDING_MODEL_NAME = 'BAAI/bge-m3'
conn = connections.connect(host=CLUSTER_DOMAIN, port=POART)
if DATABASE_NAME not in db.list_database():
    db.create_database(DATABASE_NAME)
client = MilvusClient(
    db_name=DATABASE_NAME,
    uri=CLUSTER_ENDPOINT,
    user="developers",
    password="developers",
)

In [4]:
schema = MilvusClient.create_schema(auto_id=True,enable_dynamic_field=True,)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=1024)
schema.add_field(field_name="hash", datatype=DataType.VARCHAR, max_length=32)
schema.add_field(field_name="bank", datatype=DataType.VARCHAR, max_length=65535)
# schema.add_field(field_name="content", datatype=DataType.JSON)
schema.add_field(field_name="approximate_tokens_size", datatype=DataType.INT32)


index_params = client.prepare_index_params()
index_params.add_index(index_name="id_index", field_name="id", index_type="STL_SORT")
index_params.add_index(index_name="hash_index", field_name="hash", index_type="Trie")
index_params.add_index(index_name="bank_index", field_name="bank", index_type="Trie")
index_params.add_index(
    field_name="dense_vector", 
    index_type="FLAT",
    metric_type="IP"
)

if client.has_collection(COLLECTION_NAME):
    client.drop_collection(COLLECTION_NAME)
client.create_collection(
    collection_name=COLLECTION_NAME,
    schema=schema,
    index_params=index_params,
    consistency_level="Strong"
)

In [5]:
df_path=os.path.join("dataset","entities_transcripts.jsonl")
df=pd.read_json(df_path,lines=True)
def insert_into_milvus():
    for idx,row in df.iterrows():
        data=row.to_dict()
        content=row["chunk"]
        summary=row["summary"]
        full_summary=row["full_summary"]
        pre_vectorize_text=full_summary.replace(summary,content)
        embeddings=get_embeddings([pre_vectorize_text])
        dense_vector=embeddings['dense_vectors'][0]
        dense_vector=dense_vector.tolist()
        data["dense_vector"]=dense_vector
        data["content"]=row["chunk"]
        # data.pop("chunk")
        # print(data)
        client.insert(collection_name=COLLECTION_NAME, data=[data])

        
        pre_vectorize_text=content
        embeddings=get_embeddings([pre_vectorize_text])
        client.insert(collection_name=COLLECTION_NAME, data=[data])

        # questions=row["questions"]
        # for question in questions:
        #     question_data=data.copy()
        #     question_data["content"]=question
        #     embeddings=get_embeddings([question])
        #     dense_vector=embeddings['dense_vectors'][0]
        #     dense_vector=dense_vector.tolist()
        #     question_data["dense_vector"]=dense_vector
        #     client.insert(collection_name=COLLECTION_NAME, data=[question_data])
        if idx%100==0:
            print(f"Inserted {idx} rows")
    pass
insert_into_milvus()

  from .autonotebook import tqdm as notebook_tqdm
Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 14367.34it/s]
  colbert_state_dict = torch.load(os.path.join(model_dir, 'colbert_linear.pt'), map_location='cpu')
  sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')


Model loaded.
Inserted 0 rows
Inserted 100 rows


In [6]:
import hashlib
df_path=os.path.join("dataset","hardcode_cyphers_questions.jsonl")
df=pd.read_json(df_path,lines=True)
def insert_hardcoded_questions():
    for idx,row in df.iterrows():
        data=row.to_dict()
        pre_vectorize_text=row["question"]
        embeddings=get_embeddings([pre_vectorize_text])
        dense_vector=embeddings['dense_vectors'][0]
        dense_vector=dense_vector.tolist()
        data["dense_vector"]=dense_vector
        data["content"]=row["question"]
        data["cypher"]=row["cypher"]
        data["hash"]=hashlib.md5(str(row["cypher"]).encode()).hexdigest()
        data["bank"]=""
        data["approximate_tokens_size"]=-1
        # data.pop("chunk")
        # print(data)
        client.insert(collection_name=COLLECTION_NAME, data=[data])
        if idx%100==0:
            print(f"Inserted {idx} rows")
    pass
# insert_hardcoded_questions()

In [7]:
embeddings=get_embeddings(["How does HSBC perform so far?"])
dense_vectors=embeddings['dense_vectors'].tolist()
dense_search_params = {"metric_type": "IP"}
res = client.search(COLLECTION_NAME, data=dense_vectors, search_params=dense_search_params, output_fields=["id","hash","bank","content","chunk"],limit=100)
import json
result = json.dumps(res)
pd.DataFrame(res).to_json("search_result.json",orient="records",force_ascii=False)

In [8]:
embeddings=get_embeddings(["Who went to multiple banks' earnings calls event?  Give me a short and simple answer."])
dense_vectors=embeddings['dense_vectors'].tolist()
dense_search_params = {"metric_type": "IP"}
res = client.search(COLLECTION_NAME, data=dense_vectors, search_params=dense_search_params, output_fields=["id","hash","bank","content","chunk","cypher"],limit=10)
import json
result = json.dumps(res)
pd.DataFrame(res).to_json("search_result.json",orient="records",force_ascii=False)