In [None]:
# Import azure-core elements
import azureml.core
from azureml.core.workspace import Workspace
from azureml.core import ScriptRunConfig, Environment, Experiment
from azureml.core.environment import CondaDependencies
from azureml.core import Workspace, Datastore, Dataset
from azureml.data.dataset_factory import DataType

# Initiate workspace
workspace = Workspace.from_config()

# Define datastore and load dataset
datastore_name = 'sp_data'
datastore = Datastore.get(workspace, datastore_name)

datastore_paths = [(datastore, '/patients.parquet')] 
ds = Dataset.Tabular.from_parquet_files(path=datastore_paths)
dfP = ds.to_pandas_dataframe()
dfP.head()

In [None]:
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import transformers
from transformers import AutoTokenizer, AutoModel
#import colorcet as cc
from datasets import Dataset, DatasetDict
sns.set(style="whitegrid")
sns.set_palette('mako_r')

In [None]:
# Set path for the finetuned model
model_path = "../../finetuning/acutereadm_finetuned_models/dischargesum/"
model_name = "psyroberta_p4_epoch12"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModel.from_pretrained(model_path+model_name, 
                                    local_files_only=True,
                                    use_safetensors=True, 
                                    output_hidden_states=True,
                                    output_attentions=True)#.cuda()
model.to(device)

tokenizer = AutoTokenizer.from_pretrained(model_path+model_name, local_files_only=True)

assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [None]:
events = pd.read_parquet("../../data/acuteReadmission/events_acute_labels.parquet")

In [None]:
import math
# make new encounter column, as acute and non acute have different encounter columns in events dataframe
events["encounter"] = [int(i) if math.isnan(i)==False else int(j) for (i,j) in list(zip(events.EncounterKey_dis.values, events.EncounterKey.values))]

event = events.merge(
        dfP[["DurableKey", "BirthDate"]],
        left_on = 'PatientDurableKey',
        right_on = 'DurableKey')\
        .drop(columns='DurableKey')
event['Age'] = np.floor((pd.to_datetime(event.Date_dis) -pd.to_datetime(event.BirthDate)).dt.days / 365.25).astype(int)

age_df = event[["encounter", "Age"]].copy()
event.head()

In [None]:
from azure.ai.ml import MLClient#, Input, command
from azure.identity import DefaultAzureCredential
import sys
sys.path.append("../..")
import azure_ml_configs

workspace_id = azure_ml_configs.workspace_id 
subscription_id = azure_ml_configs.subscription_id 
resource_group = azure_ml_configs.resource_group
workspace_name = azure_ml_configs.workspace_name

# Get a handle to the workspace
ml_client = MLClient(
    credential=DefaultAzureCredential(),
    subscription_id=subscription_id,
    resource_group_name=resource_group,
    workspace_name=workspace_name,
)

discharge_notes_only = True
text_column_name = "text_names_removed_step2"

data_asset = ml_client.data.get(name="clinicalNote_AcuteReadmission", version=1)

print(f"Data asset URI: {data_asset.path}")

data_path = data_asset.path

In [None]:
# loading and prepraring data
cols = [text_column_name, "Acute", "set", "Type", "PatientDurableKey", "EncounterKey", "CreationInstant"]
df = pd.read_csv(data_path, usecols=cols)
# make sure the data is sorted by patient id, encounter and date
df.sort_values(by=["PatientDurableKey", "EncounterKey", "CreationInstant"],inplace=True)
#rename main columns of interest
df.rename(columns={text_column_name: "text", "Acute": "label"}, inplace=True)

if discharge_notes_only:
    df = df[df["Type"].str.contains("Udskrivningsresume|Udskrivningsresum√©")==True].copy()
    

# concatenating texts on patient and encounter id
df = df.groupby(["PatientDurableKey", "EncounterKey", "label", "set"]).text.apply(f'{tokenizer.sep_token}'.join).reset_index()

In [None]:
data_dict = {
    "train": Dataset.from_pandas(df[df.set=="train"]),
    "validation": Dataset.from_pandas(df[df.set=="val"]),
    "test": Dataset.from_pandas(df[df.set=="test"])
    }


raw_datasets = DatasetDict(data_dict)

text_column_name = "text"

