In [None]:
import pandas as pd
import numpy as np
from ast import literal_eval
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
%matplotlib inline
import seaborn as sns
from tqdm import tqdm 
from transformers import AutoTokenizer
import string
sns.set_theme(style="whitegrid")

In [None]:
directory = "dischargesum"#"allnotes"
model_name = "psyroberta_p4_epoch12"#"psyroberta_p4_dedupcont_epoch12"
model_path = "../../finetuning/acutereadm_finetuned_models/"

if directory=="dischargesum":
    text_column_name = "text_names_removed_step2"
else:
    text_column_name = "DedupCont"

In [None]:
from azure.ai.ml import MLClient#, Input, command
from azure.identity import DefaultAzureCredential
import sys
sys.path.append("../..")
from utils import azure_ml_configs

workspace_id = azure_ml_configs.workspace_id 
subscription_id = azure_ml_configs.subscription_id 
resource_group = azure_ml_configs.resource_group
workspace_name = azure_ml_configs.workspace_name

# Get a handle to the workspace
ml_client = MLClient(
    credential=DefaultAzureCredential(),
    subscription_id=subscription_id,
    resource_group_name=resource_group,
    workspace_name=workspace_name,
)

if directory=="dischargesum":
    data_asset = ml_client.data.get(name="clinicalNote_AcuteReadmission", version=1)
else:
    data_asset = ml_client.data.get(name="clinicalNote_AcuteReadmission_DedupCont", version=1)

print(f"Data asset URI: {data_asset.path}")
data_path = data_asset.path

tokenizer = AutoTokenizer.from_pretrained(model_path+directory+"/"+model_name, local_files_only=True)

# loading and prepraring data
cols = [text_column_name, "Acute", "set", "Type", "PatientDurableKey", "EncounterKey", "CreationInstant"]
df = pd.read_csv(data_path, usecols=cols)
# make sure the data is sorted by patient id, encounter and date
df.sort_values(by=["PatientDurableKey", "EncounterKey", "CreationInstant"],inplace=True)
#rename main columns of interest
df.rename(columns={text_column_name: "text", "Acute": "label"}, inplace=True)

if directory=="dischargesum":
    df = df[df["Type"].str.contains("Udskrivningsresume|Udskrivningsresumé")==True].copy()

# concatenating texts on patient and encounter id
df = df.groupby(["PatientDurableKey", "EncounterKey", "label", "set"]).text.apply(f'{tokenizer.sep_token}'.join).reset_index()

In [None]:
# load AR result
result = pd.read_csv(f'../../result_files/{directory}_{model_name}_AR_train_results.csv')

result['tokens_2'] = result.tokens.apply(lambda x: literal_eval(str(x)))
result['attn_rollout_2'] = result.attn_rollout.apply(lambda x: literal_eval(str(x)))

y_true = [df[df.EncounterKey==eid].label.item() for eid in result.eid.values]
result["y_true"] = y_true

In [None]:
def clean_token(token):
    map_chars = {"Â": "Ġ",
                 "ł": "",
                 "Ã¦": "æ",
                 "Ã¸":"ø",
                 "ãĺ":"ø",
                 "Ã¥":"å",
                 "âĢĿ":"",
                 "Ã©":"é",
                 "ãī":"é"
                }
    
    for k in map_chars.keys():
        token = token.replace(k,map_chars[k])
    
    return token

def only_punct_next(s):
    return True if len([i for i in s if i in string.punctuation]) == len(s) else False
    
