In [1]:
import pandas as pd
from tqdm import tqdm
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
from opensearchpy.helpers import bulk
import boto3
import requests
import polling2
import logging

# Setup Basic Configuration
POLL = 60
logging.addLevelName(POLL, 'POLL')

logging.basicConfig(level=POLL,
                    format='%(asctime)s %(levelname)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
logging.getLogger().setLevel(POLL)

# Read Data

In [2]:
articles = pd.read_pickle('../data/articles.pkl')

print(articles.shape)
articles.head(2)

(97685, 7)


Unnamed: 0,id,parent_id,relevancy_rank,brand,title,text,embedding
0,46683031823_6_Google,46683031823,2,Google,Governor Ron DeSantis accompanied wife to canc...,Everyone who reads our reporting knows the Gel...,"[-0.016234327, 0.096927114, -0.0736138, 0.0110..."
1,46686258924_0_Google,46686258924,1,Google,How to enable 2-step verification on your Goog...,Proceed to the next point and use the phone to...,"[-0.07577787, 0.05864354, 0.061040197, -0.0203..."


# Opensearch

In [3]:
host = ""
port = 443
region = 'us-east-1'
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region)

client = OpenSearch(
    hosts = [{'host': host, 'port': port}],
    http_auth = auth,
    use_ssl = True,
    verify_certs = True,
    connection_class = RequestsHttpConnection
)

In [4]:
url = f'https://{host}:{port}/_cat/indices?v'
response = requests.get(url, auth=auth)
response.text

'health status index              uuid                   pri rep docs.count docs.deleted store.size pri.store.size\ngreen  open   full_article       I5BytbAAS5Wk5QCBDyjaDw  10   0          0            0        2kb            2kb\nyellow open   python-test-index3 zghM9fXRSJG_K8Nb5xJCOQ   4   1          0            0       832b           832b\ngreen  open   train_index        KRUXZtfEQb-C6MbeGZ7CQQ  10   0      97685            0      716mb          716mb\ngreen  open   sample_chunks      sEsNXH0tTY6-EWye1W9e_g  10   0      97685            0      861mb          861mb\ngreen  open   .kibana_1          jNXBc9tuT-WJbgs7u33hRA   1   0          1            0        5kb            5kb\n'

## Helpers

In [5]:
def upload_data(client, index_name, df, chunk_size=500):
    for idx in tqdm(range(0, len(df), chunk_size)):
        # Upload chunk_size number of rows at a time
        subset_df = df.iloc[idx:idx + chunk_size]

        actions = [
            {
                '_index': index_name,
                '_id': row.id,
                '_source': {
                    key: value 
                    for key, value in row._asdict().items() 
                    if key not in ['Index', 'id']
                }
            }
            for row in subset_df.itertuples()
        ]

        _, errors = bulk(client, actions, max_retries=2, request_timeout=100)
        assert len(errors) == 0, errors

    # Refresh the data
    client.indices.refresh(index_name, request_timeout=1000)

## FAISS Model Training

### Create the train_index
- Stores vectors for training the FAISS model

In [6]:
index_name = 'train_index'
emb_col = 'embedding'
emb_dim = len(articles.loc[0, emb_col])

In [7]:
index_body = {
    'settings': {
        'index': {
            'number_of_shards': 10,
            'number_of_replicas': 0,
            'refresh_interval': -1,
        }
    },
    'mappings': {
        'properties': {
            emb_col: { 
                'type': 'knn_vector',
                'dimension': emb_dim
            }
        }
    }
}

if client.indices.exists(index=index_name):
    # Delete the index if it exists
    client.indices.delete(index=index_name)

response = client.indices.create(index_name, body=index_body)
print(response)

{'acknowledged': True, 'shards_acknowledged': True, 'index': 'train_index'}


### Upload vectors

In [8]:
upload_data(client, index_name='train_index', df=articles[['id', emb_col]])

100%|██████████| 196/196 [14:45<00:00,  4.52s/it]


### Train FAISS Model

In [9]:
model_name = 'faiss_ivf_pq'
url = f'https://{host}:{port}/_plugins/_knn/models/{model_name}'

payload = {
    'training_index': index_name,
    'training_field': emb_col,
    'dimension': emb_dim,
    'description': 'FAISS IVF-PQ ANN index',
    'method': {
        'name': 'ivf',
        'engine': 'faiss',
        'space_type': 'l2',
        'parameters':{
            'nlist': 2048,
            'nprobes': 256,
            'encoder':{
                'name':'pq',
                'parameters':{
                    'code_size': 8,
                    'm': 8
                }
            }
        }
    }
}

if requests.get(url, auth=auth).status_code != 404:
    # Delete if model exists
    response = requests.delete(url, auth=auth)
    print(response.json())

# Train the model
response = requests.post(f'{url}/_train', json=payload, auth=auth)
print(response.text)

{'model_id': 'faiss_ivf_pq', 'result': 'deleted'}
{"model_id":"faiss_ivf_pq"}


### Poll the Training job

In [10]:
url = f'https://{host}:{port}/_plugins/_knn/models/{model_name}?filter_path=state&pretty'

polling2.poll(
    lambda: requests.get(url, auth=auth).json()['state'] != 'training',
    step=60,
    poll_forever=True,
    log=POLL)

training_status = requests.get(url, auth=auth).json()['state']
assert training_status=='created', training_status

