In [1]:
import pandas as pd
import numpy as np
import os

## 1 - Load the preprocessed data

In [2]:
processed_file_path = os.path.join('..', 'data', 'processed')
clustering_model_name = 'KMeans'
processed_file_name = f'train_{clustering_model_name}.csv'

train_df = pd.read_csv(os.path.join(processed_file_path, processed_file_name), low_memory=False)
train_df.head()

Unnamed: 0,file,VMONTH,VYEAR,VDAYR,YEAR,AGE,SEX,ETHNIC,RACE,USETOBAC,...,OTHPROV,MHP,NODISP,REFOTHMD,RETAPPT,OTHDISP,ERADMHOS,cluster,CombinedText,ProcessedText
0,opd2006.csv,December,2006.0,Friday,2006.0,55.0,Male,Not Hispanic or Latino,White Only,Not current,...,No,,One or more dispositions marked,No,No,No,No,9,55_year_old Middle_Aged Male Acute problem Inj...,55_year_old middle_age male acute problem inju...
1,opd2006.csv,November,2006.0,Thursday,2006.0,66.0,Male,Not Hispanic or Latino,White Only,Not current,...,No,,One or more dispositions marked,No,No,No,No,0,66_year_old Senior Male Acute problem Cough Ov...,66_year_old senior male acute problem cough ov...
2,opd2006.csv,November,2006.0,Wednesday,2006.0,71.0,Female,Not Hispanic or Latino,White Only,Not current,...,No,,One or more dispositions marked,Yes,No,No,No,7,71_year_old Senior Female Acute problem Genera...,71_year_old senior female acute problem genera...
3,opd2006.csv,November,2006.0,Tuesday,2006.0,1.0,Female,Not Hispanic or Latino,White Only,Not current,...,No,,One or more dispositions marked,No,No,No,No,2,1_year_old Infant Female Acute problem Other a...,1_year_old infant female acute problem unspeci...
4,opd2006.csv,November,2006.0,Monday,2006.0,21.0,Female,Not Hispanic or Latino,White Only,Current,...,No,,One or more dispositions marked,No,No,No,No,6,21_year_old Adult Female TOBACCO user Acute pr...,21_year_old adult female tobacco user acute pr...


## 2 - Generate vectors and compute similarities/relevances

In [3]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

Some weights of the model checkpoint at microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
def get_embedding(text):
    encoded_input = tokenizer(text, return_tensors='pt')

    with torch.no_grad():
        output = model(**encoded_input)
        #sentence_embedding = output.last_hidden_state[:, 0, :]
        sentence_embedding = output.pooler_output
    return sentence_embedding


def compute_similarity(embedding1, embedding2):
    # Cosine similarity
    cs_score = torch.nn.functional.cosine_similarity(embedding1, embedding2)

    # Dot product
    dp_score = torch.dot(embedding1.squeeze(), embedding2.squeeze())

    return cs_score, dp_score


def compare_to_groups(user_input, records_by_group):
    # preprocess user input
    #user_input = preprocess_text(user_input)

    user_embedding = get_embedding(user_input)

    group_relevances = {}
    for group_id, records in records_by_group():
        group_similarities = []
        for text in records:
            text_embedding = get_embedding(text)
            cs_score, dp_score = compute_similarity(user_embedding, text_embedding)
            group_similarities.append((cs_score, dp_score))
        
        group_relevance = np.mean(np.arry(group_similarities), axis=0)
        group_relevances[group_id] = group_relevance

    return group_relevances

In [5]:
records_by_group = train_df.groupby('cluster')['ProcessedText'].apply(list).to_dict()