In [1]:
import json 
import re
import pandas as pd
import numpy as np
import seaborn as sns

from typing import Dict, Optional, Tuple, Iterable, List
from matplotlib import pyplot as plt
from medcat.cat import CAT
from medcat.cdb import CDB
from medcat.config import Config
from medcat.vocab import Vocab
from medcat.meta_cat import MetaCAT
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.preprocessing.tokenizers import TokenizerWrapperBPE, TokenizerWrapperBERT
from tokenizers import ByteLevelBPETokenizer
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase

  from tqdm.autonotebook import tqdm, trange


In [2]:
DATA_DIR = "../medications_and_allergies/2023-07-19/"

## Preprocess data

In [220]:
# Helper function that finds the start and end indices of concepts annotated in the given pattern files and strips the annotation brackets ({p...}, {m...} from the text string
def clean_text_and_return_concept_indices(input_string, meds_allergies=False):
    
    if meds_allergies:
        # Find the indices of "{r ... }" and "{m ... }" pairs
        r_start = input_string.find("{r ")
        r_end = input_string.find("}", r_start)
        m_start = input_string.find("{m ")
        m_end = input_string.find("}", m_start)
        
        # Remove "{r ... }" and "{m ... }" patterns from the input string
        result_string = input_string.replace("{r ", "").replace("}", "").replace("{m ", "").replace("}", "")
        
        # Calculate the start and end indices for "{m ... }" pattern
        m_end -= 3
        r_end -= 3

        # Adjust the start and end indices based on the removal of "{r ... }" pattern
        if r_start != -1 and r_start < m_start:
            m_start -= 4  # Account for the removal of "{r " and "}"
            m_end -= 4
            
        if m_start != -1 and m_start < r_start:
            r_start -= 4
            r_end -= 4
        
        return result_string, m_start, m_end, r_start, r_end
    else:
        p_start = input_string.find("{p ")
        p_end = input_string.find("}", p_start)
        result_string = input_string.replace("{p ", "").replace("}", "")
        p_end -= 3

        return result_string, p_start, p_end

In [218]:
# Example usage:
input_string = "intolerant of {m potassium} which gave them {r indigestion}."
result, m_start, m_end, r_start, r_end = clean_text_and_return_concept_indices(input_string, meds_allergies=True)
print("Result:", result)
# print(result[start:end])

print("med: ",result[m_start:m_end], m_start, m_end)
print("reaction: ", result[r_start:r_end], r_start, r_end)

Result: intolerant of potassium which gave them indigestion.
med:  potassium 14 23
reaction:  indigestion 40 51


The function below processes the CSV data provided in ```miade-dataset``` into a dataframe ```[text, cui, name, start, end, <category_name>...]```

This MedCAT PR https://github.com/CogStack/MedCAT/pull/329 (v1.7.1 release) added a new method called ```cat_train_supervised_raw```, which takes an input similar to this but still needs to be serialised as JSON-ish format:

        {'projects':
            [ # list of projects
                { # project 1
                    'name': '<some name>',
                    # list of documents
                    'documents': [{'name': '<some name>',  # document 1
                                    'text': '<text of the document>',
                                    # list of annotations
                                    'annotations': [{'start': -1,  # annotation 1
                                                    'end': 1,
                                                    'cui': 'cui',
                                                    'value': '<text value>'}, ...],
                                    }, ...]
                }, ...
            ]
        }
        
I might adapt this if time, but for now use MiADE_CAT train method

