In [364]:
%env AWS_PROFILE=platform-developer

env: AWS_PROFILE=platform-developer


In [366]:
from utils.aws import get_secret
import elasticsearch

import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

ES_ENDPOINT = "https://semantic-playground-b28f61.es.eu-west-1.aws.elastic.cloud:443"
ES_API_KEY = get_secret("agnes/elasticsearch/semantic-playground")
ES_CLIENT = elasticsearch.Elasticsearch(ES_ENDPOINT, api_key=ES_API_KEY, request_timeout=120)

In [265]:
# Pretty print utils
import re

BOLD = "\033[1m"
RESET = "\033[0m"
RESET_COLOR = "\033[39m"


def get_work_url(work_id: str):
    return f"https://wellcomecollection.org/works/{work_id}"


def highlight_terms(text, terms, color="\033[92m"):
    for term in terms:
        pattern = re.compile(re.escape(term), re.IGNORECASE)
        text = pattern.sub(lambda m: f"{color}{m.group(0)}{RESET_COLOR}", text)

    return text


def print_bold(text: str):
    print(f"{BOLD}{text}{RESET}")


In [367]:
import math
import time

def get_production_query(query: str, *args):
    return {
        "bool": {
            "should": [
                {
                    "multi_match": {
                        "_name": "text_strict",
                        "query": f"{query}",
                        "fields": [
                            "query.title.*^5",
                            "query.title.cased^10",
                            "query.contributors.agent.label^10",
                            "query.subjects.concepts.label^10",
                            "query.genres.concepts.label^10",
                            "query.production.label.*^10",
                            "query.partOf.title.*^10",
                            "query.alternativeTitles.*",
                            "query.description.*",
                            "query.edition",
                            "query.languages.label",
                            "query.lettering.*",
                            "query.notes.contents.*",
                            "query.physicalDescription.*"
                        ],
                        "type": "cross_fields",
                        "minimum_should_match": "3<-20%",
                        "operator": "Or"
                    }
                },
                {
                    "match_phrase_prefix": {
                        "query.title.normalized_whole_phrase": {
                            "_name": "title_prefix",
                            "query": f"{query}",
                            "boost": 50
                        }
                    }
                },
                {
                    "multi_match": {
                        "_name": "ids_lax",
                        "query": f"{query}",
                        "analyzer": "lowercase_whitespace_tokens",
                        "fields": [
                            "query.id^5",
                            "query.sourceIdentifier.value^5",
                            "query.identifiers.value",
                            "query.items.id",
                            "query.items.identifiers.value",
                            "query.images.id",
                            "query.images.identifiers.value",
                            "query.referenceNumber*"
                        ],
                        "type": "cross_fields",
                        "boost": 100,
                        "operator": "OR",
                        "minimum_should_match": 1
                    }
                },
                {
                    "multi_match": {
                        "_name": "ids_with_path_lax",
                        "query": f"{query}",
                        "analyzer": "lowercase_whitespace_tokens",
                        "fields": ["query.items.shelfmark*", "query.collectionPath*"],
                        "type": "cross_fields",
                        "boost": 50,
                        "operator": "OR",
                        "minimum_should_match": 1
                    }
                },
                {
                    "bool": {
                        "must": [
                            {
                                "multi_match": {
                                    "_name": "text_lax",
                                    "query": f"{query}",
                                    "fields": [
                                        "query.title.*^5",
                                        "query.title.cased^10",
                                        "query.contributors.agent.label^10",
                                        "query.subjects.concepts.label^10",
                                        "query.genres.concepts.label^10",
                                        "query.production.label.*^10",
                                        "query.partOf.title.*^10",
                                        "query.alternativeTitles.*",
                                        "query.description.*",
                                        "query.edition",
                                        "query.languages.label",
                                        "query.lettering.*",
                                        "query.notes.contents.*",
                                        "query.physicalDescription.*"
                                    ],
                                    "type": "cross_fields",
                                    "tie_breaker": 0.4,
                                    "minimum_should_match": 1,
                                    "operator": "OR"
                                }
                            }
                        ],
                        "filter": [
                            {
                                "multi_match": {
                                    "query": f"{query}",
                                    "analyzer": "lowercase_whitespace_tokens",
                                    "fields": [
                                        "query.id^5",
                                        "query.sourceIdentifier.value^5",
                                        "query.referenceNumber*^5",
                                        "query.identifiers.value",
                                        "query.items.id",
                                        "query.items.identifiers.value",
                                        "query.items.shelfmark*",
                                        "query.images.id",
                                        "query.images.identifiers.value",
                                        "query.collectionPath*"
                                    ],
                                    "type": "cross_fields",
                                    "operator": "OR",
                                    "minimum_should_match": 1
                                }
                            }
                        ]
                    }
                }
            ]
        }
    }


