In [1]:
%env AWS_PROFILE=platform-developer

env: AWS_PROFILE=platform-developer


In [3]:
from utils.aws import get_secret
import elasticsearch

import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

ES_ENDPOINT = "https://semantic-playground-b28f61.es.eu-west-1.aws.elastic.cloud:443"
ES_API_KEY = get_secret("agnes/elasticsearch/semantic-playground")
ES_CLIENT = elasticsearch.Elasticsearch(ES_ENDPOINT, api_key=ES_API_KEY, request_timeout=120)

In [33]:
# Pretty print utils
import re

BOLD = "\033[1m"
RESET = "\033[0m"
RESET_COLOR = "\033[39m"

def get_work_url(work_id: str):
    return f"https://wellcomecollection.org/works/{work_id}"

def highlight_terms(text, terms, color="\033[92m"):
    for term in terms:
        pattern = re.compile(re.escape(term), re.IGNORECASE)
        text = pattern.sub(lambda m: f"{color}{m.group(0)}{RESET_COLOR}", text)

    return text

def print_bold(text: str):
    print(f"{BOLD}{text}{RESET}")
   

In [55]:
def get_full_query(query: str):
    return {
        "bool": {
            "should": [
                {
                    "multi_match": {
                        "_name": "text_strict",
                        "query": f"{query}",
                        "fields": [
                            "query.title.*^5",
                            "query.title.cased^10",
                            "query.contributors.agent.label^10",
                            "query.subjects.concepts.label^10",
                            "query.genres.concepts.label^10",
                            "query.production.label.*^10",
                            "query.partOf.title.*^10",
                            "query.alternativeTitles.*",
                            "query.description.*",
                            "query.edition",
                            "query.languages.label",
                            "query.lettering.*",
                            "query.notes.contents.*",
                            "query.physicalDescription.*"
                        ],
                        "type": "cross_fields",
                        "tie_breaker": 0.4,
                        "minimum_should_match": "3<-20%",
                        "operator": "Or"
                    }
                },
                {
                    "match_phrase_prefix": {
                        "query.title.normalized_whole_phrase": {
                            "_name": "title_prefix",
                            "query": f"{query}",
                            "boost": 50
                        }
                    }
                },
                {
                    "multi_match": {
                        "_name": "ids_lax",
                        "query": f"{query}",
                        "analyzer": "lowercase_whitespace_tokens",
                        "fields": [
                            "query.id^5",
                            "query.sourceIdentifier.value^5",
                            "query.identifiers.value",
                            "query.items.id",
                            "query.items.identifiers.value",
                            "query.images.id",
                            "query.images.identifiers.value",
                            "query.referenceNumber*"
                        ],
                        "type": "cross_fields",
                        "boost": 100,
                        "operator": "OR",
                        "minimum_should_match": 1
                    }
                },
                {
                    "multi_match": {
                        "_name": "ids_with_path_lax",
                        "query": f"{query}",
                        "analyzer": "lowercase_whitespace_tokens",
                        "fields": ["query.items.shelfmark*", "query.collectionPath*"],
                        "type": "cross_fields",
                        "boost": 50,
                        "operator": "OR",
                        "minimum_should_match": 1
                    }
                },
                {
                    "bool": {
                        "must": [
                            {
                                "multi_match": {
                                    "_name": "text_lax",
                                    "query": f"{query}",
                                    "fields": [
                                        "query.title.*^5",
                                        "query.title.cased^10",
                                        "query.contributors.agent.label^10",
                                        "query.subjects.concepts.label^10",
                                        "query.genres.concepts.label^10",
                                        "query.production.label.*^10",
                                        "query.partOf.title.*^10",
                                        "query.alternativeTitles.*",
                                        "query.description.*",
                                        "query.edition",
                                        "query.languages.label",
                                        "query.lettering.*",
                                        "query.notes.contents.*",
                                        "query.physicalDescription.*"
                                    ],
                                    "type": "cross_fields",
                                    "tie_breaker": 0.4,
                                    "minimum_should_match": 1,
                                    "operator": "OR"
                                }
                            }
                        ],
                        "filter": [
                            {
                                "multi_match": {
                                    "query": f"{query}",
                                    "analyzer": "lowercase_whitespace_tokens",
                                    "fields": [
                                        "query.id^5",
                                        "query.sourceIdentifier.value^5",
                                        "query.referenceNumber*^5",
                                        "query.identifiers.value",
                                        "query.items.id",
                                        "query.items.identifiers.value",
                                        "query.items.shelfmark*",
                                        "query.images.id",
                                        "query.images.identifiers.value",
                                        "query.collectionPath*"
                                    ],
                                    "type": "cross_fields",
                                    "operator": "OR",
                                    "minimum_should_match": 1
                                }
                            }
                        ]
                    }
                }
            ]
        }
    }