2022-05-21 09:38:09 POLL: poll() calls check_success(False)
2022-05-21 09:39:21 POLL: poll() calls check_success(False)
2022-05-21 09:40:41 POLL: poll() calls check_success(False)
2022-05-21 09:41:43 POLL: poll() calls check_success(False)
2022-05-21 09:42:44 POLL: poll() calls check_success(False)
2022-05-21 09:43:46 POLL: poll() calls check_success(False)
2022-05-21 09:44:47 POLL: poll() calls check_success(False)
2022-05-21 09:45:48 POLL: poll() calls check_success(False)
2022-05-21 09:46:50 POLL: poll() calls check_success(False)
2022-05-21 09:47:51 POLL: poll() calls check_success(False)
2022-05-21 09:48:52 POLL: poll() calls check_success(False)
2022-05-21 09:49:53 POLL: poll() calls check_success(False)
2022-05-21 09:50:55 POLL: poll() calls check_success(False)
2022-05-21 09:51:56 POLL: poll() calls check_success(True)


## ANN Search

### Create the ANN Search Index with model trained in previous step

In [11]:
index_name = 'chunk_texts_index'
index_body = {
    'settings': {
        'index': {
            'knn': True,
            'number_of_shards': 10,
            'number_of_replicas': 0,
            'refresh_interval': -1,
        }
    },
    'mappings': {
        'properties': {
            # emb_col: {
            #     "type": "knn_vector",
            #     "dimension": emb_dim,
            #     "method": {
            #         "name": "hnsw",
            #         "space_type": "l2",
            #         "engine": "nmslib",
            #         "parameters": {
            #             "ef_construction": 128,
            #             "m": 24
            #         }
            #     }
            # }
            emb_col: { 
                'type': 'knn_vector',
                'model_id': model_name,
            }
        }
    }
}

if client.indices.exists(index=index_name):
    # Delete the index if it exists
    client.indices.delete(index=index_name)

response = client.indices.create(index_name, body=index_body)
print(response)

{'acknowledged': True, 'shards_acknowledged': True, 'index': 'sample_chunks'}


### Upload Data

In [12]:
upload_data(client, index_name=index_name, df=articles, chunk_size=100)

100%|██████████| 977/977 [24:08<00:00,  1.48s/it]  


### Warmup

In [13]:
url = f'https://{host}:{port}/_plugins/_knn/warmup/{index_name}?pretty'
response = requests.get(url, auth=auth)
response.json()

{'_shards': {'total': 10, 'successful': 10, 'failed': 0}}

In [14]:
# Check number of records after update
print(client.cat.count(index_name, params={"format": "json"}))

[{'epoch': '1653099453', 'timestamp': '02:17:33', 'count': '97685'}]


## Query

### Query by ID

In [None]:
client.get(index=index_name, id='46683031823_6_Google')

### Query by search term

In [None]:
q = 'google'
num_results = 5

query = {
    'size': num_results,
    'query': {
        'multi_match': {
            'query': q,
            'fields': ['brand', 'title', 'text']
        }
    }
}

response = client.search(
    index=index_name,
    body=query
)

assert len(response['hits']['hits'])==num_results
print(response)

### Vector Search

In [20]:
articles.loc[2, 'title']

'How to add Google Sheets to Google Slides'

In [19]:
index_name = 'chunk_texts_index'
num_results = 5
query_vector = articles.loc[2, emb_col].tolist()

query = {
    'size': num_results,
    '_source': ['_id', 'title', 'text'],
    'query': {
        'knn': {
            emb_col: {
                'vector': query_vector,
                'k': num_results,
            }
        }
    }
}

response = client.search(
    index=index_name,
    body=query
)

response['hits']['hits']

[{'_index': 'sample_chunks',
  '_type': '_doc',
  '_id': '46687483861_0_Google',
  '_score': 0.74406916,
  '_source': {'text': 'You can use Google Apps Script to insert data from Google Sheets into Google Slides templates and create hundreds of beautiful slides in minutes. In this post, I’ll show you how to integrate Google Sheets with your Google Slides presentations and some helpful tips to help you get the most out of your integration',
   'title': 'How to add Google Sheets to Google Slides'}},
 {'_index': 'sample_chunks',
  '_type': '_doc',
  '_id': '46687483861_1_Google',
  '_score': 0.7034443,
  '_source': {'text': 'How to add a Google spreadsheet to Google Slides Add a table to your presentation : Adding a Google Sheets chart to your Google Slides presentation is an easy process. You can do this by following these steps: : Open the Google Slides presentation where you want to embed the graphic. Click the number of the slide you want to view. Open the Google Sheets file that requ

## Clean up

In [25]:
url = f'https://{host}:{port}/_cat/indices?v'
response = requests.get(url, auth=auth)
print(response.text)

health status index              uuid                   pri rep docs.count docs.deleted store.size pri.store.size
green  open   full_article       I5BytbAAS5Wk5QCBDyjaDw  10   0          0            0        2kb            2kb
yellow open   python-test-index3 zghM9fXRSJG_K8Nb5xJCOQ   4   1          0            0       832b           832b
green  open   .kibana_1          jNXBc9tuT-WJbgs7u33hRA   1   0          1            0        5kb            5kb



In [28]:
response = client.indices.delete(index='train_index')
print(response)

{'acknowledged': True}


In [26]:
response = client.indices.delete(index='chunk_texts_index')
print(response)

{'acknowledged': True}
