# Script to Extract Medcat UMLS Entities and Determine Embeddings

In [None]:
import json
from medcat.cat import CAT
import pandas as pd
from pathlib import Path
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel  
from tqdm import tqdm

In [2]:
# Experiment name
experiment_name = "umls_large"

# Define files and parameters
bioc_labelled_hallucinations_10_valid_mimic_summaries_path = '/home/s_hegs02/MedTator/13_agreed_label_silver_validation_examples/hallucinations_10_valid_mimic_agreed.jsonl'
bioc_labelled_hallucinations_100_mimic_summaries_path = '/home/s_hegs02/MedTator/12_agreed_label_silver_examples/hallucinations_100_mimic_agreed.jsonl'
# TODO: Replace with the agreed dataset
bioc_labelled_hallucinations_100_generated_summaries = '/home/s_hegs02/MedTator/20_label_halus_qualitatative_annotator_1/hallucinations_100_generated_annotator_1.jsonl'
dataset_paths = {'valid_mimic': bioc_labelled_hallucinations_10_valid_mimic_summaries_path, 'test_mimic': bioc_labelled_hallucinations_100_mimic_summaries_path, 'test_generated': bioc_labelled_hallucinations_100_generated_summaries}
entities_output_path = "/home/s_hegs02/mimic-iv-note-di-bhc/entities/"


# MedCat model
# Small model: UMLS Small (A modelpack containing a subset of UMLS (disorders, symptoms, medications...). Trained on MIMIC-III)
# cat_model_path = "/home/s_hegs02/medcat/models/umls_sm_pt2ch.zip"
# Large model: UMLS Full. >4MM concepts trained self-supervsied on MIMIC-III. v2022AA of UMLS.
cat_model_path = "/home/s_hegs02/medcat/models/umls_self_train_model.zip"
num_cpus = 4

# Semantic types of Griffin's "What's in a Summary" paper
# Disorders, Chemicals & Drugs, Procedures semantic groups, Lab Results 
# See groups here: https://lhncbc.nlm.nih.gov/ii/tools/MetaMap/Docs/SemGroups_2018.txt
filtered_semantic_types = [
    'T020', 'T190', 'T049', 'T019', 'T047', 'T050', 'T033', 'T037', 'T048', 'T191', 'T046', 'T184',
    'T116', 'T195', 'T123', 'T122', 'T103', 'T120', 'T104', 'T200', 'T196', 'T126', 'T131', 'T125', 'T129', 'T130', 'T197', 'T114', 'T109', 'T121', 'T192', 'T127',
    'T060', 'T065', 'T058', 'T059', 'T063', 'T062', 'T061', 
    'T034'
    ]


In [3]:
# Load dataset
def read_jsonl(path):
    input = []
    with open(path) as f:
        for line in f:
            input.append(json.loads(line))
    return input

datasets = {k: read_jsonl(v) for k, v in dataset_paths.items()}

# Verify that all labels are correctly located
for dataset_name, dataset in datasets.items():
    for i, doc in enumerate(dataset):
        for label in doc['labels']:
            assert label['start'] >= 0 and label['end'] <= len(doc['summary']), f"Label {label} in dataset {dataset_name} is out of bounds for text of length {len(doc['summary'])} in document {i}"
            assert doc['summary'][label['start']:label['end']] == label['text'], f"Label {label} in dataset {dataset_name} does not match text in document {i}"

In [None]:
# Extract entities for all concepts in the texts and summaries

# Load medcat model
cat = CAT.load_model_pack(cat_model_path)

# Get entities for all texts and summaries in the datasets
for dataset_name, dataset in datasets.items():
    output_file = Path(entities_output_path) / f"medcat_entities_{dataset_name}_{experiment_name}.json"

    if output_file.exists():
        print(f"File {output_file} already exists. Skipping...")
        continue
    
    print(f"Extracting medcat entities for dataset {dataset_name}...")

    # Load input json as pandas dataframe
    df = pd.DataFrame(dataset)[['text', 'summary']]
    assert df.notnull().values.all()

    # Prepare input data to MedCat by extracting all texts into one list
    i = 0
    in_data = []
    for col in df.columns:
        for _, text in enumerate(df[col].values):
            # Extract entities
            in_data.append((i, text))
            i += 1

    # Perform concept extraction
    # out_data is a dictionary for all input texts including a dictionay "entities" with all extracted entities for this text
    out_data = cat.multiprocessing(in_data, nproc=num_cpus)
    print(f'Total number of entities extracted: {sum([len(text["entities"]) for text in out_data.values()])}')

    # Count occurrences of semantic types in semantic_types in the extracted entities
    semantic_types_counts = {}
    for text in out_data.values():
        for entity in text['entities'].values():
            for semantic_type in entity['type_ids']:
                if semantic_type in filtered_semantic_types:
                    semantic_types_counts[semantic_type] = semantic_types_counts.get(semantic_type, 0) + 1
    print(f'Number of entities per inculded semantic type:')
    print({s: semantic_types_counts.get(s, 0) for s in filtered_semantic_types})

    # Filter out entities that are not in the semantic types
    for text in out_data.values():
        text['entities'] = {idx: entity for idx, entity in text['entities'].items() if any([s in entity['type_ids'] for s in filtered_semantic_types])}
    print(f'Total number of entities extracted after filtering: {sum([len(text["entities"]) for text in out_data.values()])}')

    # Write back all extracted entities into the same format as the input
    i = 0
    for col in df.columns:
        for j, _ in enumerate(df[col].values):
            df[col][j] = [out_data[i]]
            i += 1

    # Save output to json
    df.to_json(output_file, orient='records', indent=4)