def get_basic_query(query: str, fields: list[str]) -> dict:
    return {
        "bool": {
            "should": [
                {"match": {field: {"query": query}}}
                for field in fields
            ]
        }
    }
    

def get_basic_knn_query(query: str, fields: list[str]) -> dict:
    search_query = {
        "bool": {
          "should": []
        }
    }

    for field in fields:
        search_query["bool"]["should"].append({
              "knn": {
                "field": field,
                "k": 50,
                "num_candidates": 500,
                # "boost": 1.0,
                "query_vector_builder": {
                  "text_embedding": {
                    "model_id": "amazon-bedrock-titan-embeddings",
                    "model_text": query
                  }
                }
              }
            }
        )
    
    return search_query

In [143]:
SIZE = 100
PRINT_LIMIT = 10
INDEXES = {
    "elser_eis": "works-elser-title-description",
    "bedrock_titan": "works-titan-title-description",
    "non_semantic": "works-non-semantic",
    "production": "works-non-semantic"
}
COMPARISON_LABELS = {
    "elser_eis": "ELSER",
    "bedrock_titan": "Titan",
    "non_semantic": "Control",
    "production": "Production"
}
INDEX_COLORS = {
    "elser_eis": "\033[34m",
    "bedrock_titan": "\033[91m",
    "non_semantic": "\033[93m",
    "production": "\033[95m"
}
BASIC_QUERY_FUNCTIONS = {
    "non_semantic": get_basic_query,
    "elser_eis": get_basic_query,
    "bedrock_titan": get_basic_knn_query
}
BASIC_QUERY_FIELDS = {
    "non_semantic": ["query.title", "query.description"],
    "elser_eis": ["query.titleSemantic", "query.descriptionSemantic"],
    "bedrock_titan": ["query.titleSemantic", "query.descriptionSemantic"]
}
QUERY_TYPES = {
    "non_semantic": "basic",
    "elser_eis": "basic",
    "bedrock_titan": "basic",
    "production": "full"
}

In [144]:
def get_combined_query_results(query: str):
    results = {}
    for index in INDEXES.keys():
        body = {"size": SIZE, "track_total_hits": True}

        if QUERY_TYPES[index] == "basic":
            body["query"] = BASIC_QUERY_FUNCTIONS[index](query, BASIC_QUERY_FIELDS[index])
        else:
            body["query"] = get_full_query(query)

        response = dict(ES_CLIENT.search(index=INDEXES[index], body=body))
        results[index] = response
    
        hits = response["hits"]["hits"]
        results[index]["ranking"] = {hit["_id"]: i + 1 for i, hit in enumerate(hits)}
        results[index]["ids"] = set(hit["_id"] for hit in hits)

    return results

