In [None]:
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, BertLMHeadModel,\
    BertTokenizerFast, XLNetLMHeadModel, XLNetTokenizer, AutoTokenizer, AutoModelForCausalLM,\
    Trainer, TrainingArguments
from datasets import load_dataset
import pandas as pd
import numpy as np
from itertools import chain
import spacy
from typing import List

# Calculating PPL in Clinical Text 

## Data Preperation

In [27]:
# stroke only text
df_stroke = pd.read_csv('../data/43411_notes.csv')

In [2]:
# all mimic text
df_all = pd.read_csv('../data/mimic_3_all_notes_with_20_occurrences_of_primary_condition.csv')

  interactivity=interactivity, compiler=compiler, result=result)


In [5]:
def split_texts(df: pd.DataFrame, out_dir: str='.'):   
    hadm_ids = df.hadm_id.unique()
    texts = {}
    # recs = df.sort_values(['hadm_id', 'category', 'description', 'charttime']).groupby(['hadm_id', 'category', 'description'])
    for i, df_adm in df.groupby('hadm_id'):
        texts[i] = df_adm.text
    train_set_len = round(len(texts.keys()) * 0.8)
    train_hadms = [k for i, k in enumerate(texts.keys()) if i < train_set_len]
    othr_hadms = [k for k in texts.keys() if k not in train_hadms]
    val_hadms = [k for k in othr_hadms[0:int(len(othr_hadms) / 2)]]
    test_hadms = [k for k in othr_hadms if k not in val_hadms]             

    train_texts = list(chain.from_iterable([texts[i] for i in train_hadms]))
    val_texts = list(chain.from_iterable([texts[i] for i in val_hadms]))
    test_texts = list(chain.from_iterable([texts[i] for i in test_hadms]))

    with open(f'{out_dir}/train_file.txt', 'w') as f:
        for t in train_texts:
            f.write(t)
    with open(f'{out_dir}/val_file.txt', 'w') as f:
        for t in val_texts:
            f.write(t)
    with open(f'{out_dir}/test_file.txt', 'w') as f:
        for t in test_texts:
            f.write(t)

In [6]:
split_texts(df_all, out_dir='lm_data_mimic')

In [34]:
df['cat_desc'] = df.apply(lambda r: f'{r.category}:{r.description}', axis=1)

In [44]:
print(f'# Docs: {df.shape[0]}')
print(f'# Avg Len:{np.average(df.text.apply(len))}')
print(f'# Doc Types:{len(df["cat_desc"].unique())}')

# Docs: 1172433
# Avg Len:2201.018789133366
# Doc Types:3127


In [45]:
df_stroke['cat_desc'] = df_stroke.apply(lambda r: f'{r.category}:{r.description}', axis=1)

In [47]:
print(f'# Docs: {df_stroke.shape[0]}')
print(f'# Avg Len:{np.average(df_stroke.text.apply(len))}')
print(f'# Doc Types:{len(df_stroke["cat_desc"].unique())}')

# Docs: 8213
# Avg Len:2231.6700353098745
# Doc Types:241


In [None]:
split_texts(df_stroke, out_dir='lm_data_mimic_stroke')

## Fine-Tuning of GPT-2
Training script run_clm.py is run via tain_lm.sh
- loading models via final running checkpoint.

In [None]:
def compute_ppl(model: str, test_data:List[str], stride_div=2):
    device = 'cuda'
    model = AutoModelForCausalLM.from_pretrained(model).to(device)
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    max_length = model.config.n_positions
   
    encodings = tokenizer('\n'.join(test_data), return_tensors='pt')
   
    """the amount to window the input by"""
    stride = int(max_length / stride_div)
    lls = []
    for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, encodings.input_ids.size(1))
        trg_len = end_loc - i    # may be different from stride on last loop
        input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:,:-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            log_likelihood = outputs[0] * trg_len

        lls.append(log_likelihood)
    ppl = torch.exp(torch.stack(lls).sum() / end_loc)
    return ppl

In [None]:
test = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
test_data = test['text']

In [None]:
compute_ppl('gpt2', test_data)

In [None]:
compute_ppl('lm_outputs_mimic_stroke', test_data)

In [None]:
compute_ppl('lm_outputs_mimic', test_data)

In [27]:
def compute_ppl(encodings, stride_div=1):
    """the amount to window the input by"""
    stride = int(max_length / stride_div) 
    lls = []
    for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, encodings.input_ids.size(1))
        trg_len = end_loc - i    # may be different from stride on last loop
        input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:,:-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            log_likelihood = outputs[0] * trg_len

        lls.append(log_likelihood)
    ppl = torch.exp(torch.stack(lls).sum() / end_loc)
    return ppl

In [28]:
ppl = compute_ppl(encodings)

100%|██████████| 281/281 [00:10<00:00, 27.41it/s]


In [46]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

124439808

In [29]:
cross_entropy = np.log2(ppl.detach().cpu().numpy())
cross_entropy

6.760457

In [34]:
encodings = tokenizer('\n\n'.join(open('lm_data/test_file.txt')), return_tensors='pt')
ppl = compute_ppl(encodings)
cross_entropy = np.log2(ppl.detach().cpu().numpy())
cross_entropy

Token indices sequence length is longer than the specified maximum sequence length for this model (789575 > 1024). Running this sequence through the model will result in indexing errors
100%|██████████| 772/772 [00:28<00:00, 27.43it/s]


4.5310745

In [16]:
# GPT-2-large: ppl: 16.44
# model card on huggingface: huggingface.co/gpt2

In [79]:
# shannon bound - sub(word) level entropy / theoretical maximum entropy. i.e. H(P)
np.log2(model.config.vocab_size)

15.617036934287741

In [None]:
# cross-entropy loss, i.e. H(P, Q) = H(P) + D_{KL}(P||Q), divergence >= 0, so H(P, Q) >= H(P)
# PPL ==2^{H(P,Q)}

In [None]:
# gpt perplexity on clinical text, non-fine-tuned

In [73]:
encodings = tokenizer('\n\n'.join(test_texts), return_tensors='pt')

Token indices sequence length is longer than the specified maximum sequence length for this model (1224234 > 1024). Running this sequence through the model will result in indexing errors


In [76]:
ppl = compute_ppl(encodings)

100%|██████████| 1196/1196 [03:55<00:00,  5.07it/s]
