In [1]:
import json 
import re
import pandas as pd
import numpy as np
import seaborn as sns

from typing import Dict, Optional, Tuple, Iterable, List
from matplotlib import pyplot as plt
from medcat.cat import CAT
from medcat.cdb import CDB
from medcat.config import Config
from medcat.vocab import Vocab
from medcat.meta_cat import MetaCAT
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.preprocessing.tokenizers import TokenizerWrapperBPE, TokenizerWrapperBERT
from tokenizers import ByteLevelBPETokenizer
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase

  from tqdm.autonotebook import tqdm, trange


In [2]:
DATA_DIR = "/home/jennifer/Documents/miade-datasets/problems/metacat_problems_07mar2023"

## Preprocess data
This function processes the CSV data provided in ```miade-dataset``` into a dataframe ```[text, cui, name, start, end, <category_name>...]```

In [3]:
def preprocess_miade_synthetic_data(data_path, lower_case=True, prefix="p"):
    
    data = pd.read_csv(data_path)
    
    if lower_case:
        data = data.fillna('').astype(str).apply(lambda x: x.str.lower())  # all lower case
    else:
        if prefix == "p":
            data[["p_meta_relevance", "p_meta_confirmed", "p_meta_laterality"]] = data[["p_meta_relevance", "p_meta_confirmed", "p_meta_laterality"]].apply(lambda x: str(x).lower())
    
    # extract cui and name in separate columns
    data[["cui", "name"]] = data[prefix].str.extract(r"^(\d+)\s+\|\s+(.+)$")
    # remove words inside brackets e.g.(disease) 
    data["name"].replace(r"\s*\([^)]*\)", "", regex=True, inplace=True)
    data.drop(prefix, axis=1, inplace=True)
    
    # get the start and end indices of the concept inside {p ...} format
    if prefix == "p":
        p = re.compile(r"(\{p )|\}")
    elif prefix == "m":
        p = re.compile(r"(\{m )|\}")
    elif prefix == "r":
        p = re.compile(r"(\{r )|\}")
        
    start = []
    end = []
    for i in range(len(data)):
        match = [m for m in re.finditer(p, data.text.values[i])]
        if len(match) > 2:
            print("More than one concept in text, only processing first one")
        
        start.append(match[0].end() - 3)  # shift by -3
        end.append(match[1].start() - 3)
    
    data["start"] = start
    data["end"] = end
    
    # remove annotation brackets
    data["text"].replace(p, "", regex=True, inplace=True)  
    
    if prefix == "p":
        # convert labels
        data.replace("no laterality", "none", inplace=True)
        # tidy up columns
        data.rename(columns={"p_meta_relevance": "relevance", "p_meta_confirmed": "presence", "p_meta_laterality": "laterality (generic)"}, inplace=True)
        data = data[["text", "cui", "name", "start", "end", "relevance", "presence", "laterality (generic)"]]
    
    return data

In [4]:
df = preprocess_miade_synthetic_data(DATA_DIR + "/metacat_problems_historic.csv")

In [5]:
df

Unnamed: 0,text,cui,name,start,end,relevance,presence,laterality (generic)
0,decidual endometritis from age 12,75585005,decidual endometritis,0,21,historic,confirmed,none
1,bilateral metatarsus adductus prev,15667441000119108,bilateral metatarsus adductus,0,29,historic,confirmed,bilateral
2,crohn's disease of pylorus a few years previously,61424003,crohn's disease of pylorus,0,26,historic,confirmed,none
3,no prev hist of congenital facial nerve palsy,230542008,congenital facial nerve palsy,16,45,historic,negated,none
4,no prev hx accidental fusidic acid overdose,296643001,accidental fusidic acid overdose,11,43,historic,negated,none
...,...,...,...,...,...,...,...,...
10995,4 previous acute malnutrition in childhood,762497007,acute malnutrition in childhood,11,42,historic,confirmed,none
10996,nonspecific interstitial pneumonia while 24 y o,129452008,nonspecific interstitial pneumonia,0,34,historic,confirmed,none
10997,no pmh of oxyphilic adenocarcinoma,443261008,oxyphilic adenocarcinoma,10,34,historic,negated,none
10998,neonatal hyperglycemia previously,276557002,neonatal hyperglycemia,0,22,historic,confirmed,none


