In [None]:
"""
In this demo:

1. we will connect to local hosted elasticsearch
2. we will extract 100 clinical notes from mimic-iv
3. we will index these notes into elasticsearch using as both text and embeddings
4. we will show how to do hybrid search using both embeddings and text features
"""

In [None]:
%load_ext autoreload
%autoreload 2

## get notes from mimic-iv

In [None]:
# you need to download and put notes to a path you can access
from pathlib import Path
import pandas as pd
import numpy as np
import re


mimic_iv_path = "/Users/alexgre/Downloads/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/note/"
p = Path(mimic_iv_path)
target_num = 100

In [None]:
df = pd.read_csv(p / "discharge.csv")

In [None]:
samples = []
for _, v in df.groupby("subject_id"):
    v.apply(lambda x: samples.append((x.subject_id, x.note_id, x.text)), axis=1)
    if len(samples) > target_num:
        break
        
# a sample is (sub_id, note_id, note_text)
len(samples)

In [None]:
ll = [len(e[-1].split()) for e in samples]
np.median(ll), np.mean(ll), np.min(ll), np.max(ll)

In [None]:
# convert into dict
# chunk notes to short paragraphs so we can better index them
index_samples = []

for sample in samples:
    sub_id, note_id, note_text = sample
    sample_text = re.sub("\n[ ]\n", "\n\n", note_text)
    sample_text = re.sub("\n{3,}", "\n\n", note_text)
    sample_chunks = [e.strip() for e in sample_text.split("\n\n")]
    
    for i, chunk in enumerate(sample_chunks):
        data = dict()
        data["note_id"] = note_id
        data["subject_id"] = sub_id
        data["_id"] = f"{note_id}_{i+1}"
        data["text"] = chunk
        index_samples.append(data)

In [None]:
index_samples[0]

##  index data

In [None]:
import warnings
warnings.filterwarnings("ignore")

# you need to change to your local elasticsearch config
ELASTICSEARCH_USER = "elastic"
ELASTICSEARCH_PASSWORD = "gQwpE6_wwMZ6oY-iNQcC"
ELASTICSEARCH_URL = "https://localhost:9200"

In [None]:
import sys
sys.path.append("../../GatorRAG/src/")

In [None]:
from gatorag.elasticsearch_engine import ElasticSearchEngine

In [None]:
index_name = "mimic_iv_sample_100"

engine = ElasticSearchEngine(index_name=index_name)

elastic_connect_info={
    "hosts": ELASTICSEARCH_URL,
    "verify_certs": False,
    "basic_auth": (ELASTICSEARCH_USER, ELASTICSEARCH_PASSWORD)
}

engine.set_client(elastic_connect_info)

In [None]:
bge_dim = len(engine.bge.get_bge_embedding_single_sample("this is a test"))
instructor_dim = len(engine.instructor.get_instructor_embeddings_single_sample("this is a test"))

mimic_iv_mapping = {
    "mappings": {
        "properties": {
            "subject_id": {"type": "keyword"},
            "note_id": {"type": "keyword"},
            "text": {"type": "text", "analyzer": "english"},
            "sparse_context": {
                "type": "rank_features",
                "positive_score_impact": True,
            },
            "instruct_emb": {
                "type": "dense_vector",
                "dims": instructor_dim,
                "index": True,
                "similarity": "cosine",
                "index_options": {
                    "type": "hnsw",
                    "m": 16,
                    "ef_construction": 256,
                },
            },
            "bge_emb": {
                "type": "dense_vector",
                "dims": bge_dim,
                "index": True,
                "similarity": "cosine",
                "index_options": {
                    "type": "hnsw",
                    "m": 16,
                    "ef_construction": 256,
                },
            },
        }
    }
}

engine.initialization()
engine.create_index(customized_mapping=mimic_iv_mapping)

In [None]:
engine.index(index_samples)

In [None]:
# make a query text
query_text = "any discharge information about heart failure"

In [None]:
# BM25 only search
req_body = {"query": {"match": {"text": {"query": query_text}}}}
engine.search(req_body, top_k=10)

In [None]:
# BM25 + bge hybrid search
bge_q_emb = engine.get_query_embedding(query=query_text, embedding_method="bge")
req_body = {
    "query": {"match": {"text": {"query": query_text}}},
    "knn": {
        "field": "bge_emb",
        "query_vector": bge_q_emb,
        "k": 20,
        "num_candidates": 64,
    },
}
engine.search(req_body, top_k=10)

