In [1]:
# add autoreload
%load_ext autoreload
%autoreload 2
import os
import sys
import json

import numpy as np
import pandas as pd
import scipy as sc

from collections import defaultdict
import re
import deduce

from tqdm import tqdm
import seaborn as sns

from gensim.models import phrases

import dill

# Context:
* $100$ K echocardiographic reports available. 
* we want to extract diagnoses regarding the left-ventricle function
* we have $5000$ reports with labeled spans.

# Goal:
Train a "model" that can
1. identify the spans
2. classify the spans

# Approach: MedCAT - MetaCAT

## Two-step approach

* unsupervised training on the documents
* add a single custom entity with a custom identifier
* train a model to identify the custom entities
* supervised training on the meta-annotations of the entities

## One-step approach

* unsupervised training on the documents
* add custom entities based on the spans and their labels
* train a model to identify the custom entities

# Approach: biLSTM/transformer

## Two-step approach

* Train a model to identify the spans: self-supervision by random selecting non-span ranges as negative examples
* Train a model to classify the spans: supervised based on the labeled spans 
* Combine the model in one pipeline

## One-step approach
* Assign a label to each span
* Train a model to identify the spans

##  Load Medcat modelpack

In [2]:
# load dotenv
from dotenv import load_dotenv
load_dotenv('../.env')
# extract the environment variable 'medcat_pack'


True

In [3]:
from medcat.cat import CAT
from medcat.vocab import Vocab
from medcat.cdb import CDB
from medcat.config import Config
from medcat.meta_cat import MetaCAT

base_medcat_path = os.getenv('medcat_pack')
pack_location = 'umls-dutch-v1-10_echo'
prep_medcat = False


  from tqdm.autonotebook import tqdm, trange





# Load texts

In [5]:
echo_path = 'T://lab_research/RES-Folder-UPOD/Echo_label/E_ResearchData/2_ResearchData'
# load the jsonl in a dataframe
texts = pd.read_json(os.path.join(echo_path, 'outdb_140423.jsonl'), lines=True)
#texts_zipped = zip(texts['text'], texts['_input_hash'], texts['_task_hash'])

# Load train/test splits

In [6]:
train_indcs = pd.read_csv(os.path.join(echo_path, 'train_echoid.csv'), sep=',')
test_indcs = pd.read_csv(os.path.join(echo_path, 'test_echoid.csv'), sep=',')

train_indcs = train_indcs[train_indcs.input_hash.notna()]
test_indcs = test_indcs[test_indcs.input_hash.notna()]

In [7]:
TRAIN_INPUT_HASH = set(train_indcs['input_hash'].astype(int).to_list())
TRAIN_TASK_HASH = set(train_indcs['task_hash'].astype(int).to_list())

TEST_INPUT_HASH = set(test_indcs['input_hash'].astype(int).to_list())
TEST_TASK_HASH = set(test_indcs['task_hash'].astype(int).to_list())

In [8]:
texts_train = texts.loc[texts._input_hash.isin(TRAIN_INPUT_HASH)]
texts_test = texts.loc[texts._input_hash.isin(TEST_INPUT_HASH)]

## Make splitted data

* MedCAT train files
* class-dictionary with dataframes

In [9]:
def get_medcat_json(filename: str=None, HashSet: set=False, ClassName: str=None):
    texts = pd.read_json(filename, lines=True)
    output = {"projects": [{
                "name": ClassName,
                "id": 42,
                "cuis": "",
                "tuis": "",
                "documents": None
            }]}
    
    documents = []
    for i, _row in enumerate(texts.iterrows()):
        row = _row[1]
        if (row["_input_hash"] in HashSet) | (HashSet is False):
            txt = row['text']
            id = i
            input_hash = row["_input_hash"]
            task_hash = row["_task_hash"]
            annotations = []
            for j, ann in enumerate(row["spans"]):
                res = {}
                res['user'] = 'BVE'
                res['cui'] = 123
                res['id'] = j
                res['start'] = ann['start']
                res['end'] = ann['end']
                res['value'] = txt[ann['start']:ann['end']]
                res['validated'] = True
                res['correct'] = True
                res['deleted'] = False
                res['alternative'] = False
                res['killed'] = False
                res["meta_anns"] = {
                    ClassName: {
                        "name": ClassName,
                        "validated": True,
                        "accuracy": 1.0,
                        "value": ann['label']
                    }
                }
                annotations.append(res)
            doc = {
                'id': id,
                'text': txt,
                'input_hash': input_hash,
                'task_hash': task_hash,
                'annotations': annotations         
            }
            documents.append(doc)
    output['projects'][0]['documents'] = documents
    
    return output


