## Preprocessing i2b2 Datasets

We use the i2b2 challenge dataset from 2008 (Uzuner, 2009) for evaluating LLMs on comorbidity detection task based on de-identified discharge summaries.

Preprocessing code was adapted from Arroyo et al. (2024): https://github.com/alceballosa/clin-robust/tree/master/preprocessing_notebooks.

Before running the pipeline here, download and unzip the raw i2b2 dataset from the [Harvard DBMI portal](https://portal.dbmi.hms.harvard.edu/projects/n2c2-nlp/).

In [1]:
import os
import os.path as osp
import re
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
import xml.etree.ElementTree as ET
from collections import OrderedDict
import warnings
warnings.filterwarnings('ignore')

from datasets import Dataset, DatasetDict, concatenate_datasets
import huggingface_hub

from transformers import AutoTokenizer

HF_CACHE_DIR = '/data/hf_models' # Modify as needed
DATA_DIR = '/data' # Modify as needed
N2C2_DIR = osp.join(DATA_DIR, 'n2c2')
OBESITY_DIR = osp.join(N2C2_DIR, '2008 Obesity Challenge')

In [None]:
# Utility function for removing excessive whitespace
def transform_string(s):
    s = re.sub(r'(\n\s*|\s*\n)', '\n', s)
    s = re.sub(r'\s{2,}', ' ', s)
    s = s.strip()
    return s

# Optional: Load a tokenizer of interest to filter out discharge notes longer than max token length
hf_api_token = '' # Fill in with your own token
huggingface_hub.login(token=hf_api_token)

# NOTE: We use Llama-2 to filter out long clinical notes, as the Llama-2 tokenizer has the smallest vocabulary size among the LLMs we evaluate.
model_id = 'meta-llama/Llama-2-7b-hf'
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)

def longer_than_max_length(note, tokenizer=tokenizer, max_length=3000, verbose=False):
    '''Checks if a given note is longer than the max token length.'''

    token_length = tokenizer(note, return_tensors='pt')['input_ids'].shape[-1]
    
    if verbose:
        print(f'Length: {token_length}')

    return token_length > max_length

def filter_by_len(data_list, max_length=3000):
    '''Filters out samples with clinical notes longer than max_length.'''

    n_short = 0
    n_long = 0
    short_idxs = []

    for i, sample in enumerate(data_list):
        if longer_than_max_length(sample['text'], max_length=max_length):
            n_long += 1
        else:
            n_short += 1
            short_idxs.append(i)

    assert(n_short + n_long == len(data_list))
    print(f'Included: {n_short}, Excluded: {n_long}')
    data_list = [data_list[i] for i in short_idxs]

    return data_list

Following Arroyo et al. (2024), we focus on 4 prediction tasks for predicting asthma, CAD, diabetes, and obesity, using the "intuitive" annotations on all clinical notes. See https://www.i2b2.org/NLP/Obesity/Documentation.php for more details on how the labels are defined.

In [229]:
train_record_file = osp.join(OBESITY_DIR, 'obesity_patient_records_training.xml')
train_tree = ET.parse(train_record_file)
train_root = train_tree.getroot()

train2_record_file = osp.join(OBESITY_DIR, 'obesity_patient_records_training2.xml')
train2_tree = ET.parse(train2_record_file)
train2_root = train2_tree.getroot()

# Merge the two training document files
for child in train2_root:
    train_root.append(child)

# Labels ("intuitive")
train_annotation_file = osp.join(OBESITY_DIR, 'obesity_standoff_intuitive_annotations_training.xml')
train_annotation_tree = ET.parse(train_annotation_file)
train_annotation_root = train_annotation_tree.getroot()

train_set = []
for docs in train_root:
    for doc in docs:
        doc_id = doc.attrib["id"]
        for text in doc:
            doc_text = text.text
            #doc_text = transform_string(text.text) # For some reason, keeping the whitespace seems to work better
            train_set.append({"id": doc_id, "text": doc_text})

target_diseases = ['Asthma', 'CAD', 'Diabetes', 'Obesity']
disease_to_train_set = {}

# Optional: Set to False to keep all samples
apply_filter = False