In [215]:
def preprocess_miade_synthetic_data(data_path, lower_case=True, prefix="p"):
    
    data = pd.read_csv(data_path)
    
    if prefix == "m":
        # drop reactions
        data.drop("r", axis=1, inplace=True)
    elif prefix == "r":
        # drop meds
        data.drop("m", axis=1, inplace=True)
    
    # drop any NaNs in the concepts column
    data = data.dropna(subset=[prefix]).reset_index(drop=True)
        
    if lower_case:
        data = data.astype(str).apply(lambda x: x.str.lower())  # all lower case

    # extract cui and name in separate columns
    data[["cui", "name"]] = data[prefix].str.extract(r"^(\d+)\s*\|\s*(.+)$")
    # remove words inside brackets e.g.(disease) 
    data["name"].replace(r"\s*\([^)]*\)", "", regex=True, inplace=True)
    # drop the original column
    data.drop(prefix, axis=1, inplace=True)   
    # some entries end with | - remove that
    data['name'] = data['name'].str.rstrip('|')
     
    # extract the start and end indices of concept from text and remove the annotations e.g. {p...}
    start = []
    end = []
    text = []
    for i in range(len(data)):
        if prefix == "m":
            result, m_start, m_end, r_start, r_end = clean_text_and_return_concept_indices(data.text.values[i], meds_allergies=True)
            start.append(m_start)
            end.append(m_end)
        elif prefix == "r":          
            result, m_start, m_end, r_start, r_end = clean_text_and_return_concept_indices(data.text.values[i], meds_allergies=True)
            start.append(r_start)
            end.append(r_end)
        elif prefix == "p":
            result, start, end = clean_text_and_return_concept_indices(data.text.values[i])
            start.append(start)
            end.append(end)
        
        text.append(result)  
    
    data["start"] = start
    data["end"] = end
    data["text"] = text

    if prefix == "p":
        # convert labels
        data.replace("no laterality", "none", inplace=True)
        data.replace("positive", "present", inplace=True)
        # tidy up columns
        data.rename(columns={"p_meta_relevance": "relevance", "p_meta_confirmed": "presence", "p_meta_laterality": "laterality (generic)"}, inplace=True)
        data = data[["text", "cui", "name", "start", "end", "relevance", "presence", "laterality (generic)"]]
    elif prefix == "m":
        data.rename(columns={"m_meta_category": "substance_category", "m_meta_allergytype": "allergy_type", "m_meta_severity": "severity"}, inplace=True)
        data = data.fillna("unspecified")
        data = data[["text", "cui", "name", "start", "end", "substance_category", "allergy_type", "severity"]]
    elif prefix == "r":
        data.replace("not a reaction", "none", inplace=True)
        data.rename(columns={"r_meta_reactionpos": "reaction_pos"}, inplace=True)
        data = data[["text", "cui", "name", "start", "end", "reaction_pos"]]
        
    
    return data

In [205]:
df = preprocess_miade_synthetic_data(DATA_DIR + "/patterns_medallerg.csv", prefix="r")

In [206]:
df

Unnamed: 0,text,cui,name,start,end,reaction_pos
0,feeling breathless today.,267036007,dyspnea,8,18,none
1,allergies: severe severe depression with co-te...,310497006,severe depression,18,35,before
2,allergies: severe liver palms with cannabidiol.,248413004,liver palms,18,29,before
3,allergies: severe oral dyspraxia with dapoxetine.,361275004,oral dyspraxia,18,32,before
4,allergies: severe postcholecystectomy diarrhea...,53156005,postcholecystectomy diarrhea,18,46,before
...,...,...,...,...,...,...
356,experienced anaphylaxis with romosozumab,39579001,anaphylaxis,12,23,before
357,had anaphylaxis due to multivitamin capsules,39579001,anaphylaxis,4,15,before
358,had anaphylaxis due to homeopathic hamamelis,39579001,anaphylaxis,4,15,before
359,had anaphylaxis with larvae sterile,39579001,anaphylaxis,4,15,before


In [207]:
df.cui

0      267036007
1      310497006
2      248413004
3      361275004
4       53156005
         ...    
356     39579001
357     39579001
358     39579001
359     39579001
360     39579001
Name: cui, Length: 361, dtype: object

In [192]:
df.substance_category.value_counts()

adverse reaction    482
irrelevant           60
taking               21
Name: substance_category, dtype: int64

