### Calculating PPL in Clinical Text 

In [113]:
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, BertLMHeadModel,\
    BertTokenizerFast, XLNetLMHeadModel, XLNetTokenizer, AutoTokenizer, AutoModelForCausalLM,\
    Trainer, TrainingArguments
from datasets import load_dataset
import pandas as pd
import numpy as np
from itertools import chain
import spacy
from typing import List
import pickle 
import os

In [48]:
device = 'cuda'

In [22]:
loaded_open_web_text = load_dataset('openwebtext')['train']

In [None]:
ds = load_dataset('wikitext', 'wikitext-2-raw-v1')
test_data = ds['test']['text']

In [22]:
def compute_ppl(model: str, test_data:List[str], stride_div=2):
    model = AutoModelForCausalLM.from_pretrained(model).to(device)
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    max_length = model.config.n_positions
    
    encodings = tokenizer('\n'.join(test_data), return_tensors='pt')
    
    """the amount to window the input by"""
    stride = int(max_length / stride_div) 
    lls = []
    for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, encodings.input_ids.size(1))
        trg_len = end_loc - i    # may be different from stride on last loop
        input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:,:-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            log_likelihood = outputs[0] * trg_len

        lls.append(log_likelihood)
    ppl = torch.exp(torch.stack(lls).sum() / end_loc)
    return ppl

In [8]:
compute_ppl('gpt2', test_data)

Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 562/562 [00:20<00:00, 27.65it/s]


tensor(25.1705, device='cuda:0')

In [2]:
# wikiText2 downloaded via huggingface datasets 
# downloaded to disk, as GPU box didn't have access to certain (i.e. github) URLs.

In [26]:
wiki_text = pickle.load(open('wikiText2-test.pickle', 'rb'))

In [None]:
open_web_text = load_dataset('openwebtext')['train']

In [4]:
open_web_text_val = open_web_text[np.random.randint(0, len(open_web_text), size=5000)]

In [92]:
with (open('open_web_text_val.txt', 'w')) as f:
      f.writelines(open_web_text_val['text'])

In [7]:
compute_ppl('gpt2', open_web_text_val['text'])

Token indices sequence length is longer than the specified maximum sequence length for this model (5602680 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 10943/10943 [06:39<00:00, 27.39it/s]


tensor(21.1045, device='cuda:0')

In [96]:
with open('wikieText2-test.txt', 'w') as f:
    f.writelines(wiki_text)

In [14]:
compute_ppl('gpt2', ds['validation']['text'])

Token indices sequence length is longer than the specified maximum sequence length for this model (251048 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 491/491 [00:17<00:00, 27.69it/s]


tensor(26.3928, device='cuda:0')

In [34]:
wiki_text = pickle.load(open('wikiText2-test.pickle', 'rb'))

In [35]:
compute_ppl('gpt2', wiki_text)

Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 562/562 [00:20<00:00, 27.76it/s]


tensor(25.1705, device='cuda:0')

In [36]:
compute_ppl('lm_outputs_mimic_stroke', wiki_text)

Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 562/562 [00:20<00:00, 27.69it/s]


tensor(126.3712, device='cuda:0')

In [37]:
compute_ppl('lm_outputs_mimic', wiki_text)

Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 562/562 [00:20<00:00, 27.61it/s]


tensor(427.2962, device='cuda:0')

In [84]:
val_ds = load_dataset('text.py', data_files={'validation': 'lm_data_mimic/val_file.txt'})

Using custom data configuration default
Reusing dataset text (/home/thomass/.cache/huggingface/datasets/text/default-2f2b11b60e8f6607/0.0.0/71d5fe73c3304ac69797de827d0332aab54788917557b56f6e36824791677ac5)


In [None]:
print(compute_ppl('lm_outputs_mimic_stroke', mimic_stroke_val))

In [None]:
print(compute_ppl('lm_outputs_mimic', mimic_val))

Token indices sequence length is longer than the specified maximum sequence length for this model (99884736 > 1024). Running this sequence through the model will result in indexing errors
 58%|█████▊    | 112372/195088 [1:08:42<50:21, 27.37it/s]  

In [79]:
# shannon bound - sub(word) level entropy / theoretical maximum entropy. i.e. H(P)
np.log2(model.config.vocab_size)

15.617036934287741

## Train / Val / Test Splits 

In [98]:
df = pd.read_csv('../data/43411_notes.csv')

In [100]:
df = pd.read_csv('../data/mimic_3_all_notes_with_20_occurrences_of_primary_condition.csv')

  interactivity=interactivity, compiler=compiler, result=result)


In [105]:
df = pd.read_pickle('stroke_notes.pickle')

In [117]:
df['text'] = df.body_analysed

In [115]:
def split_texts(df: pd.DataFrame, out_dir: str='.', hadm_id_col='hadm_id'):   
    hadm_ids = df[hadm_id_col].unique()
    texts = {}
    for i, df_adm in df.groupby(hadm_id_col):
        texts[i] = df_adm.text
    train_set_len = round(len(texts.keys()) * 0.8)
    train_hadms = [k for i, k in enumerate(texts.keys()) if i < train_set_len]
    othr_hadms = [k for k in texts.keys() if k not in train_hadms]
    val_hadms = [k for k in othr_hadms[0:int(len(othr_hadms) / 2)]]
    test_hadms = [k for k in othr_hadms if k not in val_hadms]             

    train_texts = list(chain.from_iterable([texts[i] for i in train_hadms]))
    val_texts = list(chain.from_iterable([texts[i] for i in val_hadms]))
    test_texts = list(chain.from_iterable([texts[i] for i in test_hadms]))

    os.makedirs(out_dir, exist_ok=True)
    with open(f'{out_dir}/train_file.txt', 'w') as f:
        for t in train_texts:
            f.write(t)
    with open(f'{out_dir}/val_file.txt', 'w') as f:
        for t in val_texts:
            f.write(t)
    with open(f'{out_dir}/test_file.txt', 'w') as f:
        for t in test_texts:
            f.write(t)

In [116]:
split_texts(df, out_dir='kch_stroke_data', hadm_id_col='clientvisit_guid')

In [None]:
# gpt perplexity on clinical text, non-fine-tuned

In [73]:
encodings = tokenizer('\n\n'.join(test_texts), return_tensors='pt')

Token indices sequence length is longer than the specified maximum sequence length for this model (1224234 > 1024). Running this sequence through the model will result in indexing errors


In [76]:
ppl = compute_ppl(encodings)

100%|██████████| 1196/1196 [03:55<00:00,  5.07it/s]


In [81]:
np.log2(ppl.detach().cpu())

tensor(4.2651)