def tokenize_function(examples):
    input_ids = []
    attention_masks = []
    labs = []
    patientids = []
    encounterids = []
    texts = []
    for x,y, patient_id, encounter_id in list(zip(examples["text"], examples["label"], examples["PatientDurableKey"], examples["EncounterKey"])):
        encoded_dict = tokenizer.encode_plus(
            x,  # Sentence to encode
            add_special_tokens=True,  # Add '[CLS]' and '[SEP]' or equivelant for roberta
            max_length=512,  # Pad & truncate all sentences.
            padding="max_length", #(needing to specify truncation=True depends on version)
            truncation=True,
            return_overflowing_tokens=True, # return lists of tokens above 512 
            return_offsets_mapping=True,
            stride=32, # The stride used when the context is too large and is split across several features.
            return_attention_mask=True,  # Construct attn. masks.
            return_tensors='pt'  # Return pytorch tensors.
        )
        for inputs, attentions in list(zip(encoded_dict['input_ids'],encoded_dict['attention_mask']))[:None]:
            #print(i.shape)
            # Add the encoded sentence to the list.
            input_ids.append(inputs)
            texts.append(tokenizer.decode(inputs))
            #And its attention mask (simply differentiates padding from non-padding).
            attention_masks.append(attentions)
            labs.append(y)
            patientids.append(patient_id)
            encounterids.append(encounter_id)
    assert len(input_ids) == len(attention_masks) == len(labs) == len(patientids) == len(encounterids)
    sample = {"inputs": input_ids,
            "attn_masks": attention_masks,
            "labels": labs,
            "patient_id": patientids,
            "encounter_id": encounterids,
            "text_split":texts}
    return sample


tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=None,
            remove_columns=raw_datasets['validation'].column_names,
            #load_from_cache_file=not args.overwrite_cache,
            desc="Running tokenizer on every text in dataset",
        )

tokenized_datasets["train"].set_format(type='pt', columns=['inputs', 'attn_masks', 'labels', 'patient_id', 'encounter_id'])
tokenized_datasets["validation"].set_format(type='pt', columns=['inputs', 'attn_masks', 'labels', 'patient_id', 'encounter_id'])
#if args.on_test:
tokenized_datasets["test"].set_format(type='pt',columns=['inputs', 'attn_masks', 'labels', 'patient_id', 'encounter_id'])

traindata = tokenized_datasets["train"]
valdata = tokenized_datasets["validation"]
#if args.on_test:
testdata = tokenized_datasets["test"]

In [None]:
def embedding_extraction(data):
    # adapted from: 
    #https://medium.com/@minamehdinia213/fine-tuned-bert-embeddings-and-t-sne-visualization-bdfd09563744
    
    
    # Store embeddings
    cls_embeddings = []
    mean_pooled_embeddings = []
    preds = []
    pos_probs = []

    # Extract input data from the val dataset
    input_ids = data["inputs"]
    attention_mask = data["attn_masks"]
    #labels = valdata["labels"]
    #eids = valdata["encounter_id"]

    # Convert input data to tensors and move them to the same device as the model
    input_ids = torch.tensor(input_ids).to(device)
    attention_mask = torch.tensor(attention_mask).to(device)
    

    with torch.no_grad():
        for input_id, attention_mask in tqdm(zip(input_ids, attention_mask), total=len(input_ids)):
            
            # Forward pass, return hidden states
            outputs = model(input_ids=input_id.unsqueeze(0).to(device), 
                        attention_mask=attention_mask.unsqueeze(0).to(device), 
                        output_hidden_states=True)
            
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1).detach().cpu()
            prediction = np.argmax(probabilities, axis=1).flatten()
            pos_prob = probabilities[:,1:].flatten()
            preds.append(prediction)
            pos_probs.append(pos_prob)
            
        
            # Extract embeddings from the hidden states
            hidden_states = outputs.hidden_states
            last_hidden_state = hidden_states[-1]  # The last layer hidden state

            # [CLS] token embeddings
            cls_embedding = last_hidden_state[:, 0, :].detach().cpu().numpy()
            cls_embeddings.append(cls_embedding.flatten())

            # Mean pooling
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
            sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            mean_pooled_embedding = (sum_embeddings / sum_mask).detach().cpu().numpy()
            mean_pooled_embeddings.append(mean_pooled_embedding.flatten())

    #convert to numpy array
    cls_embeddings = np.array(cls_embeddings)
    mean_pooled_embeddings = np.array(mean_pooled_embeddings)
    return cls_embeddings, mean_pooled_embeddings, preds