def compare_query_results(query: str):
    print(f"{BOLD}Query:{RESET} {query}")
    query_terms = query.split(" ")

    results = get_combined_query_results(query)
    print(f"{BOLD}Total results:{RESET}", end=" ")
    for index in INDEXES.keys():
        print(f"{INDEX_COLORS[index]}{COMPARISON_LABELS[index]} {results[index]["hits"]["total"]["value"]}{RESET_COLOR}", end=" ")
    print("\n")

    seen = set()
    for i in range(PRINT_LIMIT):
        print(f"{BOLD}————— {i+1} —————{RESET}\n")
        for index in INDEXES.keys():
            hits = results[index]["hits"]["hits"]

            if len(hits) > i:
                hit = hits[i]        
                work_id = hit["_id"]
                if work_id not in seen:
                    seen.add(work_id)
                    print(get_work_url(work_id))
                    print_bold(highlight_terms(hit["_source"]["display"]["title"], query_terms))
                    if "description" in hit["_source"]["display"]:
                        print(highlight_terms(hit["_source"]["display"]["description"], query_terms))
        
                    for index_2 in INDEXES.keys():
                        print(f"{INDEX_COLORS[index_2]}{COMPARISON_LABELS[index_2]} {results[index_2]["ranking"].get(work_id, "-")}{RESET_COLOR}", end=" ")
        
                    print("\n")

def find_needle_in_haystack(query: str, work_id: str):
    results = {}
    for index in INDEXES.keys():
        body = {"size": 10_000, "track_total_hits": True, "_source": False}

        if QUERY_TYPES[index] == "basic":
            body["query"] = BASIC_QUERY_FUNCTIONS[index](query, BASIC_QUERY_FIELDS[index])
        else:
            body["query"] = get_full_query(query)

        response = dict(ES_CLIENT.search(index=INDEXES[index], body=body))
        results[index] = response
    
        ids = [h["_id"] for h in response["hits"]["hits"]]
                
        if work_id in ids:
            print(f"{INDEX_COLORS[index]}{COMPARISON_LABELS[index]} {ids.index(work_id) + 1}{RESET_COLOR}", end=" ")
        else:
            print(f"{INDEX_COLORS[index]}{COMPARISON_LABELS[index]} -{RESET_COLOR}", end=" ")                    

In [161]:
# QUERY = "surgery knife"
# QUERY = "HIV"
# QUERY = "violent criminal"
# QUERY = "lung neoplasm"
#QUERY = "lung neoplasm inflammation"
QUERY = "cardiac failure"
QUERY = "edo japan"
compare_query_results(QUERY)

[1mQuery:[0m edo japan
[1mTotal results:[0m [34mELSER 24206[39m [91mTitan 573[39m [93mControl 261[39m [95mProduction 0[39m 

[1m————— 1 —————[0m

https://wellcomecollection.org/works/dqcsv9tp
[1m"[92mJapan[39m 1"[0m
Reference file of publications relating to [92mJapan[39m, including a brochure for University of Birmingham [92mJapan[39m Centre, brochures for [92mJapan[39mese buildings, an issue of the British Council's <i>Creative Industries Magazine</i> and a British Chamber of Commerce [92mJapan[39m Fact File.
[34mELSER 1[39m [91mTitan 14[39m [93mControl 2[39m [95mProduction -[39m 

https://wellcomecollection.org/works/jqeq5cp9
[1m[92mJapan[39m[0m
Various items of correspondence, particularly with Tanabe Seiyaku Co Ltd, and some notes of meetings.  Correspondence concerning the supply of Newcastle Disease Vaccine.
[34mELSER 38[39m [91mTitan 1[39m [93mControl 35[39m [95mProduction -[39m 

https://wellcomecollection.org/works/sbz6qqth
[1m"

In [159]:
find_needle_in_haystack("lung neoplasm", "s3e28zhn")

[34mELSER 377[39m [91mTitan 59[39m [93mControl -[39m [95mProduction 1[39m 

In [160]:
find_needle_in_haystack("cardiac failure", "e37qcyfm")

[34mELSER 7[39m [91mTitan 1[39m [93mControl 60[39m [95mProduction -[39m 