Imports

In [1]:
#Imports
import logging
import json
import pandas as pd

from medcat.cdb import CDB
from medcat.config_rel_cat import ConfigRelCAT
from medcat.rel_cat import RelCAT
from medcat.utils.relation_extraction.base_component import BaseComponent_RelationExtraction
from medcat.utils.relation_extraction.bert.model import BaseModel_RelationExtraction
from medcat.utils.relation_extraction.bert.config import BaseConfig_RelationExtraction
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction

import sys, os
utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from general_utils import load_data

Data Loading

In [2]:
# Load data
df = load_data("../data/training_dataset_synthetic.csv")
print(f"Loaded {len(df)} records")

Loaded 101 records


In [3]:
#Inspect df
df.head()

Unnamed: 0,doc_id,note_text,entities_json,dates_json,relations_json,relative_dates_json
0,0,Ultrasound (30nd Jun 2024): no significant fin...,"[{'id': 300001, 'value': 'asthma', 'cui': 'pla...","[{'id': 308001, 'value': '02nd Aug 2024', 'sta...","[{'date': '02nd Aug 2024', 'entity': 'asthma',...",[]
1,1,Labs (27th Sep 2024): anemia. resolving Skin:...,"[{'id': 300001, 'value': 'multiple_sclerosis',...","[{'id': 308001, 'value': '27th Sep 2024', 'sta...","[{'date': '27th Sep 2024', 'entity': 'multiple...",[]
2,2,URGENT REVIEW (2024-10-04): cough. suspect ost...,"[{'id': 300001, 'value': 'osteoarthritis', 'cu...","[{'id': 308001, 'value': '2024-10-04', 'start'...","[{'date': '2024-10-04', 'entity': 'osteoarthri...",[]
3,3,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,"[{'id': 300001, 'value': 'schizophrenia', 'cui...","[{'id': 308001, 'value': '13rd Feb 2025', 'sta...","[{'date': '13rd Feb 2025', 'entity': 'schizoph...",[]
4,4,New pt((18/11/24)): pt presents with nausea/vo...,"[{'id': 300001, 'value': 'diabetes_mellitus', ...","[{'id': 308001, 'value': '18/11/24', 'start': ...","[{'date': '18/11/24', 'entity': 'diabetes_mell...",[]


Create Training Dataset

In [4]:
#Insert marker function
def insert_marker(txt, start, end, tag_open, tag_close):
                    return txt[:start] + tag_open + txt[start:end] + tag_close + txt[end:]

In [5]:
# Create dataset for training
rows = []

for _, row in df.iterrows():
    doc_id = row["doc_id"]
    text = row["note_text"]
    
    # Parse the JSON columns - handle both string and already-parsed cases
    entities = row["entities_json"] if isinstance(row["entities_json"], list) else json.loads(row["entities_json"])
    dates = row["dates_json"] if isinstance(row["dates_json"], list) else json.loads(row["dates_json"])
    relative_dates = row["relative_dates_json"] if isinstance(row["relative_dates_json"], list) else json.loads(row["relative_dates_json"]) if "relative_dates_json" in row else []
    
    # Combine absolute and relative dates
    all_dates = dates + relative_dates
    
    # Get original relations for labeling
    relations = row["relations_json"] if isinstance(row["relations_json"], list) else json.loads(row["relations_json"])
    
    # Build relation pairs from validated relations - match by IDs
    relation_pairs = {tuple(sorted([str(L["date_id"]), str(L["entity_id"])])) for L in relations}
    
    # Create pairs for all date-entity combinations
    for date in all_dates:
        for entity in entities:
            # Use IDs for matching
            date_id = str(date["id"])
            entity_id = str(entity["id"])
            
            # Determine label
            if tuple(sorted([date_id, entity_id])) in relation_pairs:
                label, label_id = "RELATION", 1
            else:
                label, label_id = "NO_RELATION", 0
            
            # Insert markers
            s1, e1 = date.get("start"), date.get("end")
            s2, e2 = entity.get("start"), entity.get("end")
            
            if s1 is not None and s2 is not None:
                if s1 < s2:
                    marked = insert_marker(text, s2, e2, "[s2]", "[e2]")
                    marked = insert_marker(marked, s1, e1, "[s1]", "[e1]")
                else:
                    marked = insert_marker(text, s1, e1, "[s2]", "[e2]")
                    marked = insert_marker(marked, s2, e2, "[s1]", "[e1]")
            else:
                marked = text
            
            rows.append({
                "relation_token_span_ids": None,
                "ent1_ent2_start": (s1, s2),
                "ent1": date.get("value", ""),
                "ent2": entity.get("value", ""),
                "label": label,
                "label_id": label_id,
                "ent1_type": "DATE",
                "ent2_type": "ENTITY",
                "ent1_id": date_id,
                "ent2_id": entity_id,
                "ent1_cui": None,
                "ent2_cui": None,
                "doc_id": doc_id,
                "text": marked
            })

training_df = pd.DataFrame(rows)
print(f"Created {len(training_df)} date-entity pairs (including relative dates)")