In [5]:
# Load SapBERT model
tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")  
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext").cuda()

In [6]:
# Create embeddings with SapBERT (Liu et al., 2021) for all extraxted entities

# Get embeddings for all texts and summaries in the datasets
for dataset_name, dataset in datasets.items():
    output_file = Path(entities_output_path) / f"medcat_entities_sapbert_embeddings_{dataset_name}_{experiment_name}.json"

    if output_file.exists():
        print(f"File {output_file} already exists. Skipping...")
        continue
    
    print(f"Extracting sapbert embeddings for dataset {dataset_name}...")

    medcat_file = Path(entities_output_path) / f"medcat_entities_{dataset_name}_{experiment_name}.json"
    if not medcat_file.exists():
        print(f"File {medcat_file} does not exist. Skipping...")
        continue
    entities = pd.read_json(medcat_file)
    total_num_entities = sum([len(text[0]['entities']) for text in entities['text'].values]) + sum([len(text[0]['entities']) for text in entities['summary'].values])

    # Extract all relevant text sections 
    all_pretty_names = []
    all_source_values = []
    for _, text in enumerate(entities['text'].values):
        for _, entity in text[0]['entities'].items():
            all_pretty_names.append(entity['pretty_name'])
            all_source_values.append(entity['source_value'])
    for _, text in enumerate(entities['summary'].values):
        for _, entity in text[0]['entities'].items():
            all_pretty_names.append(entity['pretty_name'])
            all_source_values.append(entity['source_value'])
            
    assert total_num_entities == len(all_pretty_names) == len(all_source_values), f"Total number of entities {total_num_entities} does not match the number of pretty names {len(all_pretty_names)} or source values {len(all_source_values)}"
    print(f"Total number of entities: {total_num_entities}")

    # From: https://github.com/cambridgeltl/sapbert/blob/main/inference/inference_on_snomed.ipynb
    bs = 128
    all_names = all_pretty_names + all_source_values
    all_reps = []
    for i in tqdm(np.arange(0, len(all_names), bs)):
        toks = tokenizer.batch_encode_plus(all_names[i:i+bs], 
                                        padding="max_length", 
                                        max_length=25, 
                                        truncation=True,
                                        return_tensors="pt").to('cuda')
        output = model(**toks)
        cls_rep = output[0][:,0,:]
        
        all_reps.append(cls_rep.cpu().detach().numpy())
    all_reps_emb = np.concatenate(all_reps, axis=0)

    assert len(all_reps_emb) == len(all_names), f"Number of embeddings {len(all_reps_emb)} does not match the number of names {len(all_names)}"

    all_pretty_names_emb = list(all_reps_emb[:len(all_pretty_names)])
    all_source_values_emb = list(all_reps_emb[len(all_pretty_names):])

    # Save embeddings with entities
    for _, text in enumerate(entities['text'].values):
        for _, entity in text[0]['entities'].items():
            entity['pretty_name_embedding'] = all_pretty_names_emb.pop(0).tolist()
            entity['source_value_embedding'] = all_source_values_emb.pop(0).tolist()
    for _, text in enumerate(entities['summary'].values):
        for _, entity in text[0]['entities'].items():
            entity['pretty_name_embedding'] = all_pretty_names_emb.pop(0).tolist()
            entity['source_value_embedding'] = all_source_values_emb.pop(0).tolist()

    # Save output to json
    entities.to_json(output_file, orient='records', indent=4)

Extracting sapbert embeddings for dataset valid_mimic...
Total number of entities: 519


  0%|          | 0/9 [00:00<?, ?it/s]

100%|██████████| 9/9 [00:00<00:00, 13.65it/s]


Extracting sapbert embeddings for dataset test_mimic...
Total number of entities: 4891


100%|██████████| 77/77 [00:01<00:00, 46.53it/s]


Extracting sapbert embeddings for dataset test_generated...
Total number of entities: 4638


100%|██████████| 73/73 [00:01<00:00, 46.18it/s]
