In [1]:
import json 
import re
import pandas as pd
import numpy as np
import seaborn as sns

from typing import Dict, Optional, Tuple, Iterable, List
from matplotlib import pyplot as plt
from medcat.cat import CAT
from medcat.cdb import CDB
from medcat.config import Config
from medcat.vocab import Vocab
from medcat.meta_cat import MetaCAT
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.preprocessing.tokenizers import TokenizerWrapperBPE, TokenizerWrapperBERT
from tokenizers import ByteLevelBPETokenizer
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase

  from tqdm.autonotebook import tqdm, trange


In [2]:
DATA_DIR = "../problems/metacat_problems_07mar2023"

## Preprocess data
This function processes the CSV data provided in ```miade-dataset``` into a dataframe ```[text, cui, name, start, end, <category_name>...]```

In [9]:
def preprocess_miade_synthetic_data(data_path, lower_case=True, prefix="p"):
    
    data = pd.read_csv(data_path)
    
    if lower_case:
        data = data.fillna('').astype(str).apply(lambda x: x.str.lower())  # all lower case
    else:
        if prefix == "p":
            data[["p_meta_relevance", "p_meta_confirmed", "p_meta_laterality"]] = data[["p_meta_relevance", "p_meta_confirmed", "p_meta_laterality"]].apply(lambda x: str(x).lower())
    
    # extract cui and name in separate columns
    data[["cui", "name"]] = data[prefix].str.extract(r"^(\d+)\s+\|\s+(.+)$")
    # remove words inside brackets e.g.(disease) 
    data["name"].replace(r"\s*\([^)]*\)", "", regex=True, inplace=True)
    data.drop(prefix, axis=1, inplace=True)
    
    # get the start and end indices of the concept inside {p ...} format
    if prefix == "p":
        p = re.compile(r"(\{p )|\}")
    elif prefix == "m":
        p = re.compile(r"(\{m )|\}")
    elif prefix == "r":
        p = re.compile(r"(\{r )|\}")
        
    start = []
    end = []
    for i in range(len(data)):
        match = [m for m in re.finditer(p, data.text.values[i])]
        if len(match) > 2:
            print("More than one concept in text, only processing first one")
        
        start.append(match[0].end() - 3)  # shift by -3
        end.append(match[1].start() - 3)
    
    data["start"] = start
    data["end"] = end
    
    # remove annotation brackets
    data["text"].replace(p, "", regex=True, inplace=True)  
    
    if prefix == "p":
        # convert labels
        data.replace("no laterality", "none", inplace=True)
        data.replace("positive", "present", inplace=True)
        # tidy up columns
        data.rename(columns={"p_meta_relevance": "relevance", "p_meta_confirmed": "presence", "p_meta_laterality": "laterality (generic)"}, inplace=True)
        data = data[["text", "cui", "name", "start", "end", "relevance", "presence", "laterality (generic)"]]
    
    return data

In [10]:
df = preprocess_miade_synthetic_data(DATA_DIR + "/metacat_problems_positive.csv")

In [11]:
df

Unnamed: 0,text,cui,name,start,end,relevance,presence,laterality (generic)
0,arenaviral hemorrhagic fever injury of ulnar v...,24272004,injury of ulnar vein,29,49,present,confirmed,none
1,dieffenbachia species poisoning fluocinolone a...,293163000,fluocinolone adverse reaction,32,61,present,confirmed,none
2,infection - perineal wound sarcoid myopathy ch...,193251003,sarcoid myopathy,27,43,present,confirmed,none
3,ulcer of anus poisoning caused by flucytosine ...,66491001,poisoning caused by flucytosine,14,45,present,confirmed,none
4,eosinophilic duodenal ulcer eruption cyst of j...,42323001,eruption cyst of jaw,28,48,present,confirmed,none
...,...,...,...,...,...,...,...,...
10995,problem list optic chiasm disorder,70476006,optic chiasm disorder,13,34,present,confirmed,none
10996,leiomyosarcoma of colon stab wound of abdomen ...,283475002,stab wound of abdomen,24,45,present,confirmed,none
10997,sclerosing hemangioma of lung neuritis of saph...,429668008,neuritis of saphenous nerve,30,57,present,confirmed,none
10998,accidental clarithromycin overdose chondrocalc...,15705241000119106,chondrocalcinosis of bilateral shoulders,35,75,present,confirmed,bilateral