Created 1242 date-entity pairs (including relative dates)


In [6]:
#Inspect data
print(f"Dataset: {len(training_df)} samples")
print("\nLabel distribution:")
print(training_df["label"].value_counts())

training_df.head()

Dataset: 1242 samples

Label distribution:
label
NO_RELATION    1060
RELATION        182
Name: count, dtype: int64


Unnamed: 0,relation_token_span_ids,ent1_ent2_start,ent1,ent2,label,label_id,ent1_type,ent2_type,ent1_id,ent2_id,ent1_cui,ent2_cui,doc_id,text
0,,"(311, 57)",02nd Aug 2024,asthma,RELATION,1,DATE,ENTITY,308001,300001,,,0,Ultrasound (30nd Jun 2024): no significant fin...
1,,"(311, 410)",02nd Aug 2024,pituitary_adenoma,NO_RELATION,0,DATE,ENTITY,308001,300002,,,0,Ultrasound (30nd Jun 2024): no significant fin...
2,,"(311, 491)",02nd Aug 2024,rheumatoid_arthritis,NO_RELATION,0,DATE,ENTITY,308001,300003,,,0,Ultrasound (30nd Jun 2024): no significant fin...
3,,"(311, 1143)",02nd Aug 2024,pneumonia,NO_RELATION,0,DATE,ENTITY,308001,300004,,,0,Ultrasound (30nd Jun 2024): no significant fin...
4,,"(311, 1305)",02nd Aug 2024,gerd,NO_RELATION,0,DATE,ENTITY,308001,300005,,,0,Ultrasound (30nd Jun 2024): no significant fin...


In [7]:
# Per-document summary from original data
summary = []

for _, row in df.iterrows():
    ## Parse the JSON columns - handle both string and already-parsed cases
    entities = row["entities_json"] if isinstance(row["entities_json"], list) else json.loads(row["entities_json"])
    dates = row["dates_json"] if isinstance(row["dates_json"], list) else json.loads(row["dates_json"])
    relative_dates = row["relative_dates_json"] if isinstance(row["relative_dates_json"], list) else json.loads(row["relative_dates_json"]) if "relative_dates_json" in row else []
    relations = row["relations_json"] if isinstance(row["relations_json"], list) else json.loads(row["relations_json"])
    
    summary.append({
        "doc_id": row["doc_id"],
        "n_entities": len(entities),
        "n_dates": len(dates),
        "n_relative_dates": len(relative_dates),
        "n_relations": len(relations),
    })

doc_level = pd.DataFrame(summary)
doc_level

Unnamed: 0,doc_id,n_entities,n_dates,n_relative_dates,n_relations
0,0,6,2,0,2
1,1,6,2,0,3
2,2,9,2,0,3
3,3,6,2,0,2
4,4,6,2,0,3
...,...,...,...,...,...
96,96,6,2,0,2
97,97,6,2,0,2
98,98,6,2,0,2
99,99,6,2,0,1


In [8]:
# Pair-based stats from the generated df
pair_stats = (
    training_df.groupby("doc_id")
      .agg(
          total_pairs=("label", "size"),
          relation_pairs=("label", lambda s: (s == "RELATION").sum()),
      )
      .reset_index()
)

pair_stats["relation_pct"] = (100 * pair_stats["relation_pairs"] / pair_stats["total_pairs"]).round(1)
pair_stats

Unnamed: 0,doc_id,total_pairs,relation_pairs,relation_pct
0,0,12,2,16.7
1,1,12,3,25.0
2,2,18,3,16.7
3,3,12,2,16.7
4,4,12,3,25.0
...,...,...,...,...
96,96,12,2,16.7
97,97,12,2,16.7
98,98,12,2,16.7
99,99,12,1,8.3


In [9]:
# Merge stats with doc_level summary
doc_level = (
    doc_level.merge(pair_stats[["doc_id", "total_pairs", "relation_pct"]], on="doc_id", how="left")
             .sort_values("doc_id")
             .reset_index(drop=True)
)

# Calculate additional metrics
total_pairs_overall = len(training_df)
relation_total = (training_df["label"] == "RELATION").sum()
relation_pct_overall = 100 * (training_df["label"] == "RELATION").mean()

print(f"total number of date-entity pairs is {total_pairs_overall}")
print(f"total number of relations: {relation_total}")
print(f"percentage positive class: {relation_pct_overall:.1f}%")

#Look at overall doc level summary
doc_level

total number of date-entity pairs is 1242
total number of relations: 182
percentage positive class: 14.7%


Unnamed: 0,doc_id,n_entities,n_dates,n_relative_dates,n_relations,total_pairs,relation_pct
0,0,6,2,0,2,12,16.7
1,1,6,2,0,3,12,25.0
2,2,9,2,0,3,18,16.7
3,3,6,2,0,2,12,16.7
4,4,6,2,0,3,12,25.0
...,...,...,...,...,...,...,...
96,96,6,2,0,2,12,16.7
97,97,6,2,0,2,12,16.7
98,98,6,2,0,2,12,16.7
99,99,6,2,0,1,12,8.3


