In [1]:
import pandas as pd
from tqdm import tqdm
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
from opensearchpy.helpers import bulk
import boto3
import requests
import polling2
import logging
from random import randint

# Setup Basic Configuration
POLL = 60
logging.addLevelName(POLL, 'POLL')

logging.basicConfig(level=POLL,
                    format='%(asctime)s %(levelname)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
logging.getLogger().setLevel(POLL)

# Read Data

In [9]:
articles = pd.read_pickle('s3://trust-stream-nlp-data/trust-content-dsml/chunk_embedding_with_topic_sample_test.pkl')
articles = articles.rename(columns={'paragraph_id': 'id', 'minilm_embeddings': 'embedding'})

print(articles.shape)
articles.head(2)

(2000, 20)


Unnamed: 0,id,pull_date,brand,article_id,text,title,estimatedPublishedDate,embedding,anti_trans_legislation,covid_19,cybersecurity,data_privacy_gdpr,diversity_inclusion,gen_z,inflation,minimum_wage,ukraine_russia,vaccine,waste_reduction,work_from_home
0,46679408493_6_Volkswagen,2022-04-04,Volkswagen,46679408493,12/13/2021 – Volkswagen was given a new €235.0...,Volkswagen (VOW3) – Analysts’ Weekly Ratings C...,2022-01-01 00:22:41,"[0.07120183110237122, 0.09447652101516724, 0.0...",False,False,False,False,False,False,False,False,False,False,False,False
1,46679425484_7_PayPal,2022-01-02,PayPal,46679425484,PayPal Company Profile,"PayPal Holdings, Inc. (NASDAQ:PYPL) is Pelham ...",2022-01-01 00:28:42,"[-0.019512677565217018, -0.021549265831708908,...",False,False,False,False,False,False,False,False,False,False,False,False


# Opensearch

In [4]:
host = "search-mynewdomain-gp74byna6ivfgazl4upbhbclfe.us-east-1.es.amazonaws.com"
port = 443
region = 'us-east-1'
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region)

client = OpenSearch(
    hosts = [{'host': host, 'port': port}],
    http_auth = auth,
    use_ssl = True,
    verify_certs = True,
    connection_class = RequestsHttpConnection
)

In [6]:
url = f'https://{host}:{port}/_cat/nodes?v'
response = requests.get(url, auth=auth)
print(response.text)

ip            heap.percent ram.percent cpu load_1m load_5m load_15m node.role master name
x.x.x.x           34          97   4    0.13    0.37     0.48 dimr      -      829c6120268cba75b105aebe5ded58f8
x.x.x.x           72          97   4    0.15    0.25     0.22 dimr      -      9def54968eaed1446cc87cfc0d17a082
x.x.x.x            11          97   4    0.37    0.62     0.50 dimr      *      0fa9a767a70fd6dc523b62929e3b8da1



In [5]:
url = f'https://{host}:{port}/_cat/indices?v'
response = requests.get(url, auth=auth)
response.text

'health status index              uuid                   pri rep docs.count docs.deleted store.size pri.store.size\ngreen  open   article_chunks     OoKSWT4ZRwuJoLvDcbeZVw   4   1          0            0      1.6kb           852b\ngreen  open   chunk_texts_index  p3e-yfCaQQ6HiMRR--4xzQ  10   0    3885306            0     39.7gb         39.7gb\ngreen  open   python-test-index3 zghM9fXRSJG_K8Nb5xJCOQ   4   1          0            0      1.6kb           832b\ngreen  open   .kibana_1          jNXBc9tuT-WJbgs7u33hRA   1   1          1            0     10.1kb            5kb\ngreen  open   full_articles      OerehVoVQmudtWq-ubaueQ   4   1          0            0      1.6kb           832b\n'

## Helpers

In [59]:
def upload_data(client, index_name, df, chunk_size=500):
    for idx in tqdm(range(0, len(df), chunk_size)):
        # Upload chunk_size number of rows at a time
        subset_df = df.iloc[idx:idx + chunk_size]

        actions = [
            {
                '_index': index_name,
                '_id': row.id,
                '_source': {
                    key: value 
                    for key, value in row._asdict().items() 
                    if key not in ['Index', 'id']
                }
            }
            for row in subset_df.itertuples()
        ]

        _, errors = bulk(client, actions, max_retries=2, request_timeout=100)
        assert len(errors) == 0, errors

    # Refresh the data
    client.indices.refresh(index_name, request_timeout=1000)