In [10]:
class_names = os.listdir(os.path.join(echo_path, 'echo_span_labels'))
merged_index = [i for i,t in enumerate(class_names) 
                    if ('merged' in t) | ('old' in t)]
merged_index = sorted(merged_index, reverse=True)

for mind in merged_index:
    class_names.pop(mind)
class_ds = defaultdict(dict)

for file_name in class_names:
    _class = file_name.split(".")[0]
    fn = os.path.join(echo_path, 'echo_span_labels', file_name)
    class_ds[_class]['ds'] = pd.read_json(fn, lines=True)
    
    #TODO: has to be refactored
    medcat_json = get_medcat_json(fn, TRAIN_INPUT_HASH, _class)
    medcat_out = os.path.join(echo_path, 'medcat_labels', 'train', f'medcat_{_class}.json')
    json.dump(medcat_json, open(medcat_out, 'w'))
    class_ds[_class]['train_location'] = medcat_out
    
    medcat_json = get_medcat_json(fn, TEST_INPUT_HASH, _class)
    medcat_out = os.path.join(echo_path, 'medcat_labels', 'test', f'medcat_{_class}.json')
    json.dump(medcat_json, open(medcat_out, 'w'))
    class_ds[_class]['test_location'] = medcat_out
    
    class_ds[_class]['labels'] = set([span_dict['label'] 
                                        for span_list in class_ds[_class]['ds']['spans'].tolist()
                                        for span_dict in span_list])


## Unsupervised learning for NER+L

In [None]:
MCAT = CAT.load_model_pack(os.path.join(base_medcat_path, pack_location))

MCAT.train(texts_train.text.values, 
            nepochs=3, 
            progress_print=10,  
            is_resumed=True)
MCAT.create_model_pack(base_medcat_path + "/umls-dutch-v1-10_echo")

### Add label spans from Prodigy annotations

In [None]:
span_sets = defaultdict(set)
for _class, dsd in tqdm(class_ds.items()):
        ds = dsd['ds']
        span_set = set()
        span_list = []
        ds = ds[ds._input_hash.isin(TRAIN_INPUT_HASH)]
        for k, (_spans, text) in enumerate(zip(ds[ds.spans.notna()].spans.values,
                                            ds[ds.spans.notna()].text.values)):
            for _span in _spans:
                start, end = _span['start'], _span['end']
                span_set.add(text[start:end])
                span_list.append(text[start:end])
        span_sets[_class] = span_set

In [None]:
# TODO: interesting to add variants without the abbreviations?
# TODO: including paraphrasing?

In [None]:
for _class, span_set in tqdm(span_sets.items()):
    for _span in span_set:
        MCAT.add_and_train_concept(cui=_class,
                                name=_span, 
                                do_add_concept=True,
                                negative=False,
                            )

## Supervised learning for NER+L

In [None]:
MCAT.train_supervised(data_path=os.path.join(base_medcat_path, 
                                 "umls-dutch-v1-10_echo",
                                 "input/ner_l_anno/trainer_export.json"), 
                      nepochs=4,
                      print_stats=0,
                      use_filters=False)
medcat_path = base_medcat_path + "/umls-dutch-v1-10_echoV2"
MCAT.create_model_pack(medcat_path)

## Supervised learning of MetaCAT models

In [11]:
# load MCAT from file
medcat_path = os.path.join(base_medcat_path, 'umls-dutch-v1-10_echoV2')
MCAT = CAT.load_model_pack(medcat_path)



In [13]:
from medcat.meta_cat import MetaCAT
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBPE, ByteLevelBPETokenizer

In [14]:
# load tokenizer from negation_model
# tokenizer folder 
tok_folder = 'T:/laupodteam/AIOS/Bram/language_modeling/Clinical_embeddings/bigrams/with_tokenizer/v2/tokenizer'
emb_folder = 'T:/laupodteam/AIOS/Bram/language_modeling/Clinical_embeddings/bigrams/with_tokenizer/v2/SG'
tokenizer = ByteLevelBPETokenizer.from_file(os.path.join(tok_folder, 'vocab.json'), 
                                            os.path.join(tok_folder, 'merges.txt'))
