In [88]:
%pip install langchain-text-splitters transformers pandas numpy tqdm

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [89]:
import numpy as np
import pandas as pd

In [90]:
from transformers import AutoTokenizer
from langchain_text_splitters import CharacterTextSplitter

# Load a tokenizer for a BERT-like model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def count_tokens(text):
    return len(tokenizer.encode(text))

text_splitter = CharacterTextSplitter(
    separator="\n",
    chunk_size=400,
    chunk_overlap=50,
    # length_function=len,
    length_function=count_tokens,
    # is_separator_regex=False,
)

In [91]:
def chunk(text):
    texts = text_splitter.create_documents([text])
    chunks = []
    for i,t in enumerate(texts):
        # replace newlines with spaces this can help keep word boundires
        chunks.append(t.page_content.replace("\n", " "))
    return chunks

In [92]:
df_p = pd.read_csv('./datasets/positive_w_text.csv', delimiter='|')
df_n = pd.read_csv('./datasets/negitive_w_text.csv', delimiter='|')

print(f'Positive Count: {len(df_p)}')
print(f'Negitive Count: {len(df_n)}')

df_p = df_p.drop(columns=["SUBJECT_ID"])
df_n = df_n.drop(columns=["SUBJECT_ID"])

df = pd.concat([df_p, df_n], ignore_index=True)
# df.reset_index(inplace=True)

# This shuffles the dataframe, but in a non-reproducable way
# df = df.sample(frac=1).reset_index(drop=True)
# Instead shuffle dataframe, but in a reproducable way
SEED=42
np.random.seed(SEED)
idx = df.index.to_list()

np.random.shuffle(idx)
df = df.loc[idx].reset_index(drop=True)

# df = df.sample(frac=1).reset_index(drop=True)
display(df)

Positive Count: 800
Negitive Count: 1000


Unnamed: 0,class_number,chart_text
0,0,"Valve replacement, aortic bioprosthetic (AVR)\..."
1,0,"66 year-old male with pmh of dm, copd, obesity..."
2,0,"51 yo male with emphysema, OSA, obesity admitt..."
3,10,"male w/ PMHx sig for HTN, COPD, dementia, afib..."
4,0,Subdural hemorrhage (SDH)\n Assessment:\n ...
...,...,...
1795,0,"Pt was helmeted motorcycle driver, drove into ..."
1796,0,"Acute Pain\n Assessment:\n Pt A&Ox3 MAE, p..."
1797,0,Mr. [**Known lastname 4345**] is a 35 yo male ...
1798,0,"Shock, septic\n Assessment:\n Post fluid b..."


In [93]:
from tables import classes, schema_dicts
print(classes)
print(schema_dicts)