In [12]:
df['text'].str.len().mean()

74.931

In [13]:
# sanity checks
for i in range(10):
    print(df.text.values[i][df.start.values[i]:df.end.values[i]])

injury of ulnar vein
fluocinolone adverse reaction
sarcoid myopathy
poisoning caused by flucytosine
eruption cyst of jaw
urachal abscess
uncomplicated non-allergic asthma
self-healing juvenile cutaneous mucinosis
beckwith-wiedemann syndrome
familial primary pulmonary hypertension


In [14]:
for i in range(10):
    print(df[["text", "start", "end", "relevance"]].iloc[i].tolist())

["arenaviral hemorrhagic fever injury of ulnar vein koenig's syndrome ii", 29, 49, 'present']
['dieffenbachia species poisoning fluocinolone adverse reaction pure sensorimotor lacunar infarction', 32, 61, 'present']
['infection - perineal wound sarcoid myopathy charcot-marie-tooth disease, type i', 27, 43, 'present']
['ulcer of anus poisoning caused by flucytosine piperazine overdose', 14, 45, 'present']
['eosinophilic duodenal ulcer eruption cyst of jaw cholesteatoma of attic', 28, 48, 'present']
['mycoplasma balanitis urachal abscess effects of hunger', 21, 36, 'present']
['benign neoplasm of back uncomplicated non-allergic asthma blau syndrome', 24, 57, 'present']
['sedative amnestic disorder self-healing juvenile cutaneous mucinosis episodic tension-type headache', 27, 68, 'present']
['intentional nitrous oxide overdose beckwith-wiedemann syndrome mild pulmonary valve regurgitation', 35, 62, 'present']
['problems familial primary pulmonary hypertension', 9, 48, 'present']


## Generate training data

Randomly samples a number of training data and concatenates it into one training data df ready for input into the next step (this is not very efficient)

In [40]:
config_dict = {
    "historic_path": DATA_DIR + "/metacat_problems_historic.csv",
    "suspected_path": DATA_DIR + "/metacat_problems_suspected.csv",
    "negated_path": DATA_DIR + "/metacat_problems_negated.csv",
    "irrelevant_path": DATA_DIR + "/metacat_problems_irrelevant.csv",
    "present_path": DATA_DIR + "/metacat_problems_positive.csv",
    "confirmed_path": DATA_DIR + "/metacat_problems_positive.csv",
    "confirmed_num": 500,
    "historic_num": 5000,
    "irrelevant_num": 5000,
    "present_num": 500,
    "suspected_num": 5000,
    "negated_num": 5000}

In [41]:
def create_problems_training_data(config_dict):
    historic_df = preprocess_miade_synthetic_data(config_dict.get("historic_path"))
    suspected_df =  preprocess_miade_synthetic_data(config_dict.get("suspected_path"))
    irrelevant_df=  preprocess_miade_synthetic_data(config_dict.get("irrelevant_path"))
    negated_df =  preprocess_miade_synthetic_data(config_dict.get("negated_path"))
    present_df =  preprocess_miade_synthetic_data(config_dict.get("present_path"))
    confirmed_df =  preprocess_miade_synthetic_data(config_dict.get("confirmed_path"))
    
    historic_train = []
    suspected_train = []
    irrelevant_train = []
    negated_train = []
    present_train = []
    confirmed_train = []
    
    if historic_df is not None:
        historic_train = historic_df[:config_dict["historic_num"]]
    if suspected_df is not None:
        suspected_train = suspected_df[:config_dict["suspected_num"]]
    if irrelevant_df is not None:
        irrelevant_train = irrelevant_df[:config_dict["irrelevant_num"]]
    if negated_df is not None:
        negated_train = negated_df[:config_dict["negated_num"]]
    if present_df is not None:
        present_train = present_df[:config_dict["present_num"]]
    if confirmed_df is not None:
        confirmed_train = confirmed_df[-config_dict["confirmed_num"]:]
    
    train_data = pd.concat([historic_train, suspected_train, irrelevant_train, negated_train, present_train, confirmed_train], ignore_index=True)
    print(len(train_data))
    
    return train_data