In [None]:
test_cls_embeddings, test_mean_pooled_embeddings, test_predictions = embedding_extraction(testdata)
np.save("test_cls_embeddings.npy", test_cls_embeddings)
np.save("test_mean_pooled_embeddings.npy", test_mean_pooled_embeddings)

In [None]:
train_cls_embeddings, train_mean_pooled_embeddings, val_predictions = embedding_extraction(traindata)
np.save("train_cls_embeddings.npy", train_cls_embeddings)
np.save("train_mean_pooled_embeddings.npy", train_mean_pooled_embeddings)

## Plotting

In [None]:
# possibly load embeddings here

train_cls_embeddings = np.load("train_cls_embeddings.npy")
train_mean_pooled_embeddings = np.load("train_mean_pooled_embeddings.npy")

test_cls_embeddings = np.load("test_cls_embeddings.npy")
test_mean_pooled_embeddings = np.load("test_mean_pooled_embeddings.npy")

val_cls_embeddings = np.load("val_cls_embeddings.npy")
val_mean_pooled_embeddings = np.load("val_mean_pooled_embeddings.npy")

In [None]:
def agegroups(x):
    if x=="Unknown":
        return x
    elif x < 18:
        return "Children"
    elif x >= 18 and x < 35:
        return "Young adults"
    elif x >= 35 and x < 55:
        return "Adults"
    else:
        return "Seniors"

In [None]:
preds = pd.read_csv("../../result_files/dischargesum_psyroberta_p4_epoch12_AR_train_results.csv")
preds = preds[["pos_prob", "pid", "eid"]].copy()
preds = preds.groupby(by=["eid","pid"])["pos_prob"].mean().reset_index()
preds["pred"] = preds.pos_prob.apply(lambda x: 0 if x<0.5 else 1)

train_labels = traindata["labels"]
train_eids = traindata["encounter_id"]
train_pids = traindata["patient_id"]
train_encounter_embs_cls = np.array([np.array([train_cls_embeddings[i] for i in np.where(np.array(train_eids) == k)[0]]).mean(axis=0) for k in np.unique(train_eids)])
#train_num_notes_in_encounter = [len([train_cls_embeddings[i] for i in np.where(np.array(train_eids) == k)[0]]) for k in np.unique(train_eids)]
train_encounter_embs_mean_pooled = np.array([np.array([train_mean_pooled_embeddings[i] for i in np.where(np.array(train_eids) == k)[0]]).mean(axis=0) for k in np.unique(train_eids)])
train_encounter_labs = np.array([np.array([train_labels[i] for i in np.where(np.array(train_eids) == k)[0]]).mean(axis=0) for k in np.unique(train_eids)])
train_num_notes_in_encounter = np.array([len([train_labels[i] for i in np.where(np.array(train_eids) == k)[0]]) for k in np.unique(train_eids)])
train_encounter_pid = np.array([[train_pids[i].item() for i in np.where(np.array(train_eids) == k)[0]][0] for k in np.unique(train_eids)])

train_probs = np.array([preds[preds.eid==i].pos_prob.item() for i in np.unique(train_eids)])
train_preds = np.array([preds[preds.eid==i].pred.item() for i in np.unique(train_eids)])
#train_encounter_tokens =  []

intersection_train = set(train_encounter_pid.tolist()).intersection(set(dfP.DurableKey.values.tolist()))
train_sex = [dfP[dfP["DurableKey"]==i].Sex.item() if i in intersection_train else "Unknown" for i in train_encounter_pid]
train_eth = [dfP[dfP["DurableKey"]==i].Ethnicity.item() if i in intersection_train else "Unknown" for i in train_encounter_pid]
train_age = [age_df[age_df.encounter==i].Age.item() if i in age_df.encounter.values else "Unknown" for i in np.unique(train_eids)]
train_age_groups = [agegroups(i) for i in train_age]

print(train_encounter_embs_cls.shape)
print(train_encounter_embs_mean_pooled.shape)
print(train_encounter_labs.shape)
print(train_num_notes_in_encounter.shape)
print(train_encounter_pid.shape)