def find_top_k_chunks(
    embeddings,
    k,
    cols_to_query,
    index_name,
    client,
    emb_col,
    chunk_size=500,
):

    req_head = {"index": index_name}
    responses = []

    for idx in range(0, len(embeddings), chunk_size):
        subset_embeddings = embeddings[idx : idx + chunk_size]
        request = []

        for embedding in subset_embeddings:
            req_body = {
                "query": {"knn": {emb_col: {"vector": embedding, "k": k}}},
                "size": k,
                "_source": cols_to_query,
            }

            request.extend([req_head, req_body])

        r = client.msearch(body=request)
        responses.extend(r["responses"])

    # Post processing
    chunks = []
    for item in responses:
        df = pd.DataFrame(item["hits"]["hits"])
        df = df[["_id", "_score"]].join(pd.json_normalize(df["_source"]))
        chunks.append(df.to_dict(orient="records"))

    return chunks

## FAISS Model Training

### Create the train_index
- Stores vectors for training the FAISS model

In [6]:
index_name = 'train_index'
emb_col = 'embedding'
emb_dim = len(articles.loc[0, emb_col])

In [7]:
index_body = {
    'settings': {
        'index': {
            'number_of_shards': 10,
            'number_of_replicas': 0,
            'refresh_interval': -1,
        }
    },
    'mappings': {
        'properties': {
            emb_col: { 
                'type': 'knn_vector',
                'dimension': emb_dim
            }
        }
    }
}

if client.indices.exists(index=index_name):
    # Delete the index if it exists
    client.indices.delete(index=index_name)

response = client.indices.create(index_name, body=index_body)
print(response)

{'acknowledged': True, 'shards_acknowledged': True, 'index': 'train_index'}


### Upload vectors

In [8]:
upload_data(client, index_name='train_index', df=articles[['id', emb_col]])

100%|██████████| 196/196 [14:45<00:00,  4.52s/it]


### Train FAISS Model

In [9]:
model_name = 'faiss_ivf_pq'
url = f'https://{host}:{port}/_plugins/_knn/models/{model_name}'

payload = {
    'training_index': index_name,
    'training_field': emb_col,
    'dimension': emb_dim,
    'description': 'FAISS IVF-PQ ANN index',
    'method': {
        'name': 'ivf',
        'engine': 'faiss',
        'space_type': 'l2',
        'parameters':{
            'nlist': 2048,
            'nprobes': 256,
            'encoder':{
                'name':'pq',
                'parameters':{
                    'code_size': 8,
                    'm': 8
                }
            }
        }
    }
}

if requests.get(url, auth=auth).status_code != 404:
    # Delete if model exists
    response = requests.delete(url, auth=auth)
    print(response.json())

# Train the model
response = requests.post(f'{url}/_train', json=payload, auth=auth)
print(response.text)

{'model_id': 'faiss_ivf_pq', 'result': 'deleted'}
{"model_id":"faiss_ivf_pq"}


### Poll the Training job

In [10]:
url = f'https://{host}:{port}/_plugins/_knn/models/{model_name}?filter_path=state&pretty'

polling2.poll(
    lambda: requests.get(url, auth=auth).json()['state'] != 'training',
    step=60,
    poll_forever=True,
    log=POLL)

training_status = requests.get(url, auth=auth).json()['state']
assert training_status=='created', training_status

2022-05-21 09:38:09 POLL: poll() calls check_success(False)
2022-05-21 09:39:21 POLL: poll() calls check_success(False)
2022-05-21 09:40:41 POLL: poll() calls check_success(False)
2022-05-21 09:41:43 POLL: poll() calls check_success(False)
2022-05-21 09:42:44 POLL: poll() calls check_success(False)
2022-05-21 09:43:46 POLL: poll() calls check_success(False)
2022-05-21 09:44:47 POLL: poll() calls check_success(False)
2022-05-21 09:45:48 POLL: poll() calls check_success(False)
2022-05-21 09:46:50 POLL: poll() calls check_success(False)
2022-05-21 09:47:51 POLL: poll() calls check_success(False)
2022-05-21 09:48:52 POLL: poll() calls check_success(False)
2022-05-21 09:49:53 POLL: poll() calls check_success(False)
2022-05-21 09:50:55 POLL: poll() calls check_success(False)
2022-05-21 09:51:56 POLL: poll() calls check_success(True)


## ANN Search

### Create the ANN Search Index with model trained in previous step