In [None]:
# do some search
#     ins_emb = self.instructor.get_instructor_embeddings_single_sample(query=text, is_query=True)
#     gte_emb = self.gte.get_sentence_transformer_embeddings_single_sample(text)
#     bge_emb = self.bge.get_bge_embedding_single_sample(text, is_qurey=True)
#     sparse_query = [
#         {
#             "rank_feature": {
#                 "field": f"sparse_context.{k}",
#                 "linear": {},
#                 "boost": v,
#             }
#         }
#         for k, v in self.splade.get_splade_features_single_sample(query=text).items()
#     ]
#     sparse_query_bm25 = [
#         {"match": {self.text_key: text}},
#         {"match": {self.title_key: text}},
#     ] + sparse_query
    # def customized_search(self, text: str, top_hits: int, skip: int = 0, qtype: int = 0) -> Dict[str, object]:
    # if qtype > 1:
    #     ins_emb = self.instructor.get_instructor_embeddings_single_sample(query=text, is_query=True)
    #     gte_emb = self.gte.get_sentence_transformer_embeddings_single_sample(text)
    #     bge_emb = self.bge.get_bge_embedding_single_sample(text, is_qurey=True)
    #     sparse_query = [
    #         {
    #             "rank_feature": {
    #                 "field": f"sparse_context.{k}",
    #                 "linear": {},
    #                 "boost": v,
    #             }
    #         }
    #         for k, v in self.splade.get_splade_features_single_sample(query=text).items()
    #     ]
    #     sparse_query_bm25 = [
    #         {"match": {self.text_key: text}},
    #         {"match": {self.title_key: text}},
    #     ] + sparse_query
    # else:
    #     st_emb = []
    #     ins_emb = []
    #     sparse_query = []
    #     sparse_query_bm25 = []

    # if qtype == 1:
    #     # BM25
    #     req_body = {"query": {"match": {"text": {"query": text}}}}
    # elif qtype == 2:
    #     req_body = {
    #         "knn": {
    #             "field": "sb_emb",
    #             "query_vector": st_emb,
    #             "k": 20,
    #             "num_candidates": 100,
    #         },
    #     }
    # elif qtype == 3:
    #     req_body = {
    #         "knn": {
    #             "field": "ins_emb",
    #             "query_vector": ins_emb,
    #             "k": 20,
    #             "num_candidates": 100,
    #         },
    #     }
    # elif qtype == 4:
    #     req_body = {
    #         "query": {
    #             "bool": {
    #                 "should": sparse_query,
    #                 "boost": 0.5,
    #                 "minimum_should_match": 1,
    #             }
    #         }
    #     }
    # elif qtype == 5:
    #     req_body = {
    #         "knn": [
    #             {
    #                 "field": "sb_emb",
    #                 "query_vector": st_emb,
    #                 "k": 20,
    #                 "num_candidates": 100,
    #                 "boost": 0.5,
    #             },
    #             {
    #                 "field": "ins_emb",
    #                 "query_vector": ins_emb,
    #                 "k": 10,
    #                 "num_candidates": 100,
    #                 "boost": 0.5,
    #             },
    #         ],
    #     }
    # elif qtype == 6:
    #     req_body = {
    #         "query": {"bool": {"should": sparse_query_bm25}},
    #         "knn": {
    #             "field": "ins_emb",
    #             "query_vector": ins_emb,
    #             "k": 10,
    #             "num_candidates": 100,
    #         },
    #     }
    # elif qtype == 7:
    #     req_body = {
    #         "query": {
    #             "bool": {
    #                 "should": sparse_query_bm25,
    #                 "boost": 0.2,
    #                 "minimum_should_match": 1,
    #             },
    #         },
    #         "knn": [
    #             {
    #                 "field": "sb_emb",
    #                 "query_vector": st_emb,
    #                 "k": 10,
    #                 "num_candidates": 100,
    #                 "boost": 0.3,
    #             },
    #             {
    #                 "field": "ins_emb",
    #                 "query_vector": ins_emb,
    #                 "k": 10,
    #                 "num_candidates": 100,
    #                 "boost": 0.5,
    #             },
    #         ],
    #     }
    # elif qtype == 8:
    #     req_body = {
    #         "query": {
    #             "bool": {
    #                 "should": sparse_query,
    #                 "boost": 0.2,
    #                 "minimum_should_match": 1,
    #             },
    #         },
    #         "knn": [
    #             {
    #                 "field": "sb_emb",
    #                 "query_vector": st_emb,
    #                 "k": 10,
    #                 "num_candidates": 100,
    #                 "boost": 0.3,
    #             },
    #             {
    #                 "field": "ins_emb",
    #                 "query_vector": ins_emb,
    #                 "k": 10,
    #                 "num_candidates": 100,
    #                 "boost": 0.5,
    #             },
    #         ],
    #     }
    # elif qtype == 9:
    #     req_body = {
    #         "knn": {
    #             "field": "gte_emb",
    #             "query_vector": gte_emb,
    #             "k": 20,
    #             "num_candidates": 100,
    #         },
    #     }
    # elif qtype == 10:
    #     req_body = {
    #         "knn": {
    #             "field": "bge_emb",
    #             "query_vector": bge_emb,
    #             "k": 20,
    #             "num_candidates": 100,
    #         },
    #     }
    # elif qtype == 11:
    #     req_body = {
    #         "query": {
    #             "bool": {
    #                 "should": sparse_query,
    #                 "minimum_should_match": 1,
    #             },
    #         },
    #         "knn": {
    #             "field": "gte_emb",
    #             "query_vector": gte_emb,
    #             "k": 20,
    #             "num_candidates": 100,
    #         },
    #     }
    # elif qtype == 12:
    #     req_body = {
    #         "query": {
    #             "bool": {
    #                 "should": [
    #                     {"match": {self.text_key: text}},
    #                     {"match": {self.title_key: text}},
    #                 ],
    #                 "minimum_should_match": 1,
    #             },
    #         },
    #         "knn": {
    #             "field": "bge_emb",
    #             "query_vector": bge_emb,
    #             "k": 20,
    #             "num_candidates": 100,
    #         },
    #     }
    # elif qtype == 13:
    #     req_body = {
    #         "query": {
    #             "bool": {
    #                 "should": [
    #                     {"match": {self.text_key: text}},
    #                     {"match": {self.title_key: text}},
    #                 ]
    #             }
    #         },
    #         "knn": {
    #             "field": "ins_emb",
    #             "query_vector": ins_emb,
    #             "k": 20,
    #             "num_candidates": 100,
    #         },
    #     }
    # else:
    #     raise NotImplementedError(f"Query type {qtype} is not supported.")    