In [None]:
test_labels = testdata["labels"]
test_eids = testdata["encounter_id"]
test_pids = testdata["patient_id"]
test_encounter_embs_cls = np.array([np.array([test_cls_embeddings[i] for i in np.where(np.array(test_eids) == k)[0]]).mean(axis=0) for k in np.unique(test_eids)])
test_encounter_embs_mean_pooled = np.array([np.array([test_mean_pooled_embeddings[i] for i in np.where(np.array(test_eids) == k)[0]]).mean(axis=0) for k in np.unique(test_eids)])
test_encounter_labs = np.array([np.array([test_labels[i] for i in np.where(np.array(test_eids) == k)[0]]).mean(axis=0) for k in np.unique(test_eids)])
test_num_notes_in_encounter = np.array([len([test_labels[i] for i in np.where(np.array(test_eids) == k)[0]]) for k in np.unique(test_eids)])
test_encounter_pid = np.array([[test_pids[i].item() for i in np.where(np.array(test_eids) == k)[0]][0] for k in np.unique(test_eids)])

intersection_test = set(test_encounter_pid.tolist()).intersection(set(dfP.DurableKey.values.tolist()))
test_sex = [dfP[dfP["DurableKey"]==i].Sex.item() if i in intersection_test else "Unknown" for i in test_encounter_pid]
test_eth = [dfP[dfP["DurableKey"]==i].Ethnicity.item() if i in intersection_test else "Unknown" for i in test_encounter_pid]
test_age = [age_df[age_df.encounter==i].Age.item() if i in age_df.encounter.values else "Unknown" for i in np.unique(test_eids)]
test_age_groups = [agegroups(i) for i in test_age]

print(test_encounter_embs_cls.shape)
print(test_encounter_embs_mean_pooled.shape)
print(test_encounter_labs.shape)
print(test_num_notes_in_encounter.shape)
print(test_encounter_pid.shape)

In [None]:
pca = PCA()
pipe = Pipeline([('scaler', StandardScaler()), ('pca', pca)])

pca_cls = pipe.fit_transform(train_encounter_embs_cls)
pca_mean_pooled = pipe.fit_transform(train_encounter_embs_mean_pooled)

In [None]:
plot_data_cls = {'PC 1': pca_cls[:, 0],
                 'PC 2': pca_cls[:, 1], 
                 'Label': [int(i) for i in train_encounter_labs],
                 'Num notes': train_num_notes_in_encounter,
                  "Sex": train_sex,
                 "Ethnicity": train_eth,
                 "Age": train_age_groups,
                 "Prediction": train_preds,
                 "Probability": train_probs
                 }

plot_data_mean_pooled = {'PC 1': pca_mean_pooled[:, 0], 
                         'PC 2': pca_mean_pooled[:, 1], 
                         'Label': [int(i) for i in train_encounter_labs],
                         'Num notes': train_num_notes_in_encounter,
                          "Sex": train_sex,
                         "Ethnicity": train_eth,
                         "Age": train_age_groups,
                         "Prediction": train_preds,
                         "Probability": train_probs
                         }

In [None]:
import matplotlib.cm as cm

sns.set_style("white", {"axes.edgecolor": ".8"})
colormap =sns.color_palette(palette='mako_r',n_colors=20, as_cmap=True)

fig = plt.figure(figsize=(9,7))
sns.scatterplot(data=plot_data_mean_pooled, x='PC 1', y='PC 2', hue='Num notes', s=10, cmap=colormap, palette=colormap)

scalarmappaple = cm.ScalarMappable(cmap=colormap)
scalarmappaple.set_array(plot_data_mean_pooled["Num notes"])
cbar = fig.colorbar(scalarmappaple)
cbar.set_label('# Note splits')
plt.legend('',frameon=False)

plt.savefig("../../output/PCA_train_mean_pooled_noteSplit_s10_woGrid.pdf", bbox_inches="tight")
plt.savefig("../../output/PCA_train_mean_pooled_noteSplit_s10_woGrid.png", bbox_inches="tight")

In [None]:
sns.set_style("white", {"axes.edgecolor": ".8"})

c0 = sns.color_palette(palette='mako_r',n_colors=6)[0]
c1 = sns.color_palette(palette='mako_r',n_colors=6)[4]

colors=[c0,c1]
#colors=[c1,c0]

fig = plt.figure(figsize=(7,7))
sns.scatterplot(data=plot_data_mean_pooled, x='PC 1', y='PC 2', hue='Prediction', s=10, palette=colors)
plt.savefig("../../output/PCA_train_mean_pooled_prediction_s10_woGrid.pdf", bbox_inches="tight")
plt.savefig("../../output/PCA_train_mean_pooled_prediction_s10_woGrid.png", bbox_inches="tight")

In [None]:
sns.set_style("white", {"axes.edgecolor": ".8"})

import matplotlib.cm as cm

colormap =sns.color_palette(palette='mako_r',n_colors=100, as_cmap=True)

