In [1]:
import pandas as pd
import numpy as np
import os
import json

## 1 - Load data

### 1.1 - Load the preprocessed data

In [2]:
processed_file_path = os.path.join('..', 'data', 'processed')
clustering_model_name = 'KMeans'
processed_file_name = f'train_{clustering_model_name}.csv'

train_df = pd.read_csv(os.path.join(processed_file_path, processed_file_name), low_memory=False)
train_df.head()

Unnamed: 0,file,VMONTH,VYEAR,VDAYR,YEAR,AGE,SEX,ETHNIC,RACE,USETOBAC,...,OTHPROV,MHP,NODISP,REFOTHMD,RETAPPT,OTHDISP,ERADMHOS,cluster,CombinedText,ProcessedText
0,opd2006.csv,December,2006.0,Friday,2006.0,55.0,Male,Not Hispanic or Latino,White Only,Not current,...,No,,One or more dispositions marked,No,No,No,No,9,55_year_old Middle_Aged Male Acute problem Inj...,55_year_old middle_age male acute problem inju...
1,opd2006.csv,November,2006.0,Thursday,2006.0,66.0,Male,Not Hispanic or Latino,White Only,Not current,...,No,,One or more dispositions marked,No,No,No,No,0,66_year_old Senior Male Acute problem Cough Ov...,66_year_old senior male acute problem cough ov...
2,opd2006.csv,November,2006.0,Wednesday,2006.0,71.0,Female,Not Hispanic or Latino,White Only,Not current,...,No,,One or more dispositions marked,Yes,No,No,No,7,71_year_old Senior Female Acute problem Genera...,71_year_old senior female acute problem genera...
3,opd2006.csv,November,2006.0,Tuesday,2006.0,1.0,Female,Not Hispanic or Latino,White Only,Not current,...,No,,One or more dispositions marked,No,No,No,No,2,1_year_old Infant Female Acute problem Other a...,1_year_old infant female acute problem unspeci...
4,opd2006.csv,November,2006.0,Monday,2006.0,21.0,Female,Not Hispanic or Latino,White Only,Current,...,No,,One or more dispositions marked,No,No,No,No,6,21_year_old Adult Female TOBACCO user Acute pr...,21_year_old adult female tobacco user acute pr...


### 1.2 - Load the vairiables dictionary

In [3]:
# Load the variables dictionary
variables_file_path = os.path.join('..', 'data', 'cleaned')

with open(os.path.join(variables_file_path, 'variables.json'), 'r') as f:
    variables = json.load(f)

print(f'Variable Categories:\n')
for category, list in variables.items():
    print(f'{category}')
    print(f'{list}')

Variable Categories:

dateOfVisit
['VMONTH', 'VYEAR', 'VDAYR', 'YEAR']
demographics
['AGE', 'SEX', 'ETHNIC', 'RACE', 'USETOBAC']
payment
['PAYPRIV', 'PAYMCARE', 'PAYMCAID', 'PAYWKCMP', 'PAYSELF', 'PAYNOCHG', 'PAYOTH', 'PAYDK', 'PAYTYPER']
visitReason
['INJDET', 'INJURY', 'MAJOR', 'RFV1', 'RFV2', 'RFV3']
patientClinicHistory
['SENBEFOR', 'PASTVIS']
vitalSigns
['HTIN', 'WTLB', 'BMI', 'TEMPF', 'BPSYS', 'BPDIAS']
imputedFields
['BDATEFL', 'SEXFL', 'SENBEFL', 'PASTFL']
physicianDiagnoses
['DIAG1', 'DIAG2', 'DIAG3']
differentialDiagnoses
['PRDIAG1', 'PRDIAG2', 'PRDIAG3']
presentSymptomsStatus
['ARTHRTIS', 'ASTHMA', 'CANCER', 'CASTAGE', 'CEBVD', 'CHF', 'CRF', 'COPD', 'DEPRN', 'DIABETES', 'HYPLIPID', 'HTN', 'IHD', 'OBESITY', 'OSTPRSIS', 'NOCHRON', 'TOTCHRON', 'DMP']
services
['BREAST', 'PELVIC', 'RECTAL', 'SKIN', 'DEPRESS', 'BONEDENS', 'MAMMO', 'MRI', 'ULTRASND', 'XRAY', 'OTHIMAGE', 'CBC', 'ELECTROL', 'GLUCOSE', 'HGBA', 'CHOLEST', 'PSA', 'OTHERBLD', 'BIOPSY', 'CHLAMYD', 'PAPCONV', 'PAPLIQ', 'P

## 2 - Preprocess user input

In [4]:
# Load custom function to combine text features
import sys
sys.path.append('../src/features/')

from combine_textual import combine_features


# Define the list of features to combine
feature_list = [
    'AGE', 'SEX', 'USETOBAC', 
    'MAJOR', 'RFV1', 'RFV2', 'RFV3', 
    'BMI', 'TEMPF', 'BPSYS', 'BPDIAS',
    'ARTHRTIS', 'ASTHMA', 'CANCER', 'CEBVD', 'CHF', 'CRF', 'COPD', 'DEPRN', 'DIABETES', 'HYPLIPID', 'HTN', 'IHD', 'OBESITY', 'OSTPRSIS', 'NOCHRON', 'DMP',
    'DIAG1', 'DIAG2', 'DIAG3'
]

In [5]:
user_input = combine_features(train_df.iloc[0].copy(), feature_list[:-3])

## 3 - Generate vectors and compute similarities/relevances

In [6]:
# Load model directly
from transformers import AutoTokenizer, AutoModel
import torch

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device= 'mps'
else:
    device = 'cpu'
print(f'Using device: {device}')

model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)