def get_basic_query(query: str, fields: list[str], *args) -> dict:
    return {
        "bool": {
            "should": [
                {"match": {field: {"query": query}}}
                for field in fields
            ]
        }
    }


def get_text_expansion_query(query: str, fields: list[str]) -> dict:
    return {
        "bool": {
            "should": [
                {
                    "text_expansion": {
                        field: {
                            "model_id": ".elser_model_2_linux-x86_64",
                            "model_text": query
                        }
                    }
                }
                for field in fields
            ]
        }
    }


def get_full_semantic_query(query: str, fields: list[str]) -> dict:
    full_query = get_production_query(query)
    semantic_query = {
        "bool": {
            "should": [
                {"match": {field: {"query": query}}}
                for field in fields
            ]
        }
    }

    semantic_with_lax_text = {
        "bool": {
            "_name": "text_lax_with_semantics",
            "must": [
                {
                    "multi_match": {
                        "_name": "text_lax",
                        "query": query,
                        "fields": [
                            "query.title.*^5",
                            "query.title.cased^10",
                            "query.contributors.agent.label^10",
                            "query.subjects.concepts.label^10",
                            "query.genres.concepts.label^10",
                            "query.production.label.*^10",
                            "query.partOf.title.*^10",
                            "query.alternativeTitles.*",
                            "query.description.*",
                            "query.edition",
                            "query.languages.label",
                            "query.lettering.*",
                            "query.notes.contents.*",
                            "query.physicalDescription.*"
                        ],
                        "type": "cross_fields",
                        "operator": "OR",
                        "minimum_should_match": 1
                    }
                }
            ],
            "should": [semantic_query]
        }
    }

    full_query["bool"]["should"].append(semantic_with_lax_text)
    return full_query


def get_basic_knn_query(query: str, fields: list[str], model_id: str, *args) -> dict:
    search_query = {
        "bool": {
            "should": []
        }
    }

    for field in fields:
        search_query["bool"]["should"].append({
            "knn": {
                "field": field,
                "k": 50,
                "num_candidates": 500,
                # "boost": 1.0,
                "query_vector_builder": {
                    "text_embedding": {
                        "model_id": model_id,
                        "model_text": query
                    }
                }
            }
        }
        )

    return search_query


def get_openai_knn_query(query: str, fields: list[str], *args):
    return get_basic_knn_query(query, fields, "openai-text_embedding-muvikv9j5f")


def get_titan_knn_query(query: str, fields: list[str], *args):
    return get_basic_knn_query(query, fields, "amazon-bedrock-titan-embeddings")


def get_basic_sparse_vector_query(query: str, fields: list[str]) -> dict:
    search_query = {
        "bool": {
            "should": [],
            #   "minimum_should_match": math.ceil(len(fields) / 2)
        }
    }

    for field in fields:
        search_query["bool"]["should"].append({
            "sparse_vector": {
                "field": field,
                "query": query,
                "prune": True,
                "pruning_config": {
                    "tokens_freq_ratio_threshold": 2,
                    "tokens_weight_threshold": 0.4,
                    "only_score_pruned_tokens": False
                }
            }
        }
        )

    return search_query


def get_rrf_query(query: str, semantic_query, min_score) -> dict:
    full_query = get_production_query(query)

    return {
        "retriever": {
            "rrf": {
                "retrievers": [
                    {
                        "standard": {
                            "query": full_query
                        }
                    },
                    {
                        "standard": {
                            "query": semantic_query,
                            "min_score": min_score,
                        }
                    }
                ],
                "rank_window_size": SIZE,
                "rank_constant": 20
            }
        }
    }