fig = plt.figure(figsize=(9,7))
sns.scatterplot(data=plot_data_mean_pooled, x='PC 1', y='PC 2', hue='Probability', s=10, cmap=colormap, palette=colormap)

scalarmappaple = cm.ScalarMappable(cmap=colormap)
scalarmappaple.set_array(plot_data_mean_pooled["Probability"])
scalarmappaple.set_clim(vmin=0,vmax=1)
cbar = fig.colorbar(scalarmappaple)
cbar.set_label('Probability')
plt.legend('',frameon=False)

plt.savefig("../../output/PCA_train_mean_pooled_probability_s10_woGrid.pdf", bbox_inches="tight")
plt.savefig("../../output/PCA_train_mean_pooled_probability_s10_woGrid.png", bbox_inches="tight")

In [None]:
sns.set_style("white", {"axes.edgecolor": ".8"})

c0 = sns.color_palette(palette='mako_r',n_colors=6)[0]
c1 = sns.color_palette(palette='mako_r',n_colors=6)[4]
c3 = "r"#sns.color_palette(palette='mako_r',n_colors=10)[5]

colors=[c1,c0,c3]
colors=[c0,c1,c3]
colors=[c0,c1]

fig = plt.figure(figsize=(7,7))
sns.scatterplot(data=plot_data_mean_pooled, 
                x='PC 1', 
                y='PC 2', 
                hue='Sex', 
                s=5, 
                palette=colors, 
                hue_order=["Kvinde", "Mand"])

#plt.savefig("../../output/PCA_train_mean_pooled_sex_s5_wo_unknown_woGrid.pdf", bbox_inches="tight")
#plt.savefig("../../output/PCA_train_mean_pooled_sex_s5_wo_unknown_woGrid.png", bbox_inches="tight")

In [None]:
# loading dataframe with diagnoses for encounters and getting the action diagnosis SKS (ICD-10) code
diagnosis_df = pd.read_parquet("../../data/acuteReadmission/afregningsdiagnose-copy.parquet")

def sort_SKS_hierarchy(SKSCode):
    """
    Function to keep DF20 as the most important for action diagnosis.
    """
    if SKSCode.startswith('DF'):
        if SKSCode == 'DF20':
            return '0' + SKSCode
        return SKSCode
    return 'Z' + SKSCode  # Place non-'DF' codes at the end

def get_action_diagnosis(afregnings):
    """
    Function to get information about the action diagnosis or
    main diagnosis related to the admission.
    Parameters
    ------------
    - afregnings. Afregningsdiagnose dataframe.
    """
    afregnings_action = afregnings[
        ['PatientDurableKey',
        'EncounterKey',
        'SKSCode',
        'IsActionDiagnosis']
    ]
    # We only want one main for encounter since we merge to encounters
    afregnings_action = afregnings_action[afregnings_action.IsActionDiagnosis == 1]
    afregnings_action.drop(columns='IsActionDiagnosis', inplace = True)
    # We need to connect to encounters table --> we need some encounterKey
    afregnings_action = afregnings_action[afregnings_action.EncounterKey != -1]
    # Create a temporary sorting column based on hierarchy
    afregnings_action['SortingColumn'] = afregnings_action['SKSCode'].apply(sort_SKS_hierarchy)
    # Group by EncounterKey, sort, and select the first row in each group
    afregnings_action = afregnings_action.sort_values(
        by=['EncounterKey', 'SortingColumn']
        ).groupby('EncounterKey').head(1)
    
    afregnings_action = afregnings_action.drop(columns='SortingColumn')
    # Reset the index
    afregnings_action = afregnings_action.reset_index(drop=True)
    return afregnings_action

afregnings_action = get_action_diagnosis(diagnosis_df)

In [None]:
intersection_diagnosis_train = set(train_eids.tolist()).intersection(set(afregnings_action.EncounterKey.values.tolist()))

train_diagnosis = np.array([afregnings_action[afregnings_action.EncounterKey==i].SKSCode.item() if i in intersection_diagnosis_train else "unknown" for i in np.unique(train_eids)])
print(train_diagnosis.shape)
np.save("train_diagnosis.npy", train_diagnosis)

In [None]:
# Possibly load the result
train_diagnosis = np.load("train_diagnosis.npy")