In [42]:
train_df = create_problems_training_data(config_dict)

20000


In [43]:
train_df.head(10)

Unnamed: 0,text,cui,name,start,end,relevance,presence,laterality (generic)
0,decidual endometritis from age 12,75585005,decidual endometritis,0,21,historic,confirmed,none
1,bilateral metatarsus adductus prev,15667441000119108,bilateral metatarsus adductus,0,29,historic,confirmed,bilateral
2,crohn's disease of pylorus a few years previously,61424003,crohn's disease of pylorus,0,26,historic,confirmed,none
3,no prev hist of congenital facial nerve palsy,230542008,congenital facial nerve palsy,16,45,historic,negated,none
4,no prev hx accidental fusidic acid overdose,296643001,accidental fusidic acid overdose,11,43,historic,negated,none
5,benign neoplasm of rib last 8 months,92326008,benign neoplasm of rib,0,22,historic,confirmed,none
6,chronic lymphangitis upto age 63,78973009,chronic lymphangitis,0,20,historic,confirmed,none
7,contact granuloma of larynx few yrs previously,12181000119103,contact granuloma of larynx,0,27,historic,confirmed,none
8,reflux gastritis last 1 mth,57433008,reflux gastritis,0,16,historic,confirmed,none
9,prev hist inflammation of spermatic cord,737173008,inflammation of spermatic cord,10,40,historic,confirmed,none


In [44]:
train_df.presence.value_counts()

confirmed    10079
suspected     5000
negated       4921
Name: presence, dtype: int64

In [45]:
train_df.relevance.value_counts()

present       10000
historic       5000
irrelevant     5000
Name: relevance, dtype: int64

In [46]:
# save
train_df.to_csv("./data/problems_synthetic_train_data.csv", index=False)

## Tokenize and convert to training data input

Modified from the utility function prepare_from_json() from MedCAT:
https://github.com/CogStack/MedCAT/blob/3e979951b0bf817b56445d90c6e7fcab97ef0390/medcat/utils/meta_cat/data_utils.py#L5

Instead of using MedCATTrainer JSON, this function takes the preprocessed MiADE synthetic data CSV (output from ```preprocess_miade_synthetic_data()```) and converts it to output data that is ready for input into Pytorch model training, which is wrapped inside the MetaCAT ```train()``` function

In [12]:
def prepare_from_miade_csv(
    data: pd.DataFrame,
    category_name: str,
    cntx_left: int,
    cntx_right: int,
    tokenizer: TokenizerWrapperBase,
    replace_center: str = None,
    lowercase: bool = True,
    ) -> Dict:

    out_data: Dict = {}
    
    for i in range(len(data)):
        text = data.text.values[i]
            
        if len(text) > 0:
            doc_text = tokenizer(text)
        
            start = data.start.values[i]
            end = data.end.values[i]

            # Get the index of the center token
            ind = 0
            for ind, pair in enumerate(doc_text['offset_mapping']):
                if start >= pair[0] and start < pair[1]:
                    break

            _start = max(0, ind - cntx_left)
            _end = min(len(doc_text['input_ids']), ind + 1 + cntx_right)
            tkns = doc_text['input_ids'][_start:_end]
            cpos = cntx_left + min(0, ind-cntx_left)
            
            if replace_center is not None:
                if lowercase:
                    replace_center = replace_center.lower()
                for p_ind, pair in enumerate(doc_text['offset_mapping']):
                    if start >= pair[0] and start < pair[1]:
                        s_ind = p_ind
                    if end > pair[0] and end <= pair[1]:
                        e_ind = p_ind

                ln = e_ind - s_ind
                tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos+ln+1:]
            
            value = data[category_name].values[i]
            sample = [tkns, cpos, value]

            if category_name in out_data:
                out_data[category_name].append(sample)
            else:
                out_data[category_name] = [sample]
    
    return out_data

