In [None]:
# check to see if safetensors work
from safetensors import safe_open

In [None]:
from accelerate import Accelerator
import datasets
from torch.utils.data import DataLoader
from datasets import Dataset, DatasetDict
import torch
import torch.nn as nn
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
import numpy as np
import os
import argparse
import random
import pandas as pd
from attention_flow_and_rollout import compute_joint_attention

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from azure.ai.ml import MLClient#, Input, command
from azure.identity import DefaultAzureCredential
import sys
sys.path.append("../..")
from utils import azure_ml_configs


# Get a handle to the workspace
ml_client = MLClient(
    credential=DefaultAzureCredential(),
    subscription_id=azure_ml_configs.subscription_id,
    resource_group_name=azure_ml_configs.resource_group,
    workspace_name=azure_ml_configs.workspace_name,
)

#data_asset = ml_client.data.get(name="clinicalNote_AcuteReadmission_DedupCont", version=1)
data_asset = ml_client.data.get(name="clinicalNote_AcuteReadmission", version=1)

print(f"Data asset URI: {data_asset.path}")

In [None]:
parser = argparse.ArgumentParser()

parser.add_argument("--data", type=str, default=None)
parser.add_argument("--pretrained_model_path", type=str, default="../")
parser.add_argument("--model_name", type=str, default="psyroberta_for_acute_readmission_prediction")
parser.add_argument("--checkpoint_dir", type=str)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--nrows", type=int, default=None, help="To load only n rows from data for development tests.")
parser.add_argument("--shuffle", action="store_true")
parser.add_argument("--discharge_notes_only",
                    action="store_true",
                    help="If passed, filters data to only include discharge summaries.")
parser.add_argument("--max_seq_splits", 
                    type=int, 
                    default=None, 
                    help="The number of splits that long notes will be split to. Default None means that every part of a long note is included. However, some notes are extremely long (tens of thousands tokens) compared to the majority (median ~130 tokens). Based on descriptive analysis of sequence lengths across regions and psychiatric centers, max_seq_splits=4 might be appropriate, because for the center with longest notes, 75% of are less than 1600 tokens long.")
parser.add_argument("--scale_loss", action="store_true")
parser.add_argument("--on_test", action="store_true")
parser.add_argument(
        "--with_tracking",
        action="store_true",
        help="Whether to enable experiment trackers for logging.",
    )
parser.add_argument("--random_seed", type=int, default=42)
parser.add_argument("--text_column_name", type=str, default="text_names_removed_step2")


args = parser.parse_args(["--model_name", "psyroberta_p4_epoch12",
                          "--pretrained_model_path", "../../finetuning/acutereadm_finetuned_models/",
                          "--data", data_asset.path,
                          "--checkpoint_dir", "../../result_files/",
                          "--batch_size", "1",
                          "--random_seed", "22",
                          #"--text_column_name", "DedupCont", 
                          "--discharge_notes_only",
                          "--scale_loss",
                          "--on_test"
                         ])

print(args, "\n")
data_path = args.data
model_path = args.pretrained_model_path

if args.discharge_notes_only:
    directory = "dischargesum"
else:
    directory = "allnotes" 

model_name = args.model_name
epoch = 11
checkpoint_dir = args.checkpoint_dir # with azure, should be within "./outputs"
batch_size = args.batch_size

print("max_seq_splits=", args.max_seq_splits)
print("disharge notes only =", args.discharge_notes_only)
print("model:", model_path+directory+"/"+model_name)

# setting random seeds
np.random.seed(args.random_seed)
random.seed(args.random_seed)
torch.manual_seed(args.random_seed)
torch.cuda.manual_seed_all(args.random_seed)

In [None]:
# randomness for DataLoader, see https://pytorch.org/docs/stable/notes/randomness.html
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(args.random_seed)

#device = torch.device("cuda:0")


model = AutoModelForSequenceClassification.from_pretrained(model_path+directory+"/"+model_name, 
                                                            local_files_only=True,
                                                            use_safetensors=True, 
                                                            output_hidden_states=True,
                                                            output_attentions=True).cuda()
#model = torch.compile(model)