def merge_roberta_tokens(tokens, importance, method="sum"):
    # function adapted from https://github.com/beinborn/relative_importance/blob/main/extract_model_importance/tokenization_util.py
    # and https://github.com/lautel/fair-rationales/blob/504aabb8f726df4372de45922facf949dd4fb044/src/extract_model_importance/tokenization_utils.py 

    # we don't care about pad tokens
    pad_index = tokens.index("<pad>") if "<pad>" in tokens else None
    tokens = tokens[:pad_index]
    importance = importance[:pad_index]
    
    adjusted_tokens = []
    adjusted_importance = []
    i = 1
    # We ignore the last token 
    while i < len(tokens) - 1:
        combined_token = clean_token(tokens[i])
        
        combined_heat = importance[i]
        combined_heat_max = [importance[i]]
        
        # keep track how number of tokens combined
        combine_count=1
        
        # check if token is only punctuation
        only_punct = only_punct_next(combined_token.replace("Ġ", ""))
        
        # Nothing to be done for the first and last token
        if 0< i < (len(tokens) - 2):
            #next_token = clean_token(tokens[i + 1])
    
            #while not clean_token(tokens[i + 1]).startswith("Ġ") and tokens[i + 1] not in string.punctuation and not only_punct:
            while not clean_token(tokens[i + 1]).startswith("Ġ") and not only_punct_next(tokens[i + 1]) and not only_punct:
                combined_token = combined_token + clean_token(tokens[i + 1])
                combined_heat = combined_heat + importance[i + 1]
                combined_heat_max.append(importance[i + 1])
                combine_count += 1
                i += 1
                if i == len(tokens) - 2:
                    break
        combined_token = combined_token.replace("Ġ", "")
        # as there are some errors in the placement of the Name token, we disregard these tokens
        if len(combined_token)>0 and "[Name]" not in combined_token and "</s>" not in combined_token:
            # we lowercase the combined token before appending
            adjusted_tokens.append(combined_token.lower())
            # optional: take mean of attention over combined tokens:
            if method=="mean":
                adjusted_importance.append(combined_heat/combine_count)
            # otherwise take sum of attentions:
            elif method=="sum":
                adjusted_importance.append(combined_heat)
            # or take the max:
            elif method=="max":
                adjusted_importance.append(max(combined_heat_max))
            
        i += 1
        
    assert len(adjusted_tokens)==len(adjusted_importance)
    return adjusted_tokens, adjusted_importance

In [None]:
all_tokens = []
all_importances = []
for i in tqdm(range(len(result))):
    t,v = merge_roberta_tokens(result.tokens_2.values[i], result.attn_rollout_2.values[i], method="max")
    all_tokens.append(t)
    all_importances.append(v)

result["merged_tokens"] = all_tokens
result["merged_attentions"] = all_importances

result['merged_tokens'] = result.merged_tokens.apply(lambda x: literal_eval(str(x)))
result['merged_attentions'] = result.merged_attentions.apply(lambda x: literal_eval(str(x)))

In [None]:
def count_token_patient_freq(result_df):
    # group token lists on patient id
    result_gb = result_df.groupby(["pid"]).agg({"merged_tokens":"sum"})
    
    toks = []
    for i in range(len(result_gb)):
        toks+=result_gb.merged_tokens.values[i]
    
    tok_set = set(toks)
    
    token_pt_freq = defaultdict(int)
    
    for i in tqdm(range(len(result_gb))):
        to_count = tok_set.intersection(set(result_gb.merged_tokens.values[i]))
        for t in to_count:
            token_pt_freq[t] += 1

    return token_pt_freq

def sort_attentions(result_df):
    # one list of all tokens
    toks = []
    for i in range(len(result_df)):
        toks+=result_df.merged_tokens.values[i]
    
    # one list of all attentions
    attentions = []
    for i in range(len(result_df)):
        attentions+=result_df.merged_attentions.values[i]
    
    frequency_dict = defaultdict(int)
    for t in toks:
        frequency_dict[t]+=1
        
    mean_attention_dict = defaultdict(float)

    for t,a in list(zip(toks,attentions)):
        mean_attention_dict[t] += a

    for k in mean_attention_dict.keys():
        mean_attention_dict[k] /= frequency_dict[k]
    
    freq_p_dict = count_token_patient_freq(result_df)
        
    d = {"words": mean_attention_dict.keys(), 
         "attention": mean_attention_dict.values(), 
         "freq": [frequency_dict[k] for k in mean_attention_dict.keys()],
         "freq_p": [freq_p_dict[k] for k in mean_attention_dict.keys()]}
    
    att_df = pd.DataFrame.from_dict(d)
    return att_df

