In [None]:
%pip install -qU langchain-text-splitters transformers pandas numpy

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [None]:
import numpy as np
import pandas as pd

In [40]:
from transformers import AutoTokenizer
from langchain_text_splitters import CharacterTextSplitter

# Load a tokenizer for a BERT-like model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def count_tokens(text):
    return len(tokenizer.encode(text))

text_splitter = CharacterTextSplitter(
    separator="\n",
    chunk_size=400,
    chunk_overlap=50,
    # length_function=len,
    length_function=count_tokens,
    # is_separator_regex=False,
)

In [41]:
def chunk(text):
    texts = text_splitter.create_documents([text])
    chunks = []
    for i,t in enumerate(texts):
        # replace newlines with spaces this can help keep word boundires
        chunks.append(t.page_content.replace("\n", " "))
    return chunks

In [None]:
df = pd.read_csv('./datasets/text_classes.csv', delimiter='|')
# 4353
# display(df)

In [None]:
from tables import classes, schema_dicts
print(classes)
print(schema_dicts)

['Intellectual Developmental Disorders', 'Communication Disorders', 'Autism Spectrum Disorder', 'Attention-Deficit/Hyperactivity Disorder', 'Specific Learning Disorder', 'Motor Disorders', 'Other Neurodevelo mental Disorders p', 'Schizophrenia Spectrum and Other Psychotic Disorders', 'Catatonia', 'Bipolar and Related Disorders', 'Depressive Disorders', 'Anxiety Disorders', 'Obsessive-Compulsive and Related Disorders', 'Trauma- and Stressor-Related Disorders', 'Dissociative Disorders', 'Somatic Symptom and Related Disorders', 'Feeding and Eating Disorders', 'Elimination Disorders', 'Sleep-Wake Disorders', 'Breathing-Related Sleep Disorders', 'Parasomnias', 'Sexual Dysfunctions', 'Gender Dysphoria', 'Disruptive, Impulse-Control, and Conduct Disorders', 'Neurocognitive Disorders', 'Personality Disorders', 'Cluster A Personality Disorders', 'Cluster B Personality Disorders', 'Cluster C Personalit Disorders y', 'Other Personality Disorders', 'Paraphilic Disorders', 'Other Mental Disorders a

In [None]:
schema_index = 3
max_avg_distance = 8

In [60]:
import weaviate
import weaviate.classes as wvc

def query_top_five(text, limit=1, max_distance=8):
    client = weaviate.connect_to_local()
    res = None
    try:
        collection = client.collections.use(schema_dicts[schema_index]["schema"])

        results = collection.query.near_text(
            # Because near_text is a purely vector search, we get a
            # distance and no score
            return_metadata=wvc.query.MetadataQuery(
                score=True, explain_score=True, distance=True, certainty=True
            ),
            query=text,
            limit=limit
        )
        for r in results.objects:
            if r.metadata.distance < max_distance:
                res = [classes.index(r.properties['title']), r.metadata.distance]

        return res
    finally:
        client.close()

In [67]:
item = 1

chart_chunks = chunk(df.iloc[item]["chart_text"])
print(f"Chunks size: {len(chart_chunks)}")

scores = []
for c in chart_chunks:
    d = query_top_five(c)
    if d is not None:
        scores.append(d)

print(f"Scores size: {len(scores)}")

Chunks size: 53
Scores size: 14


In [68]:
def score_document(s, max=8):
    cat = {}
    top_results = []

    for i, _ in enumerate(classes):
        cat[i] = {"count":0, "distance":0}

    for _, d in enumerate(s):
        cat[d[0]]["count"] += 1
        cat[d[0]]["distance"] += d[1]
        cat[d[0]]["avg"] = cat[d[0]]["distance"] / cat[d[0]]["count"]

    for i, _ in enumerate(classes):
        if cat[i]["count"] > 0 and cat[i]["avg"] < max:
            top_results.append([cat[i]["avg"], i])

    return top_results

top = np.array(score_document(scores, max_avg_distance))
print(np.sort(top, axis=0))
print('--------------')
print(f'{df.iloc[item]["class_number"]:02} {classes[df.iloc[item]["class_number"]]}')

[[ 6.33185816 10.        ]
 [ 6.40727139 16.        ]
 [ 6.76102495 23.        ]
 [ 7.18070253 33.        ]]
--------------
10 Depressive Disorders


In [None]:
test_range = 20

results = []
for i in range(test_range): # 4353//4):
    chart_chunks = chunk(df.iloc[i]["chart_text"])
    print(f"Chunks size: {len(chart_chunks)}")

    scores = []
    for c in chart_chunks:
        d = query_top_five(c)
        if d is not None:
            scores.append(d)

    top = np.array(score_document(scores, max_avg_distance))
    top = np.sort(top, axis=0)

    results.append([i, df.iloc[i]["class_number"], top])

Chunks size: 8
Chunks size: 53
Chunks size: 29
Chunks size: 19
Chunks size: 701
Chunks size: 91
Chunks size: 43
Chunks size: 48
Chunks size: 218
Chunks size: 117
Chunks size: 2000
Chunks size: 130
Chunks size: 62
Chunks size: 297
Chunks size: 166
Chunks size: 81
Chunks size: 393
Chunks size: 522
Chunks size: 132
Chunks size: 1038


In [None]:
df_results = pd.DataFrame(results, columns=["record", "original", "predict"])

def top5(x):
    return np.array([y[1] for i,y in enumerate(x) if i < 5])

def in_top_5(orig, top5):
    return 1 if orig in top5 else 0

df_results['top5'] = df_results['predict'].map(top5)
df_results['inTop5'] = df_results.apply(lambda x: in_top_5(orig = x['original'], top5 = x['top5']), axis=1)

print(df_results['inTop5'].sum() / test_range)

display(df_results)

0.55


Unnamed: 0,record,original,predict,top5,inTop5
0,0,11,"[[4.398091475168864, 23.0]]",[23.0],0
1,1,10,"[[6.331858158111572, 10.0], [6.407271385192871...","[10.0, 16.0, 23.0, 33.0]",1
2,2,11,"[[6.395759582519531, 10.0], [7.718908548355102...","[10.0, 17.0]",0
3,3,10,"[[6.660593032836914, 20.0], [7.800957679748535...","[20.0, 23.0]",0
4,4,10,"[[6.194525814056396, 8.0], [6.391502857208252,...","[8.0, 9.0, 10.0, 16.0, 23.0]",1
5,5,10,"[[5.214808327811105, 10.0], [7.492531061172485...","[10.0, 16.0, 23.0]",1
6,6,10,"[[5.892345142364502, 9.0], [7.3724493980407715...","[9.0, 10.0, 16.0, 23.0, 28.0]",1
7,7,10,"[[6.818415641784668, 7.0], [6.915680090586345,...","[7.0, 8.0, 10.0, 16.0, 18.0]",1
8,8,10,"[[4.48370627562205, 7.0], [4.703305721282959, ...","[7.0, 9.0, 10.0, 16.0, 17.0]",1
9,9,11,"[[6.082551956176758, 8.0], [6.339075326919556,...","[8.0, 9.0, 10.0, 16.0, 17.0]",0