In [6]:
df['text'].str.len().mean()

38.00836363636363

In [7]:
# sanity checks
for i in range(10):
    print(df.text.values[i][df.start.values[i]:df.end.values[i]])

decidual endometritis
bilateral metatarsus adductus
crohn's disease of pylorus
congenital facial nerve palsy
accidental fusidic acid overdose
benign neoplasm of rib
chronic lymphangitis
contact granuloma of larynx
reflux gastritis
inflammation of spermatic cord


In [8]:
for i in range(10):
    print(df[["text", "start", "end", "relevance"]].iloc[i].tolist())

['decidual endometritis from age 12', 0, 21, 'historic']
['bilateral metatarsus adductus prev', 0, 29, 'historic']
["crohn's disease of pylorus a few years previously", 0, 26, 'historic']
['no prev hist of congenital facial nerve palsy', 16, 45, 'historic']
['no prev hx accidental fusidic acid overdose', 11, 43, 'historic']
['benign neoplasm of rib last 8 months', 0, 22, 'historic']
['chronic lymphangitis upto age 63', 0, 20, 'historic']
['contact granuloma of larynx few yrs previously', 0, 27, 'historic']
['reflux gastritis last 1 mth', 0, 16, 'historic']
['prev hist inflammation of spermatic cord', 10, 40, 'historic']


## Generate training data

Randomly samples a number of training data and concatenates it into one training data df ready for input into the next step

In [10]:
config_dict = {
    "historic_path": DATA_DIR + "/metacat_problems_historic.csv",
    "suspected_path": DATA_DIR + "/metacat_problems_suspected.csv",
    "negated_path": DATA_DIR + "/metacat_problems_negated.csv",
    "irrelevant_path": DATA_DIR + "/metacat_problems_irrelevant.csv",
    "confirmed_num": 10,
    "historic_num": 10,
    "irrelevant_num": 10,
    "present_num": 10,
    "suspected_num": 10,
    "negated_num": 10}

In [11]:
def create_problems_training_data(config_dict):
    historic_df = preprocess_miade_synthetic_data(config_dict.get("historic_path"))
    suspected_df =  preprocess_miade_synthetic_data(config_dict.get("suspected_path"))
    irrelevant_df=  preprocess_miade_synthetic_data(config_dict.get("irrelevant_path"))
    negated_df =  preprocess_miade_synthetic_data(config_dict.get("negated_path"))
    
    historic_train = []
    suspected_train = []
    irrelevant_train = []
    negated_train = []
    
    if historic_df is not None:
        historic_train = historic_df.sample(n=config_dict["historic_num"])
    if suspected_df is not None:
        suspected_train = suspected_df.sample(n=config_dict["suspected_num"])
    if irrelevant_df is not None:
        irrelevant_train = irrelevant_df.sample(n=config_dict["irrelevant_num"])
    if negated_df is not None:
        negated_train = negated_df.sample(n=config_dict["negated_num"])
    
    train_data = pd.concat([historic_train, suspected_train, irrelevant_train, negated_train], ignore_index=True)
    print(len(train_data))
    
    return train_data

In [12]:
train_df = create_problems_training_data(config_dict)

40


In [13]:
train_df

Unnamed: 0,text,cui,name,start,end,relevance,presence,laterality (generic)
0,no hist of sympathetic nervous structure injury,282747005,sympathetic nervous structure injury,11,47,historic,negated,none
1,agenesis of nasal bone when 30 yr o,1003577003,agenesis of nasal bone,0,22,historic,confirmed,none
2,previous acute skin sarcoidosis,238674006,acute skin sarcoidosis,9,31,historic,confirmed,none
3,infectious disease of abdomen last 7 month,128070006,infectious disease of abdomen,0,29,historic,confirmed,none
4,no ph of cannabis poisoning,1149328002,cannabis poisoning,9,27,historic,negated,none
5,prev acute myocarditis - tuberculous,194949003,acute myocarditis - tuberculous,5,36,historic,confirmed,none
6,previous hx of exotropia of bilateral eyes,15632691000119105,exotropia of bilateral eyes,15,42,historic,confirmed,bilateral
7,past medical history of hypertrophy of gallbla...,76875008,hypertrophy of gallbladder,24,50,historic,confirmed,none
8,laceration of right buttock last 3 month,10904511000119107,laceration of right buttock,0,27,historic,confirmed,right
9,several acute pyelitis,32801008,acute pyelitis,8,22,historic,confirmed,none