def get_rrf_query_with_min_should_match(query: str, fields: list[str], min_score) -> dict:
    semantic_query = get_basic_sparse_vector_query(query, fields)
    return get_rrf_query(query, semantic_query, min_score)

def get_rrf_query_open_ai(query: str, fields: list[str], min_score) -> dict:
    semantic_query = get_openai_knn_query(query, fields)
    return get_rrf_query(query, semantic_query, min_score)


def get_rrf_query_with_multi_match(query: str, fields: list[str], min_score) -> dict:
    semantic_query = get_full_semantic_query(query, fields)
    return get_rrf_query(query, semantic_query, min_score)

In [368]:
def in_colour(text: str, colour: str):
    print(f"{colour}{text}{RESET_COLOR}", end=" ")


def get_es_request_body(query: str, config: dict):
    body = {"size": SIZE, "track_total_hits": True}

    full_query = config["get_query_function"](query, config.get("semantic_fields"), config.get("semantic_min_score"))

    if "retriever" in full_query:
        body = {**full_query, **body}
    else:
        body["query"] = full_query

    return body


def get_combined_query_results(query: str):
    results = {}
    for config in TO_COMPARE:
        body = get_es_request_body(query, config)

        t = time.time()
        response = dict(ES_CLIENT.search(index=config["index"], body=body))
        print(config["index"], time.time() - t)
        results[config["label"]] = response

        hits = response["hits"]["hits"]
        results[config["label"]]["ranking"] = {hit["_id"]: i + 1 for i, hit in enumerate(hits)}
        results[config["label"]]["ids"] = set(hit["_id"] for hit in hits)

    return results


def compare_query_results(query: str):
    print(f"{BOLD}Query:{RESET} {query}")
    query_terms = query.split(" ")

    results = get_combined_query_results(query)
    print(f"{BOLD}Total results:{RESET}", end=" ")
    for config in TO_COMPARE:
        text = f"{config["label"]} {results[config["label"]]["hits"]["total"]["value"]}"
        in_colour(text, colour=config["colour"])

    print("\n")

    seen = set()
    for i in range(PRINT_LIMIT):
        print(f"{BOLD}————— {i + 1} —————{RESET}\n")
        for config in TO_COMPARE:
            hits = results[config["label"]]["hits"]["hits"]

            if len(hits) > i:
                hit = hits[i]
                work_id = hit["_id"]
                if work_id not in seen:
                    seen.add(work_id)
                    print(get_work_url(work_id))
                    print_bold(highlight_terms(hit["_source"]["display"]["title"], query_terms))
                    if "description" in hit["_source"]["display"]:
                        print(highlight_terms(hit["_source"]["display"]["description"], query_terms))

                    for config_2 in TO_COMPARE:
                        text = f"{config_2["label"]} {results[config_2["label"]]["ranking"].get(work_id, "-")}"
                        in_colour(text, colour=config_2["colour"])

                    print("\n")


def find_needle_in_haystack(query: str, work_id: str):
    for config in TO_COMPARE:
        body = get_es_request_body(query, config)
        response = dict(ES_CLIENT.search(index=config["index"], body=body))

        ids = [h["_id"] for h in response["hits"]["hits"]]

        if work_id in ids:
            in_colour(f"{config["label"]} {ids.index(work_id) + 1}", colour=config["colour"])
        else:
            in_colour(f"{config["label"]} -", colour=config["colour"])

    print("\n")

In [371]:
# 100K sample tests
ELSER_100K = {
    "label": "ELSER",
    "index": "works-elser-title-description",
    "get_query_function": get_basic_query,
    "colour": "\033[34m",
    "semantic_fields": ["query.titleSemantic", "query.descriptionSemantic"]
}

TITAN_100K = {
    "label": "Titan",
    "index": "works-titan-title-description",
    "get_query_function": get_titan_knn_query,
    "colour": "\033[91m",
    "semantic_fields": ["query.titleSemantic", "query.descriptionSemantic"]
}

OPEN_AI_SEMANTIC_100K = {
    "label": "OpenAI",
    "index": "works-openai-title-description",
    "get_query_function": get_openai_knn_query,
    "colour": "\033[93m",
    "semantic_fields": ["query.titleSemantic", "query.descriptionSemantic"]
}

