In [1]:
%pip install -qU langchain-text-splitters transformers pandas

Note: you may need to restart the kernel to use updated packages.


In [2]:
from transformers import AutoTokenizer
from langchain_text_splitters import CharacterTextSplitter

# Load a tokenizer for a BERT-like model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def count_tokens(text):
    return len(tokenizer.encode(text))

text_splitter = CharacterTextSplitter(
    separator="\n",
    chunk_size=400,
    chunk_overlap=50,
    # length_function=len,
    length_function=count_tokens,
    # is_separator_regex=False,
)

In [3]:
def chunk(text):
    texts = text_splitter.create_documents([text])
    chunks = []
    for i,t in enumerate(texts):
        # replace newlines with spaces this can help keep word boundires
        chunks.append(t.page_content.replace("\n", " "))
    return chunks

In [18]:
import pandas as pd
df = pd.read_csv('./datasets/text_classes.csv', delimiter='|')
# 4353
# display(df)

In [6]:
classes = [
    "Intellectual Developmental Disorders",                        # 0
    "Communication Disorders",
    "Autism Spectrum Disorder",
    "Attention-Deficit/Hyperactivity Disorder",
    "Specific Learning Disorder",
    "Motor Disorders",
    "Other Neurodevelo mental Disorders p",
    "Schizophrenia Spectrum and Other Psychotic Disorders",
    "Catatonia",
    "Bipolar and Related Disorders",
    "Depressive Disorders",                                        # 10
    "Anxiety Disorders",
    "Obsessive-Compulsive and Related Disorders",
    "Trauma- and Stressor-Related Disorders",
    "Dissociative Disorders",
    "Somatic Symptom and Related Disorders",
    "Feeding and Eating Disorders",
    "Elimination Disorders",
    "Sleep-Wake Disorders",
    "Breathing-Related Sleep Disorders",
    "Parasomnias",                                                # 20
    "Sexual Dysfunctions",
    "Gender Dysphoria",
    "Disruptive, Impulse-Control, and Conduct Disorders",
    "Neurocognitive Disorders",
    "Personality Disorders",
    "Cluster A Personality Disorders",
    "Cluster B Personality Disorders",
    "Cluster C Personalit Disorders y",
    "Other Personality Disorders",
    "Paraphilic Disorders",                                       # 30
    "Other Mental Disorders and Additional Codes",
    "Additional Codes",
    "Medication-Induced Movement Disorders and Other Adverse Effects of Medication",
    "Other Conditions That May Be a Focus of Clinical Attention", # 34
]

In [7]:
schema_dicts = [
    { "name": "Distilbert", "url": "http://t2v-transformers:8080", "schema": "DSMDistilbert" },
    { "name": "Roberta", "url": "http://t2v-transformers-drobert:8080", "schema": "DSMRoberta" },
    { "name": "Biobert", "url": "http://t2v-transformers-biobert:8080", "schema": "DSMBiobert" },
    { "name": "Clinicalbert", "url": "http://t2v-transformers-clicbert:8080", "schema": "DSMClinicalbert" },
]
schema_index = 0

In [8]:
import weaviate
import weaviate.classes as wvc

def query_top_five(text, limit=1):
    client = weaviate.connect_to_local()
    res = None
    try:
        collection = client.collections.use(schema_dicts[schema_index]["schema"])

        results = collection.query.near_text(
            # Because near_text is a purely vector search, we get a
            # distance and no score
            return_metadata=wvc.query.MetadataQuery(
                score=True, explain_score=True, distance=True, certainty=True
            ),
            query=text,
            limit=limit
        )
        for r in results.objects:
            if r.metadata.distance < 8:
                res = [classes.index(r.properties['title']), r.metadata.distance]

        return res
    finally:
        client.close()

In [16]:
item = 4

chart_chunks = chunk(df.iloc[item]["chart_text"])
print(f"Chunks size: {len(chart_chunks)}")

scores = []
for c in chart_chunks:
    d = query_top_five(c)
    if d is not None:
        scores.append(d)

print(f"Scores size: {len(scores)}")

Chunks size: 701
Scores size: 552


In [17]:
def score_document(s):
    cat = {}

    for i, cls in enumerate(classes):
        cat[i] = {"count":0, "distance":0}

    for c, d in enumerate(s):
        cat[d[0]]["count"] += 1
        cat[d[0]]["distance"] += d[1]
        cat[d[0]]["avg"] = cat[d[0]]["distance"] / cat[d[0]]["count"]

    for i, cls in enumerate(classes):
        if cat[i]["count"] > 0 and cat[i]["avg"] < 8:
            print(f'{i:02} {cat[i]["avg"]} {classes[i]}')

score_document(scores)
print('--------------')
print(f'{df.iloc[item]["class_number"]:02} {classes[df.iloc[item]["class_number"]]}')

09 6.637991837092808 Bipolar and Related Disorders
10 5.920181900262833 Depressive Disorders
11 7.039983749389648 Anxiety Disorders
16 7.254909873008728 Feeding and Eating Disorders
17 6.788992881774902 Elimination Disorders
19 6.615343553048593 Breathing-Related Sleep Disorders
21 6.951370716094971 Sexual Dysfunctions
23 6.156726730042609 Disruptive, Impulse-Control, and Conduct Disorders
24 6.5375165505842725 Neurocognitive Disorders
34 5.394190788269043 Other Conditions That May Be a Focus of Clinical Attention
--------------
10 Depressive Disorders