In [None]:
def skscode_to_diagnosis(sks):
    if sks.startswith("DF20"):
        return "Schizophrenia"
    elif sks.startswith("DF2"):
        return "Other psychosis"
    elif sks.startswith("DF30") or sks.startswith("DF31"):
        return "Bipolar/manic"
    elif sks.startswith("DF32") or sks.startswith("DF33"):
        return "Depression"
    elif sks.startswith("DF40") or sks.startswith("DF41") or sks.startswith("DF42"):
        return "Anxiety/OCD"
    elif sks.startswith("DF6"):
        return "Personality disorder"
    elif sks.startswith("DF1"):
        return "SUD"
    else:
        return "Other"

train_diagnosis_simple = np.array([sks[1:3] if sks.startswith("DF") else "Other" for sks in train_diagnosis])

train_diagnosis_specific = np.array([skscode_to_diagnosis(sks) for sks in train_diagnosis])

In [None]:
plot_data_cls = {'PC 1': pca_cls[:, 0],
                 'PC 2': pca_cls[:, 1], 
                 'Label': [int(i) for i in train_encounter_labs],
                 'Num notes': train_num_notes_in_encounter,
                  "Sex": train_sex,
                 "Ethnicity": train_eth,
                 "Age": train_age_groups,
                 "Prediction": train_preds,
                 "Probability": train_probs,
                 "Diagnosis": train_diagnosis_simple,
                 "Diagnosis_specific": train_diagnosis_specific
                 }

plot_data_mean_pooled = {'PC 1': pca_mean_pooled[:, 0], 
                         'PC 2': pca_mean_pooled[:, 1],
                         'Label': [int(i) for i in train_encounter_labs],
                         'Num notes': train_num_notes_in_encounter,
                          "Sex": train_sex,
                         "Ethnicity": train_eth,
                         "Age": train_age_groups,
                         "Prediction": train_preds,
                         "Probability": train_probs,
                         "Diagnosis": train_diagnosis_simple,
                         "Diagnosis_specific": train_diagnosis_specific
                         }

In [None]:
sns.set_style("white", {"axes.edgecolor": ".8"})

colors = sns.color_palette(palette='colorblind',n_colors=len(np.unique(train_diagnosis_specific))+3)
colors = [colors[0]]+colors[2:5]+colors[7:10]#+colors[13:]
#print(colors)

fig = plt.figure(figsize=(7,7))
sns.scatterplot(data=plot_data_mean_pooled, 
                x='PC 1', 
                y='PC 2', 
                hue='Diagnosis_specific', 
                s=10, 
                palette=colors+["black"], 
                #style="Diagnosis_specific",
                hue_order= ['Anxiety/OCD', 
                            'Personality disorder', 
                            'Bipolar/manic', 
                            'Depression', 
                            'Other psychosis', 
                            'SUD', 
                            'Schizophrenia' 
                            ][::-1]+["Other"]
               )
plt.legend(loc="upper right", ncol=1, prop={'size': 9});
plt.savefig("../../output/PCA_train_mean_pooled_diagnosis_specific_fig77_s10_colorblind_noGrid.pdf", bbox_inches="tight")
plt.savefig("../../output/PCA_train_mean_pooled_diagnosis_specific_fig77_s10_colorblind_noGrid.png", bbox_inches="tight")

In [None]:
sns.set_style("white", {"axes.edgecolor": ".8"})

colors = sns.color_palette(palette='colorblind',n_colors=len(np.unique(train_diagnosis_simple))-1)