NON_SEMANTIC_100K = {
    "label": "Control",
    "index": "works-titan-title-description",
    "get_query_function": get_basic_query,
    "colour": "\033[95m",
    "semantic_fields": ["query.title", "query.description"]
}

TO_COMPARE = [ELSER_100K, TITAN_100K, OPEN_AI_SEMANTIC_100K, NON_SEMANTIC_100K]
SIZE = 1000
PRINT_LIMIT = 10

# https://www.elastic.co/search-labs/blog/semantic-search-match-knn-sparse-vector
# "token pruning is about pruning irrelevant tokens to improve pefind_needle_in_haystack("flower magazine", "c2jj7zfd")rformance, not drastically change recall or relevance"
# Vector search is meant to improve recall. Lexical search will help with precision.

QUERY = "ancient manuscript on astronomy"
QUERY = "czech republic capital"
# QUERY = "surgery knife"

# Testing for problematic connections
# All models seem to connect 'savages' with 'Africa' (2)
# QUERY = "photos of savages"
# QUERY = "backward cultures"

# OpenAI seems to connect 'photos of inferior race' with 'Jewish' (3)
QUERY = "photos of inferior race"

#find_needle_in_haystack(QUERY, "a24brmcv")
compare_query_results(QUERY)

# 1) Both models improve recall and ranking
# 2) ELSER matches too many documents (low precision). Might not be a big issue.
# 3) Titan tends to make problematic connections