In [None]:
# save
train_df.to_csv("./problems_synthetic_train_example.csv", index=False)

## Tokenize and convert to training data input

Modified from the utility function prepare_from_json() from MedCAT:
https://github.com/CogStack/MedCAT/blob/3e979951b0bf817b56445d90c6e7fcab97ef0390/medcat/utils/meta_cat/data_utils.py#L5

Instead of using MedCATTrainer JSON, this function takes the preprocessed MiADE synthetic data CSV (output from ```preprocess_miade_synthetic_data()```) and converts it to output data that is ready for input into Pytorch model training, which is wrapped inside the MetaCAT ```train()``` function

In [15]:
def prepare_from_miade_csv(
    data: pd.DataFrame,
    category_name: str,
    cntx_left: int,
    cntx_right: int,
    tokenizer: TokenizerWrapperBase,
    replace_center: str = None,
    lowercase: bool = True,
    ) -> Dict:

    out_data: Dict = {}
    
    for i in range(len(data)):
        text = data.text.values[i]
            
        if len(text) > 0:
            doc_text = tokenizer(text)
        
            start = data.start.values[i]
            end = data.end.values[i]

            # Get the index of the center token
            ind = 0
            for ind, pair in enumerate(doc_text['offset_mapping']):
                if start >= pair[0] and start < pair[1]:
                    break

            _start = max(0, ind - cntx_left)
            _end = min(len(doc_text['input_ids']), ind + 1 + cntx_right)
            tkns = doc_text['input_ids'][_start:_end]
            cpos = cntx_left + min(0, ind-cntx_left)
            
            if replace_center is not None:
                if lowercase:
                    replace_center = replace_center.lower()
                for p_ind, pair in enumerate(doc_text['offset_mapping']):
                    if start >= pair[0] and start < pair[1]:
                        s_ind = p_ind
                    if end > pair[0] and end <= pair[1]:
                        e_ind = p_ind

                ln = e_ind - s_ind
                tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos+ln+1:]
            
            value = data[category_name].values[i]
            sample = [tkns, cpos, value]

            if category_name in out_data:
                out_data[category_name].append(sample)
            else:
                out_data[category_name] = [sample]
    
    return out_data

In [16]:
# example tokenizer
mc = MetaCAT.load("./data/Status")

In [17]:
prepare_from_miade_csv(train_df, category_name="presence", cntx_left=5, cntx_right=8, tokenizer=mc.tokenizer, replace_center="disease")

{'presence': [[[1547, 11292, 313, 8506], 3, 'negated'],
  [[8506, 1132, 950, 6186, 290], 0, 'confirmed'],
  [[5500, 8506], 1, 'confirmed'],
  [[8506, 645, 464, 3015], 0, 'confirmed'],
  [[1547, 1051, 313, 8506], 3, 'negated'],
  [[1396, 85, 8506], 2, 'confirmed'],
  [[5500, 2430, 313, 8506], 3, 'confirmed'],
  [[3564, 702, 682, 313, 8506], 4, 'confirmed'],
  [[8506, 645, 373, 3015], 0, 'confirmed'],
  [[11946, 8506], 1, 'confirmed'],
  [[8506, 7074], 0, 'suspected'],
  [[28577, 313, 8506], 2, 'suspected'],
  [[1218, 8506], 1, 'suspected'],
  [[13910, 8506], 1, 'suspected'],
  [[8506, 7074], 0, 'suspected'],
  [[9181, 3998, 8506], 2, 'suspected'],
  [[6606, 8506], 1, 'suspected'],
  [[25626, 8506], 1, 'suspected'],
  [[13481, 8506], 1, 'suspected'],
  [[11405, 429, 8506], 2, 'suspected'],
  [[15629, 2486, 8506], 2, 'confirmed'],
  [[12351, 82, 8506], 2, 'confirmed'],
  [[18258, 6144, 8506], 2, 'confirmed'],
  [[3028, 11292, 8506], 2, 'confirmed'],
  [[9127, 21684, 8506], 2, 'confirmed']