In [193]:
df.allergy_type.value_counts()

allergy        222
intolerance    200
unspecified    141
Name: allergy_type, dtype: int64

In [194]:
df.severity.value_counts()

unspecified    282
severe         161
mild            60
moderate        60
Name: severity, dtype: int64

In [208]:
df.reaction_pos.value_counts()

after     180
before    140
none       41
Name: reaction_pos, dtype: int64

In [210]:
df['text'].str.len().mean()

49.45706371191136

In [211]:
# sanity checks
for i in range(10):
    print(df.text.values[i][df.start.values[i]:df.end.values[i]])

breathless
severe depression
liver palms
oral dyspraxia
postcholecystectomy diarrhea
ménière's disease
anxiety disorder
onychogryposis
depressive disorder
dyspnea


In [213]:
for i in range(10):
    print(df[["text", "start", "end", "reaction_pos"]].iloc[i].tolist())

['feeling breathless today.', 8, 18, 'none']
['allergies: severe severe depression with co-tenidone.', 18, 35, 'before']
['allergies: severe liver palms with cannabidiol.', 18, 29, 'before']
['allergies: severe oral dyspraxia with dapoxetine.', 18, 32, 'before']
['allergies: severe postcholecystectomy diarrhea with pneumococcal vaccine.', 18, 46, 'before']
["allergies: severe ménière's disease with multivitamin capsules.", 18, 35, 'before']
['allergies: severe anxiety disorder with benzathine benzylpenicillin.', 18, 34, 'before']
['allergies: severe onychogryposis with crizanlizumab.', 18, 32, 'before']
['allergies: severe depressive disorder with sodium valproate.', 18, 37, 'before']
['allergies: severe dyspnea with glycerol phenylbutyrate.', 18, 25, 'before']


In [214]:
# for meds, the training data is generated and we don't have to proceed to the next step as all the data is in one file
df.to_csv("./reactions_synthetic_train_data.csv", index=False)

## Generate training data

For problems training data generation: samples a number of training data and concatenates it into one training data df ready for input into the next step (this is not very efficient! please feel free to make a nicer script). This is because unlike meds/allergies,
1) the problems synthetic data examples are in separate files and 
2) there is a very large amount of training data in total.

In [40]:
config_dict = {
    "historic_path": DATA_DIR + "/metacat_problems_historic.csv",
    "suspected_path": DATA_DIR + "/metacat_problems_suspected.csv",
    "negated_path": DATA_DIR + "/metacat_problems_negated.csv",
    "irrelevant_path": DATA_DIR + "/metacat_problems_irrelevant.csv",
    "present_path": DATA_DIR + "/metacat_problems_positive.csv",
    "confirmed_path": DATA_DIR + "/metacat_problems_positive.csv",
    "confirmed_num": 500,
    "historic_num": 5000,
    "irrelevant_num": 5000,
    "present_num": 500,
    "suspected_num": 5000,
    "negated_num": 5000}

In [41]:
def create_problems_training_data(config_dict):
    historic_df = preprocess_miade_synthetic_data(config_dict.get("historic_path"))
    suspected_df =  preprocess_miade_synthetic_data(config_dict.get("suspected_path"))
    irrelevant_df=  preprocess_miade_synthetic_data(config_dict.get("irrelevant_path"))
    negated_df =  preprocess_miade_synthetic_data(config_dict.get("negated_path"))
    present_df =  preprocess_miade_synthetic_data(config_dict.get("present_path"))
    confirmed_df =  preprocess_miade_synthetic_data(config_dict.get("confirmed_path"))
    
    historic_train = []
    suspected_train = []
    irrelevant_train = []
    negated_train = []
    present_train = []
    confirmed_train = []
    
    if historic_df is not None:
        historic_train = historic_df[:config_dict["historic_num"]]
    if suspected_df is not None:
        suspected_train = suspected_df[:config_dict["suspected_num"]]
    if irrelevant_df is not None:
        irrelevant_train = irrelevant_df[:config_dict["irrelevant_num"]]
    if negated_df is not None:
        negated_train = negated_df[:config_dict["negated_num"]]
    if present_df is not None:
        present_train = present_df[:config_dict["present_num"]]
    if confirmed_df is not None:
        confirmed_train = confirmed_df[-config_dict["confirmed_num"]:]
    
    train_data = pd.concat([historic_train, suspected_train, irrelevant_train, negated_train, present_train, confirmed_train], ignore_index=True)
    print(len(train_data))
    
    return train_data