tokenizer = AutoTokenizer.from_pretrained(model_path+directory+"/"+model_name, local_files_only=True)

assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [None]:
# loading and prepraring data
cols = [args.text_column_name, "Acute", "set", "Type", "PatientDurableKey", "EncounterKey", "CreationInstant"]
df = pd.read_csv(data_path, usecols=cols, nrows=args.nrows)
# make sure the data is sorted by patient id, encounter and date
df.sort_values(by=["PatientDurableKey", "EncounterKey", "CreationInstant"],inplace=True)
#rename main columns of interest
df.rename(columns={args.text_column_name: "text", "Acute": "label"}, inplace=True)

if args.discharge_notes_only:
    df = df[df["Type"].str.contains("Udskrivningsresume|Udskrivningsresum√©")==True].copy()
    # to do: take 1 when there are more than 1.

# concatenating texts on patient and encounter id
df = df.groupby(["PatientDurableKey", "EncounterKey", "label", "set"]).text.apply(f'{tokenizer.sep_token}'.join).reset_index()

In [None]:
data_dict = {
    "train": Dataset.from_pandas(df[df.set=="train"]),
    "validation": Dataset.from_pandas(df[df.set=="val"]),
    "test": Dataset.from_pandas(df[df.set=="test"])
    }


raw_datasets = DatasetDict(data_dict)

text_column_name = "text"

def tokenize_function(examples):
    input_ids = []
    attention_masks = []
    labs = []
    patientids = []
    encounterids = []
    texts = []
    for x,y, patient_id, encounter_id in list(zip(examples["text"], examples["label"], examples["PatientDurableKey"], examples["EncounterKey"])):
        encoded_dict = tokenizer.encode_plus(
            x,  # Sentence to encode
            add_special_tokens=True,  # Add '[CLS]' and '[SEP]' or equivelant for roberta
            max_length=512,  # Pad & truncate all sentences.
            padding="max_length", #(needing to specify truncation=True depends on version)
            truncation=True,
            return_overflowing_tokens=True, # return lists of tokens above 512 
            return_offsets_mapping=True,
            stride=32, # The stride used when the context is too large and is split across several features.
            return_attention_mask=True,  # Construct attn. masks.
            return_tensors='pt'  # Return pytorch tensors.
        )
        for inputs, attentions in list(zip(encoded_dict['input_ids'],encoded_dict['attention_mask']))[:args.max_seq_splits]:
            #print(i.shape)
            # Add the encoded sentence to the list.
            input_ids.append(inputs)
            texts.append(tokenizer.decode(inputs))
            #And its attention mask (simply differentiates padding from non-padding).
            attention_masks.append(attentions)
            labs.append(y)
            patientids.append(patient_id)
            encounterids.append(encounter_id)
    assert len(input_ids) == len(attention_masks) == len(labs) == len(patientids) == len(encounterids)
    sample = {"inputs": input_ids,
            "attn_masks": attention_masks,
            "labels": labs,
            "patient_id": patientids,
            "encounter_id": encounterids,
            "text_split":texts}
    return sample

In [None]:
tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=None,
            remove_columns=raw_datasets['validation'].column_names,
            #load_from_cache_file=not args.overwrite_cache,
            desc="Running tokenizer on every text in dataset",
        )

tokenized_datasets["train"].set_format(type='pt', columns=['inputs', 'attn_masks', 'labels', 'patient_id', 'encounter_id'])
tokenized_datasets["validation"].set_format(type='pt',output_all_columns=True, columns=['inputs', 'attn_masks', 'labels', 'patient_id', 'encounter_id'])
tokenized_datasets["test"].set_format(type='pt', output_all_columns=True, columns=['inputs', 'attn_masks', 'labels', 'patient_id', 'encounter_id'])

traindata = tokenized_datasets["train"]
valdata = tokenized_datasets["validation"]
testdata = tokenized_datasets["test"]