In [13]:
# example tokenizer
mc = MetaCAT.load("./examples/Status")

In [14]:
prepare_from_miade_csv(train_df, category_name="presence", cntx_left=5, cntx_right=8, tokenizer=mc.tokenizer, replace_center="disease")

{'presence': [[[3564, 2430, 313, 8506], 3, 'confirmed'],
  [[9732, 2430, 8506], 2, 'confirmed'],
  [[8506, 360, 3073, 2541], 0, 'confirmed'],
  [[19, 15267, 8506], 2, 'confirmed'],
  [[8506, 1132, 2890, 8453, 785], 0, 'confirmed'],
  [[8506, 645, 476, 16841, 82], 0, 'confirmed'],
  [[1547, 1051, 313, 8506], 3, 'negated'],
  [[8506, 645, 464, 16841], 0, 'confirmed'],
  [[8506, 1922, 14636], 0, 'confirmed'],
  [[11694, 8506], 1, 'confirmed'],
  [[11405, 429, 8506], 2, 'suspected'],
  [[13956, 8506], 1, 'suspected'],
  [[4346, 8506], 1, 'suspected'],
  [[20327, 8506], 1, 'suspected'],
  [[11405, 884, 8506], 2, 'suspected'],
  [[9181, 3998, 8506], 2, 'suspected'],
  [[25626, 8506], 1, 'suspected'],
  [[8506, 7074], 0, 'suspected'],
  [[4346, 8506], 1, 'suspected'],
  [[474, 727, 8506], 2, 'suspected'],
  [[8260, 313, 8506], 2, 'confirmed'],
  [[3028, 6366, 8506], 2, 'confirmed'],
  [[12351, 82, 796, 8506], 3, 'confirmed'],
  [[3028, 6366, 796, 8506], 3, 'confirmed'],
  [[16756, 8107, 8506]

In [36]:
test = pd.concat([train_df[train_df["presence"] == "negated"][:10], train_df[train_df["presence"] == "suspected"][:10]], ignore_index=True)

In [37]:
prepare_from_miade_csv(test, category_name="presence", cntx_left=5, cntx_right=8, tokenizer=mc.tokenizer, replace_center="disease")

{'presence': [[[1547, 1051, 313, 8506], 3, 'negated'],
  [[8506, 1310], 0, 'negated'],
  [[8506, 2771], 0, 'negated'],
  [[5434, 8506], 1, 'negated'],
  [[8506, 2771], 0, 'negated'],
  [[17808, 8506], 1, 'negated'],
  [[8506, 1700], 0, 'negated'],
  [[7920, 8506], 1, 'negated'],
  [[8506], 0, 'negated'],
  [[8506, 3418], 0, 'negated'],
  [[11405, 429, 8506], 2, 'suspected'],
  [[13956, 8506], 1, 'suspected'],
  [[4346, 8506], 1, 'suspected'],
  [[20327, 8506], 1, 'suspected'],
  [[11405, 884, 8506], 2, 'suspected'],
  [[9181, 3998, 8506], 2, 'suspected'],
  [[25626, 8506], 1, 'suspected'],
  [[8506, 7074], 0, 'suspected'],
  [[4346, 8506], 1, 'suspected'],
  [[474, 727, 8506], 2, 'suspected']]}