def find_ngrams(input_list, n):
    return list(zip(*[input_list[i:] for i in range(n)]))

def ngram_attentions(result_df, n):
    
    result_df = result_df.groupby(["pid"]).agg({"merged_tokens":"sum", 
                                                "merged_attentions":"sum"})
    
    frequency_dict = defaultdict(int)
    mean_attention_dict = defaultdict(float)
    
    ngrams = []
    ngrams_att = []
    for i in range(len(result_df)):
        toks = result_df.merged_tokens.values[i]
        attentions = result_df.merged_attentions.values[i]
        curr_ngrams = [" ".join(x) for x in find_ngrams(toks, n)]
        curr_ngrams_att = list(map(sum,find_ngrams(attentions,n)))
        ngrams += curr_ngrams
        ngrams_att += curr_ngrams_att
    
    assert len(ngrams)==len(ngrams_att)
    
    for t in ngrams:
        frequency_dict[t] +=1
        
    for t,a in list(zip(ngrams,ngrams_att)):
        mean_attention_dict[t] += a

    for k in mean_attention_dict.keys():
        mean_attention_dict[k] /= frequency_dict[k]
    
    n_grams_set = set(ngrams)
    
    ngram_pt_freq = defaultdict(int)
    
    for i in tqdm(range(len(result_df))):
        curr_n_grams = [" ".join(x) for x in find_ngrams(result_df.merged_tokens.values[i], n)] 

        to_count = n_grams_set.intersection(set(curr_n_grams))
        
        for t in to_count:
            ngram_pt_freq[t] += 1
            
    d = {"words": mean_attention_dict.keys(), 
         "attention": mean_attention_dict.values(), 
         "freq": [frequency_dict[k] for k in mean_attention_dict.keys()],
         "freq_p": [ngram_pt_freq[k] for k in mean_attention_dict.keys()]}
    
    att_df = pd.DataFrame.from_dict(d)
    return att_df

In [None]:
def plotting_simple(highdf, 
             n, 
             method, 
             freq_rule="top 10% most frequent"):
    
    print(freq_rule, n)
    
    palette_name1 = 'mako'#"Reds_r"
    #palette_name2 = "Blues"
    sns.set(font_scale=0.8,style="whitegrid")
    
    fig, ax = plt.subplots(figsize=(3,5))

    sns.barplot(x="attention", y="words", data=highdf,
                label="Largest attention all tokens" , ax=ax,
               palette=palette_name1)
    ax.xaxis.label.set_visible(False)
    ax.yaxis.label.set_visible(False)
    ax.set_title("Attention")