fig = plt.figure(figsize=(7,7))
sns.scatterplot(data=plot_data_mean_pooled, 
                x='PC 1', 
                y='PC 2', 
                hue='Diagnosis', 
                s=10, 
                palette=colors+["black"], 
                #style="Diagnosis",
                hue_order= ["F0", "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "Other"]
               )
plt.legend(loc="upper right", ncol=2);

plt.savefig("../../output/PCA_train_mean_pooled_Fdiagnosis_fig77_s10_colorblind_noGrid.pdf", bbox_inches="tight")
plt.savefig("../../output/PCA_train_mean_pooled_Fdiagnosis_fig77_s10_colorblind_noGrid.png", bbox_inches="tight")

## KNN

In [None]:
from sklearn.metrics import f1_score, roc_auc_score, matthews_corrcoef,  precision_recall_curve, roc_curve, auc, precision_score, recall_score
from sklearn.neighbors import KNeighborsClassifier

def compute_specificity(targets, preds):
    if len(np.unique(targets))>2:
        return None
    else:
        return recall_score(targets, preds, pos_label=0, average="binary")

def compute_auroc(targets,preds):
    if len(np.unique(targets))>2:
        auc_score = roc_auc_score(
                            targets,
                            preds,
                            multi_class="ovr",
                            average="macro",
                        )
    else:
        fpr, tpr, thresholds = roc_curve(targets,preds[:,1])
        auc_score = auc(fpr, tpr)
    return auc_score

def compute_auprc(targets,preds):
    precision, recall, _ = precision_recall_curve(targets, preds)
    area = auc(recall,precision)
    return area

#def compute_f1_neg(targets,preds):
#    return f1_score(targets,preds,pos_label=0)

def compute_f1_weighted(targets,preds):
    return f1_score(targets,preds,average="weighted")

def compute_f1(targets,preds):
    if len(np.unique(targets))>2:
        return f1_score(targets, preds, average="weighted")
    else:
        return f1_score(targets, preds, average="binary")
        
def compute_precision(targets,preds):
    if len(np.unique(targets))>2:
        return precision_score(targets, preds, average="weighted")
    else:
        return precision_score(targets, preds, average="binary")

def compute_recall(targets,preds):
    if len(np.unique(targets))>2:
        return recall_score(targets, preds, average="weighted")
    else:
        return recall_score(targets, preds, average="binary")


def run_eval(probs,preds,targets):
    
    metrics_results_list = [compute_auroc(targets, probs),
                          matthews_corrcoef(targets, preds),
                          compute_f1_weighted(targets, preds),
                          compute_f1(targets, preds),
                          compute_precision(targets, preds),
                          compute_recall(targets, preds)
                           ]
    
    index_ = ["AUC","MCC","F1 AVG", "F1","precision","recall"]
    
    return metrics_results_list, index_
    


def run_knn(train_embs, train_labs, test_embs, test_labs, k):
    
    KNN = Pipeline(
    steps=[("scaler", StandardScaler()), ("knn", KNeighborsClassifier(n_neighbors=k, n_jobs=12))])
    
    KNN.fit(train_embs, train_labs)
    
    test_probs = KNN.predict_proba(test_embs)
    
    train_probs = KNN.predict_proba(train_embs)
    
    return (test_probs, train_probs)

def eval_and_save(test_probs, train_probs, test_labs, train_labs, k, embs_type, task):
    
    test_preds = np.argmax(test_probs,axis=1)
    train_preds = np.argmax(train_probs,axis=1)
   
    test_metrics, _ = run_eval(test_probs, test_preds, test_labs)
    ##print(test_metrics)
    
    train_metrics, index_ = run_eval(train_probs, train_preds, train_labs)
    #print(train_metrics)
    
    metric_results = pd.DataFrame(list(zip(train_metrics,test_metrics)), 
                                  columns = ["Train", "Test"],
                                  index = index_)
    
    metric_results.to_csv("../../output/KNN_{}_{}_{}.csv".format(k,embs_type,task))
    
    return metric_results

In [None]:
train_data_mean_pooled = {'Embs': train_encounter_embs_mean_pooled,
                         'Label': [int(i) for i in train_encounter_labs],
                         'Num notes': train_num_notes_in_encounter,
                          "Sex": train_sex,
                         "Ethnicity": train_eth,
                         "Age": train_age_groups,
                         "Prediction": train_preds,
                         "Probability": train_probs,
                         "Diagnosis": train_diagnosis_simple,
                         "Diagnosis_specific": train_diagnosis_specific
                         }

intersection_diagnosis_test = set(test_eids.tolist()).intersection(set(afregnings_action.EncounterKey.values.tolist()))

test_diagnosis = np.array([afregnings_action[afregnings_action.EncounterKey==i].SKSCode.item() if i in intersection_diagnosis_test else "unknown" for i in np.unique(test_eids)])
print(test_diagnosis.shape)

test_diagnosis_simple = np.array([sks[1:3] if sks.startswith("DF") else "Other" for sks in test_diagnosis])
test_diagnosis_specific = np.array([skscode_to_diagnosis(sks) for sks in test_diagnosis])

test_data_mean_pooled = {'Embs': test_encounter_embs_mean_pooled,
                         'Label': [int(i) for i in test_encounter_labs],
                         'Num notes': test_num_notes_in_encounter,
                          "Sex": test_sex,
                         "Ethnicity": test_eth,
                         "Age": test_age_groups,
                         "Diagnosis": test_diagnosis_simple,
                         "Diagnosis_specific": test_diagnosis_specific
                         }

In [None]:
train_embs = train_encounter_embs_mean_pooled
train_labs = train_encounter_labs
test_embs = test_encounter_embs_mean_pooled
test_labs = test_encounter_labs

embs_type = "mean_pooled"
task = "finetuned_psyroberta_label"

for k in [5,10,50]:
    test_probs, train_probs = run_knn(train_embs, train_labs, test_embs, test_labs, k)
    eval_and_save(test_probs, train_probs, test_labs, train_labs, k, embs_type, task)

In [None]:
train_embs = train_encounter_embs_mean_pooled
id2index = {k:v for v, k in enumerate(np.unique(train_diagnosis_simple))}
train_labs = [id2index[i] for i in train_diagnosis_simple]

test_embs = test_encounter_embs_mean_pooled
test_labs = [id2index[i] for i in test_diagnosis_simple]

embs_type = "mean_pooled"
task = "finetuned_psyroberta_diagnosis"

for k in [5,10,50]:
    test_probs, train_probs = run_knn(train_embs, train_labs, test_embs, test_labs, k)
    eval_and_save(test_probs, train_probs, test_labs, train_labs, k, embs_type, task)

In [None]:
train_embs = train_encounter_embs_mean_pooled
id2index = {k:v for v, k in enumerate(np.unique(train_eth))}
train_labs = [id2index[i] for i in train_eth]

test_embs = test_encounter_embs_mean_pooled
test_labs = [id2index[i] for i in test_eth]

embs_type = "mean_pooled"
task = "finetuned_psyroberta_ethnicity"

for k in [5,10,50]:
    test_probs, train_probs = run_knn(train_embs, train_labs, test_embs, test_labs, k)
    eval_and_save(test_probs, train_probs, test_labs, train_labs, k, embs_type, task)

In [None]:
train_embs = train_encounter_embs_mean_pooled
id2index = {k:v for v, k in enumerate(np.unique(train_diagnosis_specific))}
train_labs = [id2index[i] for i in train_diagnosis_specific]

test_embs = test_encounter_embs_mean_pooled
test_labs = [id2index[i] for i in test_diagnosis_specific]


embs_type = "mean_pooled"
task = "finetuned_psyroberta_diagnosis_specific"

for k in [5,10,50]:
    test_probs, train_probs = run_knn(train_embs, train_labs, test_embs, test_labs, k)
    eval_and_save(test_probs, train_probs, test_labs, train_labs, k, embs_type, task)

In [None]:
train_indels = [i for i,j in enumerate(train_age_groups) if j != "Unknown"]
train_age_groups_ = np.array(train_age_groups)[np.array(train_indels)]

train_embs = train_encounter_embs_mean_pooled[np.array(train_indels)]
id2index = {k:v for v, k in enumerate(np.unique(train_age_groups_))}
train_labs = [id2index[i] for i in train_age_groups_]


test_indels = [i for i,j in enumerate(test_age_groups) if j != "Unknown"]
test_age_groups_ = np.array(test_age_groups)[np.array(test_indels)]

test_embs = test_encounter_embs_mean_pooled[np.array(test_indels)]
test_labs = [id2index[i] for i in test_age_groups_]

embs_type = "mean_pooled"
task = "finetuned_psyroberta_age"

for k in [5,10,50]:
    test_probs, train_probs = run_knn(train_embs, train_labs, test_embs, test_labs, k)
    eval_and_save(test_probs, train_probs, test_labs, train_labs, k, embs_type, task)

In [None]:
train_indels = [i for i,j in enumerate(train_sex) if j != "Unknown"]
train_sex_ = np.array(train_sex)[np.array(train_indels)]

train_embs = train_encounter_embs_mean_pooled[np.array(train_indels)]
id2index = {k:v for v, k in enumerate(np.unique(train_sex_))}
train_labs = [id2index[i] for i in train_sex_]

test_indels = [i for i,j in enumerate(test_sex) if j != "Unknown"]
test_sex_ = np.array(test_sex)[np.array(test_indels)]
test_embs = test_encounter_embs_mean_pooled[np.array(test_indels)]
test_labs = [id2index[i] for i in test_sex_]

embs_type = "mean_pooled"
task = "finetuned_psyroberta_sex_wo_unknown"

for k in [5,10,50]:
    test_probs, train_probs = run_knn(train_embs, train_labs, test_embs, test_labs, k)
    eval_and_save(test_probs, train_probs, test_labs, train_labs, k, embs_type, task)