['Intellectual Developmental Disorders', 'Communication Disorders', 'Autism Spectrum Disorder', 'Attention-Deficit/Hyperactivity Disorder', 'Specific Learning Disorder', 'Motor Disorders', 'Other Neurodevelo mental Disorders p', 'Schizophrenia Spectrum and Other Psychotic Disorders', 'Catatonia', 'Bipolar and Related Disorders', 'Depressive Disorders', 'Anxiety Disorders', 'Obsessive-Compulsive and Related Disorders', 'Trauma- and Stressor-Related Disorders', 'Dissociative Disorders', 'Somatic Symptom and Related Disorders', 'Feeding and Eating Disorders', 'Elimination Disorders', 'Sleep-Wake Disorders', 'Breathing-Related Sleep Disorders', 'Parasomnias', 'Sexual Dysfunctions', 'Gender Dysphoria', 'Disruptive, Impulse-Control, and Conduct Disorders', 'Neurocognitive Disorders', 'Personality Disorders', 'Cluster A Personality Disorders', 'Cluster B Personality Disorders', 'Cluster C Personalit Disorders y', 'Other Personality Disorders', 'Paraphilic Disorders', 'Other Mental Disorders a

**Schema Index**
| ID | Model | Port |
|----|-------|------|
| 0 | Distilbert | 9090 |
| 1 | Roberta | 9091 |
| 2 | Biobert | 9092 |
| 3 | Clinicalbert | 9093 |

In [94]:
# Which BERT model to use
schema_index = 3
# When using vector distance, this is the farthest
# a vector can be to be considered
max_avg_distance = 20
# Total number of documents to process
# For example, we grab 2000 positive, 2000 negitive 
# shuffle them, and then process this many total
test_range = len(df) # 2000 

print(schema_index, max_avg_distance, test_range)

3 20 1800


In [95]:
import weaviate
import weaviate.classes as wvc

def query_top_x(text, limit=1, max_distance=20):
    client = weaviate.connect_to_local()
    res = None
    try:
        collection = client.collections.use(schema_dicts[schema_index]["schema"])

        results = collection.query.near_text(
            # Because near_text is a purely vector search, we get a
            # distance and no score
            return_metadata=wvc.query.MetadataQuery(
                score=True, explain_score=True, distance=True, certainty=True
            ),
            query=text,
            limit=limit
        )
        for r in results.objects:
            if r.metadata.distance < max_distance:
                res = [classes.index(r.properties['title']), r.metadata.distance]

        return res
    finally:
        client.close()

In [96]:
def score_document(s, max_dist=10):
    cat = {}
    top_results = []

    for i, _ in enumerate(classes):
        cat[i] = {"count":0, "distance":0}

    for _, d in enumerate(s):
        cat[d[0]]["count"] += 1
        cat[d[0]]["distance"] += d[1]
        cat[d[0]]["avg"] = cat[d[0]]["distance"] / cat[d[0]]["count"]

    for i, _ in enumerate(classes):
        if cat[i]["count"] > 0 and cat[i]["distance"] < max_dist:
            top_results.append([cat[i]["avg"], i])

    return top_results

In [97]:
from tqdm import tqdm
results = []
for i in tqdm(range(test_range)):
    text = str(df.iloc[i]["chart_text"])

    # print(type(text), len(text))

    if type(text) == str:
        chart_chunks = chunk(text[:500000])
        # print(f"Chart: {len(text)} \t Chunks: {len(chart_chunks)}")
        scores = []
        for c in chart_chunks:
            d = query_top_x(c, 5)
            if d is not None:
                scores.append(d)

        top = np.array(score_document(scores, max_avg_distance))
        top = np.sort(top, axis=0)

        results.append([i, df.iloc[i]["class_number"], top])

100%|██████████| 1800/1800 [26:23:17<00:00, 52.78s/it]    


## Example Scoring

Correct answer in the top 5 - just to get a view of the results. See 5_results.ipynb for the full results process

In [98]:
df_results = pd.DataFrame(results, columns=["record", "original", "predict"])

def top5(x):
    return np.array([y[1] for i,y in enumerate(x) if i < 5])

def in_top_5(orig, top5):
    return 1 if orig in top5 else 0

def correct_predict(orig, top5):
    if orig == 0 and 11 not in top5 and 10 not in top5:
        return 1
    return 1 if orig in top5 else 0

df_results['top5'] = df_results['predict'].map(top5)
df_results['inTop5'] = df_results.apply(lambda x: in_top_5(orig=x['original'], top5=x['top5']), axis=1)
df_results['correct'] = df_results.apply(lambda x: correct_predict(orig=x['original'], top5=x['top5']), axis=1)

print(f"Correct top 5    {df_results['inTop5'].sum() / test_range}")
print(f"Correct category {df_results['correct'].sum() / test_range}")
display(df_results)

# Sample size 30
# i     %   MMR
# 0   .70  .20
# 1   .60  .16
# 2   .53  .14
# 3   .66  .18


Correct top 5    0.28833333333333333
Correct category 0.3661111111111111


Unnamed: 0,record,original,predict,top5,inTop5,correct
0,0,0,"[[0.09662508964538574, 7.0], [0.10977473855018...","[7.0, 8.0, 9.0, 10.0, 11.0]",0,0
1,1,0,"[[0.1388481685093471, 7.0], [0.139701657825046...","[7.0, 8.0, 9.0, 10.0, 13.0]",0,0
2,2,0,"[[0.12216955423355103, 3.0], [0.13968069496608...","[3.0, 7.0, 8.0, 9.0, 10.0]",0,0
3,3,10,"[[0.14886939525604248, 16.0], [0.1796432733535...","[16.0, 18.0, 19.0, 20.0, 23.0]",0,0
4,4,0,"[[0.11608517169952393, 7.0], [0.12344068288803...","[7.0, 8.0, 9.0, 10.0, 16.0]",0,0
...,...,...,...,...,...,...
1795,1795,0,"[[0.1316316823164622, 1.0], [0.140423774719238...","[1.0, 3.0, 5.0, 7.0, 8.0]",0,1
1796,1796,0,"[[0.1611930951476097, 7.0], [0.161896912947944...","[7.0, 9.0, 10.0, 11.0, 12.0]",0,0
1797,1797,0,"[[0.12209852933883666, 8.0], [0.12868483662605...","[8.0, 9.0, 10.0, 16.0, 18.0]",0,0
1798,1798,0,"[[0.12893107533454895, 8.0], [0.13033321729073...","[8.0, 9.0, 10.0, 11.0, 12.0]",0,0


### Mean Reciprocal Rank (MRR)

The position of the first correct result is important

For each query, find the rank of the first correct result. The Reciprocal Rank (RR) is $1 / rank$. MRR is the average RR across all queries.

Example: Suppose you have 3 queries:
- Query 1: Correct result is ranked 1st → RR = 1/1 = 1.0
- Query 2: Correct result is ranked 3rd → RR = 1/3 ≈ 0.33
- Query 3: Correct result is ranked 2nd → RR = 1/2 = 0.5

$$
MRR = (1.0 + 0.33 + 0.5) / 3 ≈ 0.61
$$

$$
MRR = \frac{1}{∣Q∣} \sum_{i=1}^{∣Q∣}\frac{​1}{rank_i}
$$

A good Mean Reciprocal Rank (MRR) value typically ranges from 0.6 to 1.0

In [99]:
def mmr(orig, top5):
    idx = np.where(top5 == orig)[0]
    return (1/(idx+1) if idx.size != 0 else [0])[0]

df_results['mmr'] = df_results.apply(lambda x: mmr(orig=x['original'], top5=x['top5']), axis=1)

print(f"MMR: {df_results['mmr'].sum() / len(df)}")
display(df_results)

MMR: 0.1168148148148148


Unnamed: 0,record,original,predict,top5,inTop5,correct,mmr
0,0,0,"[[0.09662508964538574, 7.0], [0.10977473855018...","[7.0, 8.0, 9.0, 10.0, 11.0]",0,0,0.0
1,1,0,"[[0.1388481685093471, 7.0], [0.139701657825046...","[7.0, 8.0, 9.0, 10.0, 13.0]",0,0,0.0
2,2,0,"[[0.12216955423355103, 3.0], [0.13968069496608...","[3.0, 7.0, 8.0, 9.0, 10.0]",0,0,0.0
3,3,10,"[[0.14886939525604248, 16.0], [0.1796432733535...","[16.0, 18.0, 19.0, 20.0, 23.0]",0,0,0.0
4,4,0,"[[0.11608517169952393, 7.0], [0.12344068288803...","[7.0, 8.0, 9.0, 10.0, 16.0]",0,0,0.0
...,...,...,...,...,...,...,...
1795,1795,0,"[[0.1316316823164622, 1.0], [0.140423774719238...","[1.0, 3.0, 5.0, 7.0, 8.0]",0,1,0.0
1796,1796,0,"[[0.1611930951476097, 7.0], [0.161896912947944...","[7.0, 9.0, 10.0, 11.0, 12.0]",0,0,0.0
1797,1797,0,"[[0.12209852933883666, 8.0], [0.12868483662605...","[8.0, 9.0, 10.0, 16.0, 18.0]",0,0,0.0
1798,1798,0,"[[0.12893107533454895, 8.0], [0.13033321729073...","[8.0, 9.0, 10.0, 11.0, 12.0]",0,0,0.0


In [100]:
df_results.to_csv(f"datasets/{schema_index}_{max_avg_distance}_{test_range}.csv")