def plotting(highdf, 
             n, 
             method, 
             freq_rule="top 10% most frequent", 
             xlim_lst = None,
             annot="a"):
    
    print(freq_rule)
    
    palette_name1 = 'mako'
    
    sns.set(font_scale=0.8,style="whitegrid")

    fig = plt.figure(figsize=(8,5))
    
    
    if xlim_lst==None:
        gs = GridSpec(1, 4, figure=fig)
        
        ax1 = fig.add_subplot(gs[0,0 :2])
        ax1.xaxis.label.set_visible(False)
        ax1.yaxis.label.set_visible(False)
        
        
        # freq
        ax2 = fig.add_subplot(gs[0,2 :3], sharey=ax1)
        plt.setp(ax2.get_yticklabels(), visible=False)
        ax2.yaxis.label.set_visible(False)
        ax2.xaxis.label.set_visible(False)
        
        # freq_p
        ax3 = fig.add_subplot(gs[0,3 :4], sharey=ax1)
        plt.setp(ax3.get_yticklabels(), visible=False)
        ax3.yaxis.label.set_visible(False)
        ax3.xaxis.label.set_visible(False)
        
        ax1.set_title("Attention")
        ax2.set_title("Word frequency")
        ax3.set_title("Patient frequency\nof word occurrence")
        
    else:
        gs = GridSpec(1, 12, figure=fig)
        
        ax1 = fig.add_subplot(gs[0,0 :6])
        ax1.xaxis.label.set_visible(False)
        ax1.yaxis.label.set_visible(False)
        
        # freq
        ax2 = fig.add_subplot(gs[0,6 :8], sharey=ax1)
        plt.setp(ax2.get_yticklabels(), visible=False)
        ax2.yaxis.label.set_visible(False)
        ax2.xaxis.label.set_visible(False)
        
        ax2b = fig.add_subplot(gs[0,8 :9], sharey=ax1)
        plt.setp(ax2b.get_yticklabels(), visible=False)
        ax2b.yaxis.label.set_visible(False)
        ax2b.xaxis.label.set_visible(False)
        
        ax2.set_xlim(xlim_lst[0],xlim_lst[1])
        ax2b.set_xlim(xlim_lst[2],xlim_lst[3])

        
        # freq_p
        ax3 = fig.add_subplot(gs[0,9 :11], sharey=ax1)
        plt.setp(ax3.get_yticklabels(), visible=False)
        ax3.yaxis.label.set_visible(False)
        ax3.xaxis.label.set_visible(False)
        ax3b = fig.add_subplot(gs[0,11 :12], sharey=ax1)
        plt.setp(ax3b.get_yticklabels(), visible=False)
        ax3b.yaxis.label.set_visible(False)
        ax3b.xaxis.label.set_visible(False)
        
        ax3.set_xlim(xlim_lst[4],xlim_lst[5])
        ax3b.set_xlim(xlim_lst[6],xlim_lst[7])
        

    highdf["words"] = highdf["words"].apply(lambda x: x.replace("schaumburg", "[name]"))

    sns.barplot(x="attention", y="words", data=highdf,
                label="Largest attention all tokens" , ax=ax1,
               palette=palette_name1)

    sns.barplot(x="freq", y="words", data=highdf,
                label="" , ax=ax2, #
               palette=palette_name1)
    
    sns.barplot(x="freq_p", y="words", data=highdf,
                label="" , ax=ax3, #
               palette=palette_name1)
    
    if not xlim_lst==None:
        sns.barplot(x="freq", y="words", data=highdf,
                label="" , ax=ax2b, #
               palette=palette_name1)
    
        sns.barplot(x="freq_p", y="words", data=highdf,
                label="" , ax=ax3b, #
               palette=palette_name1)
    
        d = .015  # how big to make the diagonal lines in axes coordinates
        # arguments to pass plot, just so we don't keep repeating them
        kwargs = dict(transform=ax2.transAxes, color='grey', clip_on=False)
        ax2.plot((1-d, 1+d), (-d, +d), **kwargs, zorder=10)
        ax2.plot((1-d, 1+d), (1-d, 1+d), **kwargs, zorder=10)
        
        kwargs.update(transform=ax2b.transAxes)  # switch to the bottom axes
        ax2b.plot((-d, +d), (1-d, 1+d), **kwargs, zorder=10)
        ax2b.plot((-d, +d), (-d, +d), **kwargs, zorder=10)
        
        # hide the spines between ax and ax2
        ax2.spines['right'].set_visible(False)
        ax2b.spines['left'].set_visible(False)
        
        kwargs = dict(transform=ax3.transAxes, color='grey', clip_on=False)
        ax3.plot((1-d, 1+d), (-d, +d), **kwargs, zorder=10)
        ax3.plot((1-d, 1+d), (1-d, 1+d), **kwargs, zorder=10)
        kwargs.update(transform=ax3b.transAxes)  # switch to the bottom axes
        ax3b.plot((-d, +d), (1-d, 1+d), **kwargs, zorder=10)
        ax3b.plot((-d, +d), (-d, +d), **kwargs, zorder=10)
        
        # hide the spines between ax and ax2
        ax3.spines['right'].set_visible(False)
        ax3b.spines['left'].set_visible(False)
        
        plt.setp(ax1.get_xticklabels(), rotation=45, rotation_mode="anchor")
        plt.setp(ax2.get_xticklabels(), rotation=45, rotation_mode="anchor")
        plt.setp(ax2b.get_xticklabels(), rotation=45, rotation_mode="anchor")
        plt.setp(ax3.get_xticklabels(), rotation=45, rotation_mode="anchor")
        plt.setp(ax3b.get_xticklabels(), rotation=45, rotation_mode="anchor")
        
        
        ghost1 = fig.add_subplot(gs[:6],label="attention title")
        ghost1.axis('off')
        ghost1.set_title("Attention")

        ghost2 = fig.add_subplot(gs[0,6:9], label="freq title")
        ghost2.axis('off')
        ghost2.set_title("Word frequency")
        
        ghost3 = fig.add_subplot(gs[0,9:12], label="freq_p title")
        ghost3.axis('off')
        ghost3.set_title("Patient frequency\nof word occurrence")
        
    
    fig.tight_layout()
    
    if not xlim_lst==None:
        fig.subplots_adjust(wspace=1, hspace=1)

    
    