In [11]:
index_name = 'chunk_texts_index'
index_body = {
    'settings': {
        'index': {
            'knn': True,
            'number_of_shards': 10,
            'number_of_replicas': 0,
            'refresh_interval': -1,
        }
    },
    'mappings': {
        'properties': {
            emb_col: {
                "type": "knn_vector",
                "dimension": emb_dim,
                "method": {
                    "name": "hnsw",
                    "space_type": "l2",
                    "engine": "nmslib",
                    "parameters": {
                        "ef_construction": 128,
                        "m": 24
                    }
                }
            }
            # emb_col: { 
            #     'type': 'knn_vector',
            #     'model_id': model_name,
            # }
        }
    }
}

if client.indices.exists(index=index_name):
    # Delete the index if it exists
    client.indices.delete(index=index_name)

response = client.indices.create(index_name, body=index_body)
print(response)

{'acknowledged': True, 'shards_acknowledged': True, 'index': 'sample_chunks'}


### Upload Data

In [12]:
upload_data(client, index_name=index_name, df=articles, chunk_size=100)

100%|██████████| 977/977 [24:08<00:00,  1.48s/it]  


### Warmup

In [13]:
url = f'https://{host}:{port}/_plugins/_knn/warmup/{index_name}?pretty'
response = requests.get(url, auth=auth)
response.json()

{'_shards': {'total': 10, 'successful': 10, 'failed': 0}}

In [14]:
# Check number of records after update
print(client.cat.count(index_name, params={"format": "json"}))

[{'epoch': '1653099453', 'timestamp': '02:17:33', 'count': '97685'}]


## Query

### Query by ID

In [None]:
client.get(index=index_name, id='46683031823_6_Google')

### Query by search term

In [None]:
q = 'google'
num_results = 5

query = {
    'size': num_results,
    'query': {
        'multi_match': {
            'query': q,
            'fields': ['brand', 'title', 'text']
        }
    }
}

response = client.search(
    index=index_name,
    body=query
)

assert len(response['hits']['hits'])==num_results
print(response)

### Vector Search

In [None]:
# Things you can update
num_results = 10 
## Recommend to keep it <=100 for good search speed

row_to_query = randint(0, len(articles)) 
## Choose a random integer between 0 and len(articles)

cols_to_query = ['title', 'text'] 
## Valid options for cols_to_query are: ['pull_date', 'brand', 'article_id', 'text', 'title', 'estimatedPublishedDate', 'anti_trans_legislation', 'covid_19', 'cybersecurity', 'data_privacy_gdpr', 'diversity_inclusion', 'gen_z', 'inflation', 'minimum_wage', 'ukraine_russia', 'vaccine', 'waste_reduction', 'work_from_home']. The more columns you extract, the slower the query

###################################################################

# DO NOT CHANGE BELOW!
index_name = 'chunk_text_index'
emb_col = 'embedding'
embeddings = [articles.loc[row_to_query, emb_col].tolist()]

responses = find_top_k_chunks(
    embeddings,
    k=num_results,
    cols_to_query=cols_to_query,
    index_name=index_name,
    client=client,
    emb_col="embedding",
    chunk_size=500,
)

# Look at the output
print('🔎 Query:')
print(articles.loc[row_to_query, cols_to_query].values)
print('\n-------------------------------------------------------------\n')
print('📝 Results')
responses

The opensearch documentation says score is a function of distance
$$score = \frac{1}{1+distance}$$

So we can calculate the L2 distance using
$$distance = \frac{1}{score} - 1$$

## Clean up

In [72]:
url = f'https://{host}:{port}/_cat/indices?v'
response = requests.get(url, auth=auth)
print(response.text)

health status index              uuid                   pri rep docs.count docs.deleted store.size pri.store.size
green  open   chunk_text_index   BBGZjZzvTzqGeE9OG_5xXg  10   0    3885306            0     39.7gb         39.7gb
green  open   article_chunks     OoKSWT4ZRwuJoLvDcbeZVw   4   1          0            0      1.6kb           852b
green  open   python-test-index3 zghM9fXRSJG_K8Nb5xJCOQ   4   1          0            0      1.6kb           832b
green  open   .kibana_1          jNXBc9tuT-WJbgs7u33hRA   1   1          1            0     10.1kb            5kb
green  open   full_articles      scnMPTNfTw2QuGTa4ZA88Q  10   0    3930943            0     15.5gb         15.5gb



In [28]:
response = client.indices.delete(index='train_index')
print(response)

{'acknowledged': True}


In [71]:
response = client.indices.delete(index='chunk_texts_index')
print(response)

{'acknowledged': True}
