## Amazon Opensearch 테스트

여기에서는 AOSS (amazon opensearch serverless) 를 활용합니다.

### 사전 준비사항

- SageMaker notebook 에서 사용하는 IAM role에 아래 권한을 추가해야 합니다.
  - `AmazonOpenSearchServiceFullAccess`


In [None]:
!pip install -q opensearch-py requests-aws4auth

In [None]:
!pip list | grep 'opensearch-py\|requests-aws4auth'

In [None]:
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
import boto3
import botocore
import time

import sagemaker

sess = sagemaker.Session()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name


client = boto3.client('opensearchserverless')
service = 'aoss'
credentials = boto3.Session().get_credentials()
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key,
                   region, service, session_token=credentials.token)

In [None]:
collection_name = "rag-hol-aoss-collection"

In [None]:
def createEncryptionPolicy(client, collection):
    try:
        policy_format = """
                {
                    \"Rules\":[
                        {
                            \"ResourceType\":\"collection\",
                            \"Resource\":[
                                \"collection\/{collection_name}*\"
                            ]
                        }
                    ],
                    \"AWSOwnedKey\":true
                }
                """
        
        policy_string = policy_format.replace("{collection_name}", collection)

        response = client.create_security_policy(
            description=f'Encryption policy for {collection}',
            name=f'{collection}-policy',
            policy=policy_string,
            type='encryption'
        )
        print('\nEncryption policy created:')
        print(response)
        return response
    except botocore.exceptions.ClientError as error:
        if error.response['Error']['Code'] == 'ConflictException':
            print(
                '[ConflictException] The policy name or rules conflict with an existing policy.')
        else:
            raise error


def createNetworkPolicy(client, collection):
    try:
        policy_format = """
                [{
                    \"Description\":\"Public access for {collection_name}\",
                    \"Rules\":[
                        {
                            \"ResourceType\":\"dashboard\",
                            \"Resource\":[\"collection\/{collection_name}*\"]
                        },
                        {
                            \"ResourceType\":\"collection\",
                            \"Resource\":[\"collection\/{collection_name}*\"]
                        }
                    ],
                    \"AllowFromPublic\":true
                }]
                """
        
        policy_string = policy_format.replace("{collection_name}", collection)
        
        response = client.create_security_policy(
            description=f'Network policy for {collection}',
            name=f'{collection}-policy',
            policy=policy_string,
            type='network'
        )
        print('\nNetwork policy created:')
        print(response)
        return response
    except botocore.exceptions.ClientError as error:
        if error.response['Error']['Code'] == 'ConflictException':
            print(
                '[ConflictException] A network policy with this name already exists.')
        else:
            raise error


def createAccessPolicy(client, collection, role):
    try:
        policy_format = """
                [{
                    \"Rules\":[
                        {
                            \"Resource\":[
                                \"index\/{collection_name}*\/*\"
                            ],
                            \"Permission\":[
                                \"aoss:CreateIndex\",
                                \"aoss:DeleteIndex\",
                                \"aoss:UpdateIndex\",
                                \"aoss:DescribeIndex\",
                                \"aoss:ReadDocument\",
                                \"aoss:WriteDocument\"
                            ],
                            \"ResourceType\": \"index\"
                        },
                        {
                            \"Resource\":[
                                \"collection\/{collection_name}*\"
                            ],
                            \"Permission\":[
                                \"aoss:CreateCollectionItems\"
                            ],
                            \"ResourceType\": \"collection\"
                        }
                    ],
                    \"Principal\":[
                        \"{role_arn}\"
                    ]
                }]
                """
        policy_string = policy_format.replace("{collection_name}", collection)
        policy_string = policy_string.replace("{role_arn}", role)
        
        response = client.create_access_policy(
            description=f'Data access policy for {collection}',
            name=f'{collection}-policy',
            policy=policy_string,
            type='data'
        )
        print('\nAccess policy created:')
        print(response)
        return response
    except botocore.exceptions.ClientError as error:
        if error.response['Error']['Code'] == 'ConflictException':
            print(
                '[ConflictException] An access policy with this name already exists.')
        else:
            raise error


