In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
import re, os

from nltk import word_tokenize
from string import punctuation

import spacy

nlp = spacy.load("en_core_web_sm")

In [3]:
from utils import read_ann, read_report, format_doc_and_labels

In [4]:
DATA_DIR = 'chifir-cytology-and-histopathology-invasive-fungal-infection-reports-1.0.0/'

# 1. Digest reports and annotations

In [5]:
root_ann_dir = os.path.join(DATA_DIR, 'annotations/')
root_txt_dir = os.path.join(DATA_DIR, 'reports/')

In [6]:
# check all entities

all_ann_files = [f for f in os.listdir(root_ann_dir) if f.endswith('ann')]

all_ents = []
for ann_f in all_ann_files:
    try:
        ents = read_ann(os.path.join(root_ann_dir, ann_f))
    except:
        print(ann_f)
    all_ents.extend(ents)
len(all_ents)

1155

In [7]:
ent_types = [t for (_,_,_,t) in all_ents]

set(ent_types)

{'ClinicalQuery',
 'FungalDescriptor',
 'Fungus',
 'Invasiveness',
 'SampleType',
 'Stain',
 'equivocal',
 'negative',
 'positive'}

In [8]:
# combine report and ann files for every pt to get the right format for bert
all_pt = [f.strip('.ann') for f in os.listdir(root_ann_dir) if f.endswith('ann')]

all_data = []

for pt in all_pt:
    note = read_report(os.path.join(DATA_DIR, f'reports/{pt}.txt'))
    ents = read_ann(os.path.join(DATA_DIR, f'annotations/{pt}.ann'))

    doc = nlp(note)

    try:
        lines, tags = format_doc_and_labels(doc, ents, remove_shorts=3)

        assert len(lines) == len(tags)
        for i, (token, tag) in enumerate(zip(lines, tags)):
            all_data.append({
                'id': f'{pt}_{i}',
                'tokens': token,
                'tags': tag
            })
            
    except:
        print(pt)


pt57_r1


In [9]:
# total number of BIO tags
all_tags = []
for d in all_data:
    all_tags.extend(d['tags'])
all_tags = sorted(list(set(all_tags)))
len(all_tags)

18

In [10]:
# pt57_r1 is a special case w/ some data error starting from T8 (offset by 1 char; fix T8 string)
from utils import process_note_ent_pair

pt = 'pt57_r1'

note = read_report(os.path.join(DATA_DIR, f'reports/{pt}.txt'))
ents = read_ann(os.path.join(DATA_DIR, f'annotations/{pt}.ann'))

d = process_note_ent_pair(note, ents, pt)

notseen not see
927 934
** pt57_r1 **
GMS  GM
938 941
** pt57_r1 **
Fungi  Fung
1280 1285
** pt57_r1 **
Pneumocystis /Pneumocysti
1286 1298
** pt57_r1 **
not seen  not see
1303 1311
** pt57_r1 **
GMS  GM
1315 1318
** pt57_r1 **


In [11]:
note[1316:1319], note[1315:1318]

('GMS', ' GM')

In [12]:
# manually fixed pt57_r1 in the new folder `annotations-fixed`

# 2. Get train/val/test sets

In [13]:
TAGS = [
    'B-ClinicalQuery',
    'I-ClinicalQuery',
    'B-FungalDescriptor',
    'I-FungalDescriptor',
    'B-Fungus',
    'I-Fungus',
    'B-Invasiveness',
    'I-Invasiveness',
    'B-SampleType',
    'I-SampleType',
    'B-Stain',
    'I-Stain',
    'B-equivocal', # didn't find I-equivocal
    'B-negative',
    'I-negative',
    'B-positive',
    'I-positive',
    'O',
]

tag2id = {t:i for i, t in enumerate(TAGS)}
len(tag2id)

18

In [14]:
import datasets
from datasets import load_dataset, load_from_disk
from datasets import Features, Sequence, Value, ClassLabel

In [15]:
all_pt = [f.strip('.ann') for f in os.listdir(root_ann_dir) if f.endswith('ann')]

def encode_tags(example):
    example['ner_tags'] = [tag2id[tag] for tag in example['tags']]
    return example
    
def get_pt_data(pts, reports_path=root_txt_dir, annotations_path=root_ann_dir):

    all_data = []
    
    for pt in pts:
        note = read_report(os.path.join(reports_path, f'{pt}.txt'))
        ents = read_ann(os.path.join(annotations_path, f'{pt}.ann'))
    
        doc = nlp(note)
    
        try:
            lines, tags = format_doc_and_labels(doc, ents, remove_shorts=3)
    
            assert len(lines) == len(tags)
            for i, (token, tag) in enumerate(zip(lines, tags)):
                all_data.append({
                    'id': f'{pt}_{i}',
                    'tokens': token,
                    'ner_tags': tag
                })
                
        except:
            print(pt)

    features = Features({
        'id': Value('string'),
        'tokens': Sequence(Value('string')),
        'ner_tags': Sequence(ClassLabel(names=TAGS))
    })

    dset = datasets.Dataset.from_list(all_data, features=features)
    
    # dset = dset.map(encode_tags, remove_columns=['tags'])

    return dset

In [16]:
mdf = pd.read_csv('./chifir_metadata.csv')
mdf.head()

Unnamed: 0,histopathology_id,patient_id,report_no,y_report,dataset,val_fold
0,658,13,1,Positive,development,10.0
1,189,14,1,Positive,development,7.0
2,529,28,1,Negative,development,8.0
3,325,28,2,Positive,development,8.0
4,559,28,3,Negative,development,8.0


In [17]:
# check metadata matches reports/anns
for p, r in zip(mdf.patient_id, mdf.report_no):
    assert f'pt{p}_r{r}' in all_pt

In [18]:
# get data for three splits

train_pt = []
val_pt = []
test_pt = []

# use val fold 1, 2 as validation to tune hparam
msk1 = mdf.dataset == 'development'
msk2 = mdf.val_fold.isin([1.,2.])

for _, row in mdf[msk1 & ~msk2].iterrows():
    train_pt.append(f"pt{row['patient_id']}_r{row['report_no']}")

for _, row in mdf[msk1 & msk2].iterrows():
    val_pt.append(f"pt{row['patient_id']}_r{row['report_no']}")

for _, row in mdf[mdf.dataset == 'test'].iterrows():
    test_pt.append(f"pt{row['patient_id']}_r{row['report_no']}")

assert len(train_pt) + len(val_pt) + len(test_pt) == len(all_pt)


In [19]:
# the report w/ error appears in test set
'pt57_r1' in test_pt

True

In [20]:
ds_train = get_pt_data(train_pt)
ds_val = get_pt_data(val_pt)
ds_test = get_pt_data(test_pt, annotations_path=os.path.join(DATA_DIR, 'annotations-fixed/'))


In [21]:
raw_datasets = datasets.DatasetDict(
    {
        'train': ds_train,
        'validation': ds_val,
        'test': ds_test,
    })

raw_datasets.save_to_disk('chifir_hf')

Saving the dataset (0/1 shards):   0%|          | 0/3346 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/819 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1049 [00:00<?, ? examples/s]