In [42]:
train_df = create_problems_training_data(config_dict)

20000


In [43]:
train_df.head(10)

Unnamed: 0,text,cui,name,start,end,relevance,presence,laterality (generic)
0,decidual endometritis from age 12,75585005,decidual endometritis,0,21,historic,confirmed,none
1,bilateral metatarsus adductus prev,15667441000119108,bilateral metatarsus adductus,0,29,historic,confirmed,bilateral
2,crohn's disease of pylorus a few years previously,61424003,crohn's disease of pylorus,0,26,historic,confirmed,none
3,no prev hist of congenital facial nerve palsy,230542008,congenital facial nerve palsy,16,45,historic,negated,none
4,no prev hx accidental fusidic acid overdose,296643001,accidental fusidic acid overdose,11,43,historic,negated,none
5,benign neoplasm of rib last 8 months,92326008,benign neoplasm of rib,0,22,historic,confirmed,none
6,chronic lymphangitis upto age 63,78973009,chronic lymphangitis,0,20,historic,confirmed,none
7,contact granuloma of larynx few yrs previously,12181000119103,contact granuloma of larynx,0,27,historic,confirmed,none
8,reflux gastritis last 1 mth,57433008,reflux gastritis,0,16,historic,confirmed,none
9,prev hist inflammation of spermatic cord,737173008,inflammation of spermatic cord,10,40,historic,confirmed,none


In [44]:
train_df.presence.value_counts()

confirmed    10079
suspected     5000
negated       4921
Name: presence, dtype: int64

In [45]:
train_df.relevance.value_counts()

present       10000
historic       5000
irrelevant     5000
Name: relevance, dtype: int64

In [46]:
# save
train_df.to_csv("./data/problems_synthetic_train_data.csv", index=False)

## Tokenize and convert to training data input

Modified from the utility function prepare_from_json() from MedCAT:
https://github.com/CogStack/MedCAT/blob/3e979951b0bf817b56445d90c6e7fcab97ef0390/medcat/utils/meta_cat/data_utils.py#L5

Instead of using MedCATTrainer JSON, this function takes the preprocessed MiADE synthetic data CSV (output from ```preprocess_miade_synthetic_data()```) and converts it to output data that is ready for input into Pytorch model training, which is wrapped inside the MetaCAT ```train()``` function.

This is not necessary to run to generate the training data itself - just for validation

In [12]:
def prepare_from_miade_csv(
    data: pd.DataFrame,
    category_name: str,
    cntx_left: int,
    cntx_right: int,
    tokenizer: TokenizerWrapperBase,
    replace_center: str = None,
    lowercase: bool = True,
    ) -> Dict:

    out_data: Dict = {}
    
    for i in range(len(data)):
        text = data.text.values[i]
            
        if len(text) > 0:
            doc_text = tokenizer(text)
        
            start = data.start.values[i]
            end = data.end.values[i]

            # Get the index of the center token
            ind = 0
            for ind, pair in enumerate(doc_text['offset_mapping']):
                if start >= pair[0] and start < pair[1]:
                    break

            _start = max(0, ind - cntx_left)
            _end = min(len(doc_text['input_ids']), ind + 1 + cntx_right)
            tkns = doc_text['input_ids'][_start:_end]
            cpos = cntx_left + min(0, ind-cntx_left)
            
            if replace_center is not None:
                if lowercase:
                    replace_center = replace_center.lower()
                for p_ind, pair in enumerate(doc_text['offset_mapping']):
                    if start >= pair[0] and start < pair[1]:
                        s_ind = p_ind
                    if end > pair[0] and end <= pair[1]:
                        e_ind = p_ind

                ln = e_ind - s_ind
                tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos+ln+1:]
            
            value = data[category_name].values[i]
            sample = [tkns, cpos, value]

            if category_name in out_data:
                out_data[category_name].append(sample)
            else:
                out_data[category_name] = [sample]
    
    return out_data