[1mQuery:[0m photos of inferior race
works-elser-title-description 2.488131046295166
works-titan-title-description 1.4267640113830566
works-openai-title-description 1.506113052368164
works-titan-title-description 1.5875251293182373
[1mTotal results:[0m [34mELSER 61682[39m [91mTitan 556[39m [93mOpenAI 561[39m [95mControl 52195[39m 

[1m————— 1 —————[0m

https://wellcomecollection.org/works/j5v38qn9
[1mPapers by Dicks on [92mrace[39m issues[0m
<p>Comprises:
</p><li>PP/HVD/E/2/1: “Psychological factors on prejudice”, draft [92mof[39m a paper published in [92mRace[39m Relations, 1959.</li>
<li>PP/HVD/E/2/2: Outline [92mof[39m lecture given at the London Hospital, June 1963.</li>
<li>PP/HVD/E/2/3: “[92mRace[39m prejudice”, notes for a lecture given in Birmingham, c.1960s.</li>
<li>PP/HVD/E/2/4: “Psychology [92mof[39m [92mrace[39m prejudice”, c.1960s.</li>
<li>PP/HVD/E/2/5: “Thoughts on the relation between psycho-analysis and social science”, paper given at Sus

In [363]:
# 1M full dataset tests using hybrid search (reciprocal rank fusion)
ELSER_1M_TITLE_DESCRIPTION = {
    "label": "ELSER",
    "index": "works-elser-full",
    "get_query_function": get_rrf_query_with_min_should_match,
    "colour": "\033[34m",
    "semantic_fields": ["query.titleSemantic", "query.descriptionSemantic"],
    "semantic_min_score": 10
}

OPENAI_1M_FULL = {
    "label": "OpenAI",
    "index": "works-openai-full",
    "get_query_function": get_rrf_query_open_ai,
    "colour": "\033[36m",
    "semantic_fields": ["query.titleSemantic", "query.descriptionSemantic"],
    "semantic_min_score": 0
}

ELSER_1M_FULL = {
    "label": "ELSER (full)",
    "index": "works-elser-full",
    "get_query_function": get_rrf_query_with_min_should_match,
    "colour": "\033[36m",
    "semantic_fields": ["query.titleSemantic", "query.descriptionSemantic", "query.alternativeTitlesSemantic",
                        "query.contributorsSemantic", "query.genresSemantic", "query.subjectsSemantic",
                        "query.notesSemantic"],
    "semantic_min_score": 30
}

NON_SEMANTIC_1M = {
    "label": "Prod",
    "index": "works-elser-full",
    "get_query_function": get_production_query,
    "colour": "\033[95m"
}

TO_COMPARE = [OPENAI_1M_FULL, NON_SEMANTIC_1M]
SIZE = 10000
PRINT_LIMIT = 10

QUERY = "ancient manuscript on astronomy"
# QUERY = "cardiac failure"
# QUERY = "lung neoplasm"

# All models are associate 'surgery knife' with sharp surgical tools (scalpel, saw, scissors) (4, 9)
# QUERY = "surgery knife"

# All models make a connection between 'Czech Republic capital' and Prague (2)
# QUERY = "czech republic capital"

# Connection between 'child doctor' and 'pediatrician' (7, 10)
# QUERY = "child doctor"

# OpenAI's model is the only one connecting 'Florence Nightingale' to 'Lady with the Lamp' (7, 8, 9)
# Other models frequently return low-relevance results mentioning women and lamps (6)
# OpenAI's model also lowers the ranking of low-relevance results returned by prod query (6)
# QUERY = "Lady with the Lamp"

# Production query returns 0 results
# ELSER returns low-relevance results (4)
# OpenAI's model is the only one returning specific dishes (6, 10)
# QUERY = "nutritious plant-based dish"

#find_needle_in_haystack(QUERY, "a24brmcv")
compare_query_results(QUERY)

[1mQuery:[0m ancient manuscript on astronomy
works-openai-full 1.1467540264129639
works-elser-full 0.29523301124572754
[1mTotal results:[0m [36mOpenAI 550[39m [95mProd 11[39m 

[1m————— 1 —————[0m

https://wellcomecollection.org/works/yupgnk6b
[1m[92mOn[39m the acti[92mon[39m and influence of the mo[92mon[39m / Sergius of Resaina, [translated by] Joseph Zolin.[0m
[36mOpenAI 1[39m [95mProd 1[39m 

[1m————— 2 —————[0m

https://wellcomecollection.org/works/ce7r34bj
[1mCollecti[92mon[39m of texts relating to astrology by Albumasar, John of Seville and Ibn al-Saffar[0m
<p>Collecti[92mon[39m of texts relating to astrology by Albumasar, John of Seville and Ibn al-Saffar, in Latin, copied by the unidentified surge[92mon[39m Guidotus of Vicenza in Northern Italy; with decorated initials and rubricati[92mon[39m.</p>
<p><b>C[92mon[39mtents</b> (described according to the more recent modern foliati[92mon[39m; earlier foliati[92mon[39m recorded in brackets):<

In [343]:
# QUERY = "consumption"
# QUERY = "tuberculosis"

QUERY = "smart large black bird"
QUERY = "large bird"  # A large tree with a small bird flying towards it (prod), non-semantic struggles with adjectives in between
QUERY = "london gardens unesco"
QUERY = "着物"
QUERY = "cardiac failure"
QUERY = "lung neoplasm nerve inflammation"

In [241]:
# Searching for animals by describing prominent features works

# https://wellcomecollection.org/works/nfgzazqm/images?id=wzr92r6d
find_needle_in_haystack("mouse with long nose", "nfgzazqm")

# https://wellcomecollection.org/works/d8qqspwv/images?id=wjzph6wv
find_needle_in_haystack("fox species with large ears", "d8qqspwv")

# https://wellcomecollection.org/works/njacsf2g/items
find_needle_in_haystack("large flightless bird", "njacsf2g")

[34mELSER 143[39m [36mOpenAI 13[39m [95mProd -[39m 

[34mELSER -[39m [36mOpenAI 80[39m [95mProd -[39m 

[34mELSER -[39m [36mOpenAI 81[39m [95mProd -[39m 



In [253]:
# Synonyms/descriptions

find_needle_in_haystack("pig mum with babies", "xkm6ubyq")

# https://wellcomecollection.org/works/dpte8snu/items
find_needle_in_haystack("ant baby", "dpte8snu")

# https://wellcomecollection.org/works/ag3zz4dx/images?id=a2yxbhw5&resultPosition=16
find_needle_in_haystack("bear eating seal", "ag3zz4dx")

# https://wellcomecollection.org/works/rt7bk7dt/images?id=a57y2s4z&resultPosition=29
find_needle_in_haystack("dog eye", "rt7bk7dt")

[34mELSER 274[39m [36mOpenAI 82[39m [95mProd -[39m 

[34mELSER 172[39m [36mOpenAI 121[39m [95mProd -[39m 

[34mELSER 18[39m [36mOpenAI 61[39m [95mProd -[39m 

[34mELSER 26[39m [36mOpenAI 4[39m [95mProd -[39m 



In [208]:
# https://wellcomecollection.org/works/a227y9ye
find_needle_in_haystack("the blitz", "a227y9ye")

[34mELSER -[39m [36mOpenAI 10[39m [95mProd -[39m 



In [209]:
# https://wellcomecollection.org/works/a24brmcv
find_needle_in_haystack("how to make meth", "a24brmcv")

[34mELSER -[39m [36mOpenAI 2[39m [95mProd -[39m 



In [243]:
# https://wellcomecollection.org/works/jvbc3r5f/images?id=ab5ywfpy
find_needle_in_haystack("man riding a pig", "jvbc3r5f")

# https://wellcomecollection.org/works/jvbc3r5f/images?id=ab5ywfpy
find_needle_in_haystack("butcher riding a pig", "jvbc3r5f")

[34mELSER -[39m [36mOpenAI 46[39m [95mProd -[39m 

[34mELSER 6[39m [36mOpenAI 1[39m [95mProd -[39m 



In [237]:
# OpenAI's model is multilingual

# https://wellcomecollection.org/works/b5kqccbb
find_needle_in_haystack("animal anatomy treatise", "b5kqccbb")

find_needle_in_haystack("boucher sur un cochon", "jvbc3r5f")
find_needle_in_haystack("řezník na praseti", "jvbc3r5f")

# Traité des maladies du coeur
find_needle_in_haystack("treaty heart diseases", "a239wxjg")

[34mELSER -[39m [36mOpenAI 72[39m [95mProd -[39m 

[34mELSER -[39m [36mOpenAI 15[39m [95mProd -[39m 

[34mELSER -[39m [36mOpenAI 4[39m [95mProd -[39m 

[34mELSER -[39m [36mOpenAI 24[39m [95mProd -[39m 



In [301]:
find_needle_in_haystack("tanuki", "pwjrcz4t")

[34mELSER -[39m [36mOpenAI -[39m [95mProd -[39m 



In [254]:



# Testing for problematic connections

QUERY = "photos of savages"
# QUERY = "backward cultures"
# QUERY = "photos of inferior race"

#find_needle_in_haystack(QUERY, "a24brmcv")
compare_query_results(QUERY)

[1mQuery:[0m photos of savages
[1mTotal results:[0m [34mELSER 198[39m [36mOpenAI 606[39m [95mProd 8[39m 

[1m————— 1 —————[0m

https://wellcomecollection.org/works/w5phw38q
[1mCorrespondence with organisations supported by Savage[0m
Correspondence with organisations which Savage supported. Much [92mof[39m the correspondence is to express support and apologise for being oversubscribed and therefore unable to become more involved in the organisation.
[34mELSER 1[39m [36mOpenAI -[39m [95mProd -[39m 

https://wellcomecollection.org/works/sx5g9jjh
[1mPhotographs[0m
A small number [92mof[39m photographs [92mof[39m CJS 6 clinical photographs, possibly taken in CJS in Abyssinia in 1904. Also includes a photographic plate mounted to wood [92mof[39m a portrait [92mof[39m an unidentified man.
[34mELSER -[39m [36mOpenAI 1[39m [95mProd -[39m 

https://wellcomecollection.org/works/tketcabe
[1mCardiovascular physiology in the sixteenth and early seventeenth cent

In [302]:
%pip install detoxify


Collecting detoxify
  Downloading detoxify-0.5.2-py3-none-any.whl.metadata (13 kB)
Collecting transformers (from detoxify)
  Downloading transformers-5.1.0-py3-none-any.whl.metadata (31 kB)
Collecting torch>=1.7.0 (from detoxify)
  Downloading torch-2.10.0-1-cp312-none-macosx_11_0_arm64.whl.metadata (31 kB)
Collecting sentencepiece>=0.1.94 (from detoxify)
  Downloading sentencepiece-0.2.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (10 kB)
Collecting filelock (from torch>=1.7.0->detoxify)
  Downloading filelock-3.20.3-py3-none-any.whl.metadata (2.1 kB)
Collecting setuptools (from torch>=1.7.0->detoxify)
  Downloading setuptools-82.0.0-py3-none-any.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch>=1.7.0->detoxify)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch>=1.7.0->detoxify)
  Downloading networkx-3.6.1-py3-none-any.whl.metadata (6.8 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch>=1.7.0->detoxify)
  Downloadi

In [303]:
from detoxify import Detoxify


  from .autonotebook import tqdm as notebook_tqdm


In [306]:
results = Detoxify('multilingual').predict(['example text','exemple de texte','texto de ejemplo','testo di esempio','texto de exemplo','örnek metin','пример текста'])
results

Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 1276.72it/s, Materializing param=roberta.encoder.layer.11.output.dense.weight]
[1mXLMRobertaForSequenceClassification LOAD REPORT[0m from: None
Key                             | Status     |  | 
--------------------------------+------------+--+-
roberta.embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


{'toxicity': [0.00019621763203758746,
  0.0005629529478028417,
  0.0006275926134549081,
  0.000991594628430903,
  0.0006123465136624873,
  0.0005906997248530388,
  0.0005345010431483388],
 'severe_toxicity': [0.0001925499818753451,
  0.004809971898794174,
  0.0031421624589711428,
  0.004266710486263037,
  0.0015480158617720008,
  0.002832299331203103,
  0.003825804451480508],
 'obscene': [0.001262642559595406,
  0.030098777264356613,
  0.022164227440953255,
  0.03333602473139763,
  0.012159483507275581,
  0.021325847133994102,
  0.026461614295840263],
 'identity_attack': [0.0003226218977943063,
  0.005532296374440193,
  0.003466877853497863,
  0.0054079694673419,
  0.0018412952776998281,
  0.0031340194400399923,
  0.004029527772217989],
 'insult': [0.0008828384452499449,
  0.026441361755132675,
  0.017626892775297165,
  0.028724247589707375,
  0.009404674172401428,
  0.017864996567368507,
  0.02190164290368557],
 'threat': [0.00013756829139310867,
  0.0021339100785553455,
  0.001496575

In [310]:
results = Detoxify('multilingual').predict('photos of inferior race')
print(results)

Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 1254.60it/s, Materializing param=roberta.encoder.layer.11.output.dense.weight]
[1mXLMRobertaForSequenceClassification LOAD REPORT[0m from: None
Key                             | Status     |  | 
--------------------------------+------------+--+-
roberta.embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


{'toxicity': np.float32(0.050552037), 'severe_toxicity': np.float32(0.00045800567), 'obscene': np.float32(0.004736124), 'identity_attack': np.float32(0.007991177), 'insult': np.float32(0.011319938), 'threat': np.float32(0.00044642107), 'sexual_explicit': np.float32(0.0008477475)}


In [320]:
d = Detoxify('unbiased')



Downloading: "https://github.com/unitaryai/detoxify/releases/download/v0.3-alpha/toxic_debiased-c7548aa0.ckpt" to /Users/brychtas/.cache/torch/hub/checkpoints/toxic_debiased-c7548aa0.ckpt


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 476M/476M [01:12<00:00, 6.85MB/s]
Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 1334.73it/s, Materializing param=roberta.encoder.layer.11.output.dense.weight]
[1mRobertaForSequenceClassification LOAD REPORT[0m from: None
Key                             | Status     |  | 
--------------------------------+------------+--+-
roberta.embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


In [315]:
dm = Detoxify('multilingual')

Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 1312.38it/s, Materializing param=roberta.encoder.layer.11.output.dense.weight]
[1mXLMRobertaForSequenceClassification LOAD REPORT[0m from: None
Key                             | Status     |  | 
--------------------------------+------------+--+-
roberta.embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


In [321]:
d.predict('photos of savages')

{'toxicity': np.float32(0.027025623),
 'severe_toxicity': np.float32(6.9970283e-06),
 'obscene': np.float32(0.0004700063),
 'identity_attack': np.float32(0.0009840501),
 'insult': np.float32(0.010121866),
 'threat': np.float32(0.0002726626),
 'sexual_explicit': np.float32(0.00019775685)}