for diseaseset in train_annotation_root:
    judgment_type = diseaseset.attrib["source"]
    for disease in diseaseset:
        disease_train_set = []
        disease_name = disease.attrib["name"]

        if disease_name not in target_diseases:
            continue
        
        pbar = tqdm(disease, desc=f'Fetching labels for "{disease_name}"')
        for doc in pbar:
            doc_id = doc.attrib['id']
            doc_judgment = doc.attrib['judgment'] # Y/N/Q

            # Add label to matching clinical note
            for i, sample in enumerate(train_set):
                if sample['id'] == doc_id:
                    feature_name = f"{disease_name.lower()}"
                    assert feature_name not in sample.keys(), f"Feature {feature_name} already exists!"
                    disease_train_set.append(sample | {'label': doc_judgment})

        if apply_filter:
            disease_train_set = filter_by_len(disease_train_set)

        disease_to_train_set[disease_name.lower()] = Dataset.from_list(disease_train_set)

Fetching labels for "Asthma": 100%|██████████| 572/572 [00:00<00:00, 15315.50it/s]
Fetching labels for "CAD": 100%|██████████| 552/552 [00:00<00:00, 15436.89it/s]
Fetching labels for "Diabetes": 100%|██████████| 572/572 [00:00<00:00, 14832.87it/s]
Fetching labels for "Obesity": 100%|██████████| 554/554 [00:00<00:00, 15423.72it/s]


In [230]:
test_record_file = osp.join(OBESITY_DIR, 'obesity_patient_records_test.xml')
test_tree = ET.parse(test_record_file)
test_root = test_tree.getroot()

test_annotation_file = osp.join(OBESITY_DIR, 'obesity_standoff_annotations_test_intuitive.xml')
test_annotation_tree = ET.parse(test_annotation_file)
test_annotation_root = test_annotation_tree.getroot()

test_set = []
for docs in test_root:
    for doc in docs:
        doc_id = doc.attrib["id"]
        for text in doc:
            doc_text = text.text
            test_set.append({"id": doc_id, "text": doc_text})
    
target_diseases = ['Asthma', 'CAD', 'Diabetes', 'Obesity']
disease_to_test_set = {}

# Optional: Set to False to keep all samples
apply_filter = True

for diseaseset in test_annotation_root:
    judgment_type = diseaseset.attrib["source"]
    for disease in diseaseset:
        disease_test_set = []
        disease_name = disease.attrib["name"]
        
        if disease_name not in target_diseases:
            continue
        
        pbar = tqdm(disease, desc=f'Fetching labels for "{disease_name}"')
        for doc in pbar:
            doc_id = doc.attrib["id"]
            doc_judgment = doc.attrib["judgment"] # Y/N/Q
            
            # Add label to matching clinical note
            for i, sample in enumerate(test_set):
                if sample['id'] == doc_id:
                    feature_name = f"{disease_name.lower()}"
                    assert feature_name not in sample.keys(), f"Feature {feature_name} already exists!"
                    disease_test_set.append(sample | {'label': doc_judgment})

        if apply_filter:
            disease_test_set = filter_by_len(disease_test_set)

        disease_to_test_set[disease_name.lower()] = Dataset.from_list(disease_test_set)

Fetching labels for "Asthma": 100%|██████████| 471/471 [00:00<00:00, 23657.76it/s]




Included: 357, Excluded: 114


Fetching labels for "CAD": 100%|██████████| 458/458 [00:00<00:00, 23768.17it/s]


Included: 335, Excluded: 123


Fetching labels for "Diabetes": 100%|██████████| 479/479 [00:00<00:00, 24047.49it/s]


Included: 358, Excluded: 121


Fetching labels for "Obesity": 100%|██████████| 447/447 [00:00<00:00, 23994.14it/s]


Included: 334, Excluded: 113


In [None]:
for disease_name in target_diseases:
    save_dir = osp.join(DATA_DIR, f'n2c2_2008-obesity_{disease_name.lower()}')
    os.makedirs(save_dir, exist_ok=True)
    
    print(f'Saving comorbidity dataset for "{disease_name.lower()}"...')
    disease_dataset = DatasetDict(dict(
        train=disease_to_train_set[disease_name.lower()],
        test=disease_to_test_set[disease_name.lower()]
    ))
    disease_dataset.save_to_disk(save_dir)

In [233]:
from collections import Counter

# Class label counts
for disease_name in target_diseases:
    disease_labels = disease_to_test_set[disease_name.lower()]['label']
    counts = Counter(disease_labels)
    print(f'{disease_name}: {counts}')

Asthma: Counter({'N': 309, 'Y': 48})
CAD: Counter({'Y': 192, 'N': 142, 'Q': 1})
Diabetes: Counter({'Y': 227, 'N': 131})
Obesity: Counter({'N': 191, 'Y': 143})