def createCollection(client, collection):
    try:
        response = client.create_collection(
            name=collection,
            type='VECTORSEARCH'
        )
        print(response)
        return response
    except botocore.exceptions.ClientError as error:
        if error.response['Error']['Code'] == 'ConflictException':
            print(
                '[ConflictException] A collection with this name already exists. Try another name.')
        else:
            raise error


def waitForCollectionCreation(client, collection):
    """Waits for the collection to become active"""
    response = client.batch_get_collection(
        names=[collection])
    # Periodically check collection status
    while (response['collectionDetails'][0]['status']) == 'CREATING':
        print('Creating collection...')
        time.sleep(30)
        response = client.batch_get_collection(
            names=[collection])
    print('\nCollection successfully created:')
    print(response["collectionDetails"])
    # Extract the collection endpoint from the response
    host = (response['collectionDetails'][0]['collectionEndpoint'])
    final_host = host.replace("https://", "")
    return final_host
    # indexData(final_host)


In [None]:
createEncryptionPolicy(aoss_client, collection_name)
createNetworkPolicy(client, collection_name)
createAccessPolicy(client, collection_name, role)
createCollection(client, collection_name)
aoss_endpoint = waitForCollectionCreation(client, collection_name)

# It can take up to a minute for data access rules to be enforced

In [None]:
print(aoss_endpoint)

In [None]:

def getOpenSearchClient(host):
    client = OpenSearch(
        hosts=[{'host': host, 'port': 443}],
        http_auth=awsauth,
        use_ssl=True,
        verify_certs=True,
        connection_class=RequestsHttpConnection,
        timeout=6000
    )
    return client
    
    
def getIndex(host, index_name):
    try:
        client = getOpenSearchClient(host)
        response = client.indices.get(index_name)
        print(response)
        return response
    except Exception as e:
        print(f"Error: {e}")

        
def deleteIndex(host, index_name):
    client = getOpenSearchClient(host)
    response = client.indices.delete(index_name)
    print('\nDeleting index:')
    print(response)
    
    return response
    

# Note that it can take up to a minute for data access rules to be enforced
def createIndex(host, index_name, index_schema=None):
    client = getOpenSearchClient(host)
    
    if index_schema:
        response = client.indices.create(index_name, body=index_schema)
    else:
        response = client.indices.create(index_name)
    print('\nCreating index:')
    print(response)
    
    return response


def addSampleData(host, index_name, data):
    client = getOpenSearchClient(host)

    response = client.index(
        index=index_name,
        body=data,
    )
    print('\nDocument added:')
    print(response)
    return response


def getSampleData(host, index_name, doc_id):
    client = getOpenSearchClient(host)
    
    response = client.get(index=index_name, id=doc_id)
    print(response)
    return response

In [None]:
index_name = "rag-hol-index-simple"
createIndex(aoss_endpoint, index_name)

# deleteIndex(aoss_endpoint, index_name)


In [None]:
getIndex(aoss_endpoint, index_name)

In [None]:
sample_data = {
    'title': 'Seinfeld',
    'creator': 'Larry David',
    'year': 1989
}

add_result = addSampleData(aoss_endpoint, index_name, sample_data)

In [None]:
document_id = add_result["_id"]
print(document_id)

In [None]:
# It could takes more than 1 min to indexing
getSampleData(aoss_endpoint, index_name, document_id)

In [None]:
ef_search = 512
embedding_model_dimensions = 1024

index_schema = {
        "settings": {
            "index": {
                "knn": True,
                "knn.algo_param.ef_search": ef_search,
            }
        },
        "mappings": {
            "properties": {
                "content_embeddings": {
                    "type": "knn_vector",
                    "dimension": embedding_model_dimensions,
                    "method": {
                        "name": "hnsw",
                        "space_type": "cosinesimil",
                        # "space_type": "l2",
                        "engine": "nmslib",
                        "parameters": {"ef_construction": 512, "m": 16},
                    },
                },
                "content": {"type": "text", "analyzer": "nori"},
                "metadata": {"type": "object"},
            }
        },
    }

In [None]:
vector_index_name = "rag-hol-index-vector"
createIndex(aoss_endpoint, vector_index_name, index_schema)

In [None]:
getIndex(aoss_endpoint, vector_index_name)

In [None]:
%store collection_name
%store vector_index_name
%store aoss_endpoint