In [None]:
def create_dataloaders():
    
    train_dataloader = DataLoader(dataset=traindata, 
                                    shuffle=args.shuffle, 
                                    batch_size=batch_size,
                                    #num_workers=0,
                                    worker_init_fn=seed_worker,
                                    generator=g)

    val_dataloader = DataLoader(dataset=valdata, 
                                    shuffle=args.shuffle, 
                                    batch_size=batch_size,
                                    #num_workers=0,
                                    worker_init_fn=seed_worker,
                                    generator=g)

    test_dataloader = DataLoader(dataset=testdata, 
                                    shuffle=args.shuffle, 
                                    batch_size=batch_size,
                                    #num_workers=0,
                                    worker_init_fn=seed_worker,
                                    generator=g)
        
    return train_dataloader,val_dataloader,test_dataloader

In [None]:
def accelerate_forward_and_explain(model):
    
    accelerator_log_kwargs = {}
    if args.with_tracking:
        accelerator_log_kwargs["log_with"] = args.report_to
        accelerator_log_kwargs["project_dir"] = args.log_dir #args.output_dir

    accelerator = Accelerator(**accelerator_log_kwargs)

    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
    
    train_dataloader, val_dataloader, test_dataloader = create_dataloaders()
    
    
    model, train_dataloader, val_dataloader, test_dataloader = accelerator.prepare(
        model, train_dataloader, val_dataloader, test_dataloader
    )



    def attention_rollout(all_attentions, inputs):
        #tokens = tokenizer.convert_ids_to_tokens(inputs.tolist()[0])
        _attentions = [att.detach().cpu().numpy() for att in all_attentions]
        attentions_mat = np.stack(_attentions, axis=0).squeeze()
        #print("attn shape:", attentions_mat.shape)
        res_att_mat = attentions_mat.sum(axis=1)/attentions_mat.shape[1]
        joint_attentions = compute_joint_attention(res_att_mat, add_residual=True)
        
        return joint_attentions[-1].sum(axis=0)


    def run_attention_rollout(dataloader):
        model.eval()
        tokens_attentions = []
        for batch in tqdm(dataloader):
            #print(batch)
            inputs, attn, targets = batch["inputs"],batch["attn_masks"], batch["labels"]
            with torch.no_grad():
                outputs = model(inputs.cuda(), attention_mask=attn.cuda(), labels=targets.cuda())
            
            all_hidden_states, all_attentions =  outputs['hidden_states'], outputs['attentions']
            all_attentions, all_token_ids = accelerator.gather_for_metrics((all_attentions, inputs))
            att_rollout = attention_rollout(all_attentions,all_token_ids)
            tokens = tokenizer.convert_ids_to_tokens(all_token_ids.cpu()[0])
            assert len(tokens)==len(all_token_ids.cpu()[0])==len(att_rollout)
            
            #print(len(all_token_ids.cpu()[0]), len(att_rollout))
            
            probabilities = nn.functional.softmax(outputs.logits, dim=-1)
            probabilities, references = accelerator.gather_for_metrics((probabilities,batch["labels"]))


            predictions = np.argmax(probabilities.detach().cpu(), axis=1).flatten()
            #predictions = torch.argmax(probabilities.detach().cpu()).flatten()
            
            pos_probs = probabilities[:,1:].flatten().item()
            #print(pos_probs)
            
            pid, eid = accelerator.gather_for_metrics((batch["patient_id"], batch["encounter_id"]))
            
            tokens_attentions.append({"token_ids":all_token_ids.cpu()[0].tolist(),
                                      "tokens": tokens,
                                      "attn_rollout":att_rollout.tolist(),
                                      "pos_prob": pos_probs,
                                      "pid": pid.item(),
                                      "eid": eid.item(),
                                      "text": tokenizer.decode(all_token_ids.cpu()[0])})
           

        #print(tokens_attentions[:10])
        return tokens_attentions
    
    
    
    attn_train = run_attention_rollout(train_dataloader)
    attn_train_df = pd.DataFrame(attn_train)
    if accelerator.is_main_process:
        attn_train_df.to_csv(f'../../result_files/{directory}_{model_name}_AR_train_results.csv')

In [None]:
from accelerate import notebook_launcher
torch.set_float32_matmul_precision('high')

notebook_launcher(accelerate_forward_and_explain, (model,), num_processes=1)