In [13]:
# example tokenizer
mc = MetaCAT.load("./examples/Status")

In [14]:
prepare_from_miade_csv(train_df, category_name="presence", cntx_left=5, cntx_right=8, tokenizer=mc.tokenizer, replace_center="disease")

{'presence': [[[3564, 2430, 313, 8506], 3, 'confirmed'],
  [[9732, 2430, 8506], 2, 'confirmed'],
  [[8506, 360, 3073, 2541], 0, 'confirmed'],
  [[19, 15267, 8506], 2, 'confirmed'],
  [[8506, 1132, 2890, 8453, 785], 0, 'confirmed'],
  [[8506, 645, 476, 16841, 82], 0, 'confirmed'],
  [[1547, 1051, 313, 8506], 3, 'negated'],
  [[8506, 645, 464, 16841], 0, 'confirmed'],
  [[8506, 1922, 14636], 0, 'confirmed'],
  [[11694, 8506], 1, 'confirmed'],
  [[11405, 429, 8506], 2, 'suspected'],
  [[13956, 8506], 1, 'suspected'],
  [[4346, 8506], 1, 'suspected'],
  [[20327, 8506], 1, 'suspected'],
  [[11405, 884, 8506], 2, 'suspected'],
  [[9181, 3998, 8506], 2, 'suspected'],
  [[25626, 8506], 1, 'suspected'],
  [[8506, 7074], 0, 'suspected'],
  [[4346, 8506], 1, 'suspected'],
  [[474, 727, 8506], 2, 'suspected'],
  [[8260, 313, 8506], 2, 'confirmed'],
  [[3028, 6366, 8506], 2, 'confirmed'],
  [[12351, 82, 796, 8506], 3, 'confirmed'],
  [[3028, 6366, 796, 8506], 3, 'confirmed'],
  [[16756, 8107, 8506]

In [36]:
test = pd.concat([train_df[train_df["presence"] == "negated"][:10], train_df[train_df["presence"] == "suspected"][:10]], ignore_index=True)

In [37]:
prepare_from_miade_csv(test, category_name="presence", cntx_left=5, cntx_right=8, tokenizer=mc.tokenizer, replace_center="disease")

{'presence': [[[1547, 1051, 313, 8506], 3, 'negated'],
  [[8506, 1310], 0, 'negated'],
  [[8506, 2771], 0, 'negated'],
  [[5434, 8506], 1, 'negated'],
  [[8506, 2771], 0, 'negated'],
  [[17808, 8506], 1, 'negated'],
  [[8506, 1700], 0, 'negated'],
  [[7920, 8506], 1, 'negated'],
  [[8506], 0, 'negated'],
  [[8506, 3418], 0, 'negated'],
  [[11405, 429, 8506], 2, 'suspected'],
  [[13956, 8506], 1, 'suspected'],
  [[4346, 8506], 1, 'suspected'],
  [[20327, 8506], 1, 'suspected'],
  [[11405, 884, 8506], 2, 'suspected'],
  [[9181, 3998, 8506], 2, 'suspected'],
  [[25626, 8506], 1, 'suspected'],
  [[8506, 7074], 0, 'suspected'],
  [[4346, 8506], 1, 'suspected'],
  [[474, 727, 8506], 2, 'suspected']]}