def plot_attentions(result, method="max", n=1):
    
    all_tokens = []
    all_importances = []
    for i in tqdm(range(len(result))):
        t,v = merge_roberta_tokens(result.tokens_2.values[i], result.attn_rollout_2.values[i], method=method)
        all_tokens.append(t)
        all_importances.append(v)

    result["merged_tokens"] = all_tokens
    result["merged_attentions"] = all_importances

    result['merged_tokens'] = result.merged_tokens.apply(lambda x: literal_eval(str(x)))
    result['merged_attentions'] = result.merged_attentions.apply(lambda x: literal_eval(str(x)))
    
    result_highrisk = result[result.pos_prob>=0.7].copy()
    
    att_df = ngram_attentions(result, n)
    att_df_highrisk = ngram_attentions(result_highrisk, n)    
    
    all_freq = att_df[att_df.freq_p>=3].sort_values(by=["attention"],ascending=False)[:25]

    plotting(all_freq,n=n,method=method,freq_rule="all")
    
    k=10
    top_k = int(len(att_df)*(k/100))
    print(top_k)
    
    top_k_freq = att_df[att_df.freq_p>=3].sort_values(by=["freq"], ascending=False)[:top_k]
    top_k_freq_high = top_k_freq.sort_values(by=["attention"],ascending=False)[:25]
    plotting(top_k_freq_high,n=n, method=method)
    
    top_k_freq_highrisk = att_df_highrisk[(att_df_highrisk.words.isin(top_k_freq.words))&(att_df_highrisk.freq_p>=3)].sort_values(by=["attention"],ascending=False)[:25]
    
    plotting(top_k_freq_highrisk,n=n, method=method)
    
    print(n, method)
    print("With 10% most freqent:")
    for w in top_k_freq_high.words.values:
        print(w)
    for a in top_k_freq_high.attention.values:
        print(np.round(a,3))
    
    
    print("With all words:")
    for w in all_freq.words.values:
        print(w)
    for a in all_freq.attention.values:
        print(np.round(a,3))
    
    
    print("highrisk with 10% most freqeunt words (globally in set)")
    for w in top_k_freq_highrisk.words.values:
        print(w)
    for a in top_k_freq_highrisk.attention.values:
        print(np.round(a,3))
    
    return all_freq, top_k_freq_high, top_k_freq_highrisk

In [None]:
# Example of extracting attentions on trigrams and using the max attention score (mean and sum are also possible)

all_freq, top_k_freq_high, top_k_freq_highrisk = plot_attentions(result, method="max", n=3)

plotting_simple(top_k_freq_high, 
             3, 
             method="max", 
             freq_rule="top 10% most frequent",)
#plt.savefig("train_top10_max_trigram.pdf", bbox_inches="tight")
#plt.savefig("train_top10_max_trigram.png", bbox_inches="tight")