wrapped_tokenizer = TokenizerWrapperBPE(hf_tokenizers=tokenizer)
wrapped_tokenizer.save(medcat_path + "/assets/tokenizer")

In [15]:
from gensim.models import Word2Vec, KeyedVectors
vec_path = os.path.join(emb_folder, 'sg')
print(vec_path)
w2v = KeyedVectors.load(vec_path)

T:/laupodteam/AIOS/Bram/language_modeling/Clinical_embeddings/bigrams/with_tokenizer/v2/SG\sg


In [16]:
# Create embedding matrix
embeddings = []
words_not_present = []

for i in range(tokenizer.get_vocab_size()):
    word = tokenizer.id_to_token(i)
    if word in w2v:
        embeddings.append(w2v[word])
    else:
        words_not_present.append(i)
        embeddings.append(np.random.random(300))
        
mean_vector = np.mean(embeddings, axis=0)

for i in words_not_present:
    embeddings[i] = mean_vector

# Save the embeddings
embeddings_array = np.array(embeddings)
np.save(open(medcat_path+"/assets/embeddings/embedding.npy", 
             'wb'), embeddings_array)

print(f"Words not present:{len(words_not_present)}")

Words not present:2807


In [17]:
for _class, d in class_ds.items():
    if _class != 'normal':
        print(f"Commencing training of biLSTM-span for {_class}...")
        config_metacat = ConfigMetaCAT()
        config_metacat.general['category_name'] = _class
        config_metacat.train['nepochs'] = 25
        config_metacat.train['score_average'] = 'weighted'
        config_metacat.model['hidden_size'] = 256
        config_metacat.model['input_size'] = 300
        config_metacat.model['dropout'] = 0.3
        config_metacat.model['num_layers'] = 3
        config_metacat.model['num_directions'] = 2
        config_metacat.model['nclasses'] = len(d['labels'])
        config_metacat.model['model_name'] = 'lstm'
        
        meta_cat = MetaCAT(tokenizer=wrapped_tokenizer,
                    embeddings=embeddings_array, 
                    config=config_metacat)
        
        train_path = d['train_location']
        model_path = os.path.join(medcat_path, f"meta_{_class}")
        os.makedirs(model_path, exist_ok=True)
        
        meta_cat.train(json_path=train_path, 
                        save_dir_path=model_path)
        
        meta_cat.save(save_dir_path=model_path)
        # now manually add the model to the model_pack...
        label_dict = config_metacat.general.category_value2id
        
        # add to config
        medcat_config = json.load(open(os.path.join(medcat_path, 'model_card.json'), 'r'))
        
        # add to "MetaCAT models" list
        medcat_config["MetaCAT models"].append({
                      "Category Name": _class,
                      "Description": "No description",
                      "Classes": label_dict,
                        "Model": "lstm"
                    })
        # write config to .json
        json.dump(medcat_config, open(os.path.join(medcat_path, 'model_card.json'), 'w'))

Commencing training of biLSTM-span for aortic_regurgitation...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for aortic_stenosis...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for diastolic_dysfunction...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for lv_dil...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for lv_syst_func...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for mitral_regurgitation...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for pe...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for rv_dil...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for rv_syst_func...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for tricuspid_regurgitation...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Commencing training of biLSTM-span for wma...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## Load  new model pack

In [18]:
MCATnew = CAT.load_model_pack(os.path.join(medcat_path))

  for entry_point in AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, []):
  for entry_point in AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, []):


## Apply to texts

In [19]:
from spacy import displacy

In [20]:
i = 232
doc = MCATnew(texts.text.values[i])
displacy.render(doc, style='ent')

  from IPython.core.display import HTML, display


In [21]:
def create_dict_with_conf(ent):
    result_dict = {}
    for k, v in ent._.meta_anns.items():
        result_dict[k] = v['value']
        result_dict[f'conf_{k}'] = v['confidence']
    return result_dict

In [22]:

doc = MCATnew(texts.text.values[i])
inds = []
res = []
for ent in doc.ents:
    inds.append(ent.text)
    res.append(create_dict_with_conf(ent))
    
res_df = pd.DataFrame(res, index=inds)

## Evaluate