Using device: mps


Some weights of the model checkpoint at microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
def get_embedding(text):
    encoded_input = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

    with torch.no_grad():
        output = model(**encoded_input)
        #sentence_embedding = output.last_hidden_state[:, 0, :]
        sentence_embedding = output.pooler_output
    return sentence_embedding


def compare_to_groups(user_input, records_by_group, batch_size=32):
    # preprocess user input
    #user_input = preprocess_text(user_input)

    user_embedding = get_embedding(user_input)

    group_relevances = {}
    for group_id, records in records_by_group.items():
        batch_embeddings = []
        for i in range(0, len(records), batch_size):
            batch_records = records[i:i+batch_size]
            batch_embeddings.append(get_embedding([text for text in batch_records]))

        group_embeddings = torch.cat(batch_embeddings, dim=0)

        group_similarities = torch.nn.functional.cosine_similarity(user_embedding, group_embeddings, dim=1)

        # Print the max and min similarity scores within the group
        print(f'Group {group_id}:')
        print(f'Max Similarity: {group_similarities.max().item()}')
        print(f'Min Similarity: {group_similarities.min().item()}')
        print()
    
        group_relevance = group_similarities.mean(dim=0).item()
        group_relevances[group_id] = group_relevance

    return group_relevances

In [8]:
records_by_group = train_df.groupby('cluster')['ProcessedText'].apply(lambda x: x.tolist()).to_dict()
records_by_group.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [9]:
compare_to_groups(user_input, records_by_group)

Group 0:
Max Similarity: 0.9808236956596375
Min Similarity: 0.6218248009681702
Group 1:
Max Similarity: 0.9779689311981201
Min Similarity: 0.6466044187545776
Group 2:
Max Similarity: 0.9760890603065491
Min Similarity: 0.6302465200424194
Group 3:
Max Similarity: 0.9772753119468689
Min Similarity: 0.6411662697792053
Group 4:
Max Similarity: 0.980180025100708
Min Similarity: 0.676959753036499
Group 5:
Max Similarity: 0.9760475754737854
Min Similarity: 0.6320551633834839
Group 6:
Max Similarity: 0.978866696357727
Min Similarity: 0.6442828178405762
Group 7:
Max Similarity: 0.9787057042121887
Min Similarity: 0.641491174697876
Group 8:
Max Similarity: 0.9779014587402344
Min Similarity: 0.6317482590675354
Group 9:
Max Similarity: 0.9788384437561035
Min Similarity: 0.6553593277931213


{0: 0.9275656342506409,
 1: 0.9284870028495789,
 2: 0.9148563146591187,
 3: 0.9196761846542358,
 4: 0.9316330552101135,
 5: 0.9221423268318176,
 6: 0.9189413785934448,
 7: 0.919468879699707,
 8: 0.9163185358047485,
 9: 0.9324429035186768}