In [10]:
# Save training dataset
training_df.to_csv("../data/relcat_training_data.tsv", sep="\t", index=False)

Training & Model Config

In [11]:
#Choose model to use - any BERT model from HuggingFace can be used, see: https://huggingface.co/google-bert
model = "bert-base-uncased"

In [12]:
#Set path to save trained model and checkpoints to
model_save_path = '../models/relcat_models'

In [13]:
#Create RelCAT config and set parameters
config = ConfigRelCAT()
config.general.log_level = logging.INFO
config.general.model_name = model
#logging.basicConfig(level=logging.INFO)

In [14]:
#Hidden size, model size and hidden layers
config.model.hidden_size= 256
config.model.model_size = 2304 # 4096 for llama

In [None]:
# Further config
config.general.cntx_left = 15
config.general.cntx_right = 15
config.general.window_size = 400
config.train.nclasses = 2
config.train.nepochs = 3
config.model.freeze_layers = False
config.general.limit_samples_per_class = 300
config.train.batch_size = 32
config.train.lr = 3e-5
config.train.adam_epsilon = 1e-8
config.train.adam_weight_decay = 0.0005
config.train.class_weights = [0.3027, 1.6973]
config.train.enable_class_weights = True

In [16]:
#Create CDB
cdb = CDB()

In [17]:
#Create tokenizer
tokenizer = BaseTokenizerWrapper_RelationExtraction.load(tokenizer_path=config.general.model_name,
                                                       relcat_config=config)

In [18]:
#Add special tokens
special_ent_tokens = ["[s1]", "[e1]", "[s2]", "[e2]"]
tokenizer.hf_tokenizers.add_tokens(special_ent_tokens, special_tokens=True)
tokenizer.hf_tokenizers.add_special_tokens({'pad_token': '[PAD]'}) # used in llama tokenizer

0

In [19]:
#Add tokens to config
config.general.tokenizer_relation_annotation_special_tokens_tags = special_ent_tokens
config.general.annotation_schema_tag_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(special_ent_tokens)

In [20]:
#Create RelCAT object
relCAT = RelCAT(cdb, config=config)

INFO:medcat.utils.relation_extraction.base_component:BaseComponent_RelationExtraction initialized


In [21]:
# Load model configuration
model_config = BaseConfig_RelationExtraction.load(pretrained_model_name_or_path=config.general.model_name,
                                                 relcat_config=config)

You are using a model of type bert to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:medcat.utils.relation_extraction.config:Loaded config from : bert-base-uncased\model_config.json


In [22]:
# Update vocab size in model config to match tokenizer
model_config.hf_model_config.vocab_size = tokenizer.get_size()

In [23]:
# set the padding idx in the model config and relcat config, this is necesasry as it depends on what tokenizer you use
config.model.padding_idx = model_config.pad_token_id = tokenizer.get_pad_id()

In [24]:
# Load model
model = BaseModel_RelationExtraction.load(pretrained_model_name_or_path=config.general.model_name,
                                         model_config=model_config,
                                         relcat_config=config)

INFO:medcat.utils.relation_extraction.models:RelCAT model config: PretrainedConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.55.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30526
}

INFO:medcat.utils.relation_extraction.bert.model:RelCAT model config: PretrainedConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  

In [25]:
# Resize embeddings to match tokenizer
model.hf_model.resize_token_embeddings(len(tokenizer.hf_tokenizers))

Embedding(30526, 768, padding_idx=0)

In [26]:
# Create RelCAT component
component = BaseComponent_RelationExtraction(tokenizer=tokenizer, config=config)
component.model = model
component.model_config = model_config
component.relcat_config = config
component.tokenizer = tokenizer
relCAT.component = component

INFO:medcat.utils.relation_extraction.base_component:BaseComponent_RelationExtraction initialized


Training

In [27]:
# Train the model using the dataset we created earlier
relCAT.train(
    train_csv_path="../data/relcat_training_data.tsv",  
    checkpoint_path=model_save_path
)

INFO:medcat.utils.relation_extraction.rel_dataset:CSV dataset | No. of relations detected:514| from : ../data/relcat_training_data.tsv | nclasses: 2 | idx2label: {0: 'NO_RELATION', 1: 'RELATION'}
INFO:medcat.utils.relation_extraction.rel_dataset:Samples per class: 
INFO:medcat.utils.relation_extraction.rel_dataset: label: NO_RELATION | samples: 395
INFO:medcat.utils.relation_extraction.rel_dataset: label: RELATION | samples: 119
INFO:root:Relations after train, test split :  train - 396 | test - 83
INFO:root: label: RELATION samples | train 96 | test 23
INFO:root: label: NO_RELATION samples | train 300 | test 60
INFO:root:Attempting to load RelCAT model on device: cpu
INFO:medcat.rel_cat:Starting training process...
INFO:medcat.rel_cat:Total epochs on this model: 3 | currently training epoch 0
100%|██████████| 396/396 [04:41<00:00,  1.41it/s]
Consider using tensor.detach() first. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\generated\pyth

In [28]:
#Save model
relCAT.save(save_path=model_save_path)