In [23]:
TEST_RESULTS = {}
for k in class_ds.keys():
    if k != 'normal':
        CATmodel = MetaCAT.load(os.path.join(medcat_path, f'meta_{k}'))
        test_location = class_ds[k]['test_location']
        test_labels = json.load(open(test_location, 'r'))
        bulk_res = CATmodel.eval(test_location)
        
        res_count = defaultdict(int)
        for doc in test_labels['projects'][0]['documents']:
            anns = doc['annotations']
            if len(anns)>0:
                for ann in anns:            
                    _ann = ann['meta_anns'][k]
                    res_count[_ann['value']] += 1

        TEST_RESULTS[k] = {
            'f1':  bulk_res['f1'],
            'precision': bulk_res['precision'],
            'recall': bulk_res['recall'],
            'confusion_df': bulk_res['confusion matrix'],
            'real_presence': res_count
        }

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [24]:
dill.dump(TEST_RESULTS, file=open("../artifacts/MetaCAT_test_results.pkl", "wb"))
#TEST_RESULTS = dill.load(open("../artifacts/MetaCAT_test_results.pkl", "rb"))

## Assessment of performance in the wild

In [25]:
merged_labels_train = get_medcat_json(open(os.path.join(echo_path, 'echo_span_labels', 'merged_labels.jsonl'), 'r'), 
                                    TRAIN_INPUT_HASH, 'merged')['projects'][0]['documents']

merged_labels_test = get_medcat_json(open(os.path.join(echo_path, 'echo_span_labels', 'merged_labels.jsonl'), 'r'), 
                                    TEST_INPUT_HASH, 'merged')['projects'][0]['documents']

See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information.  (Deprecated NumPy 1.25)
  return np.find_common_type(types, [])
See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information.  (Deprecated NumPy 1.25)
  return np.find_common_type(types, [])


In [26]:
comparison_list = []
for medcat_doc in tqdm(merged_labels_test):
    txt = medcat_doc['text']
    parsed_doc = MCATnew(txt)

    res = []
    start_stop = []
    for ent in parsed_doc.ents:
        start_stop.append((ent.start_char, ent.end_char))
        res.append(create_dict_with_conf(ent))
    
    medcat_doc['predicted'] = {k:v for k,v in zip(start_stop, res)}
    comparison_list.append(medcat_doc)
comparison_list_with_annotations_present = [d for d in comparison_list if len(d['annotations'])>0]

100%|██████████| 991/991 [02:56<00:00,  5.62it/s]


In [30]:
def get_pred_list(pred_dict: dict=None):
    return [v for k,v in pred_dict.items() if 'conf_' not in k]
    
span_suggester_comparison_list = []
for d in comparison_list_with_annotations_present:
    anns = d['annotations']
    span_cat_list = []
    for ann in anns:
        span = (ann['start'], ann['end'])
        class_val = ann['meta_anns']['merged']['value']
        span_cat_list.append((span, class_val))
    
    pred_list = []
    for k, pred in d['predicted'].items():
        pred_list.append((k, get_pred_list(pred)))
    
    span_suggester_comparison_list.append([span_cat_list, pred_list])

In [31]:
# We want to get the coverage of spans -> a count of all (partially) overlapping spans
# We want to get the token overlap of covered spans (Jaccard?)
span_suggester_comparison_list[0]

[[((37, 69), 'tricuspid_valve_native_regurgitation_not_present'),
  ((13, 35), 'rv_sys_func_normal'),
  ((13, 35), 'lv_sys_func_normal'),
  ((37, 69), 'aortic_valve_native_regurgitation_not_present'),
  ((37, 69), 'aortic_valve_native_regurgitation_not_present'),
  ((37, 69), 'mitral_valve_native_regurgitation_not_present'),
  ((37, 70), 'aortic_valve_native_stenosis_not_present')],
 [((13, 36),
   ['aortic_valve_native_regurgitation_not_present',
    'aortic_valve_native_stenosis_not_present',
    'lv_dias_func_normal',
    'lv_dil_normal',
    'lv_sys_func_normal',
    'mitral_valve_native_regurgitation_not_present',
    'not negated',
    'pe',
    'rv_dil_present',
    'rv_sys_func_normal',
    'tricuspid_valve_native_regurgitation_not_present',
    'wma_not_present']),
  ((37, 70),
   ['aortic_valve_native_regurgitation_not_present',
    'aortic_valve_native_stenosis_not_present',
    'lv_dias_func_normal',
    'lv_dil_normal',
    'lv_sys_func_normal',
    'mitral_valve_native_re