Create Dataset

In [1]:
#Imports
import json
import pandas as pd
import logging
from sklearn.model_selection import train_test_split
from medcat.cdb import CDB
from medcat.config_rel_cat import ConfigRelCAT
from medcat.rel_cat import RelCAT
from medcat.utils.relation_extraction.base_component import BaseComponent_RelationExtraction
from medcat.utils.relation_extraction.bert.model import BaseModel_RelationExtraction
from medcat.utils.relation_extraction.bert.config import BaseConfig_RelationExtraction
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction

logging.basicConfig(level=logging.INFO)

In [2]:
# Load MedCAT data
with open("../data/MedCAT_Export_With_Text_2025-09-11_08_19_37.json", "r", encoding="utf-8") as f:
    data = json.load(f)

print(f"Loaded data with {len(data['projects'])} projects")

Loaded data with 1 projects


In [5]:
# Create dataset
DATE_CUI = "410671006"
rows = []

for project in data["projects"]:
    for doc in project["documents"]:
        doc_id = doc["id"]
        text = doc["text"]

        # All annotations in this doc
        anns = {a["id"]: a for a in doc.get("annotations", [])}

        # Collect dates and non-dates
        dates = [a for a in anns.values() if a.get("cui") == DATE_CUI]
        others = [a for a in anns.values() if a.get("cui") != DATE_CUI]

        # Relations explicitly annotated as 'link'
        link_pairs = set()
        for rel in doc.get("relations", []):
            link_pairs.add(tuple(sorted([rel["start_entity"], rel["end_entity"]])))

        # Create entity–date pairs
        for date in dates:
            for ent in others:
                ent1, ent2 = date, ent

                ent1_id, ent2_id = ent1["id"], ent2["id"]
                ent1_val, ent2_val = ent1["value"], ent2["value"]
                ent1_cui, ent2_cui = ent1.get("cui"), ent2.get("cui")
                ent1_s, ent1_e = ent1.get("start"), ent1.get("end")
                ent2_s, ent2_e = ent2.get("start"), ent2.get("end")

                # Determine label
                if tuple(sorted([ent1_id, ent2_id])) in link_pairs:
                    label, label_id = "LINK", 1
                else:
                    label, label_id = "NO_LINK", 0

                # Insert ADE-style markers into text
                def insert_marker(txt, start, end, tag_open, tag_close):
                    return txt[:start] + tag_open + txt[start:end] + tag_close + txt[end:]

                if ent1_s is not None and ent2_s is not None:
                    if ent1_s < ent2_s:
                        marked = insert_marker(text, ent2_s, ent2_e, "[s2]", "[e2]")
                        marked = insert_marker(marked, ent1_s, ent1_e, "[s1]", "[e1]")
                    else:
                        marked = insert_marker(text, ent1_s, ent1_e, "[s2]", "[e2]")
                        marked = insert_marker(marked, ent2_s, ent2_e, "[s1]", "[e1]")
                else:
                    marked = text

                rows.append({
                    "relation_token_span_ids": None,
                    "ent1_ent2_start": (ent1_s, ent2_s),
                    "ent1": ent1_val,
                    "ent2": ent2_val,
                    "label": label,
                    "label_id": label_id,
                    "ent1_type": "DATE",
                    "ent2_type": "ENTITY",
                    "ent1_id": ent1_id,
                    "ent2_id": ent2_id,
                    "ent1_cui": ent1_cui,
                    "ent2_cui": ent2_cui,
                    "doc_id": doc_id,
                    "text": marked
                })

df = pd.DataFrame(rows)

In [None]:
#Inspect data
df

In [6]:
# Split the dataset
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])

print(f"Train set: {len(train_df)} samples")
print(f"Test set: {len(test_df)} samples")
print("\nTrain label distribution:")
print(train_df["label"].value_counts())
print("\nTest label distribution:")
print(test_df["label"].value_counts())

# Save train and test sets
train_df.to_csv("../data/relations_date_entity_train.tsv", sep="\t", index=False)
test_df.to_csv("../data/relations_date_entity_test.tsv", sep="\t", index=False)

Train set: 1405 samples
Test set: 352 samples

Train label distribution:
label
NO_LINK    1375
LINK         30
Name: count, dtype: int64

Test label distribution:
label
NO_LINK    344
LINK         8
Name: count, dtype: int64


Training

In [7]:
# Configure RelCAT
config = ConfigRelCAT()
config.general.log_level = logging.INFO
config.general.model_name = "bert-base-uncased"
config.model.hidden_size = 256
config.model.model_size = 2304
config.general.cntx_left = 15
config.general.cntx_right = 15
config.general.window_size = 300
config.train.nclasses = 2
config.train.nepochs = 3
config.model.freeze_layers = False
config.general.limit_samples_per_class = 300
config.train.batch_size = 32
config.train.lr = 3e-5
config.train.adam_epsilon = 1e-8
config.train.adam_weight_decay = 0.0005

cdb = CDB()

In [8]:
#Create tokenizer
tokenizer = BaseTokenizerWrapper_RelationExtraction.load(tokenizer_path=config.general.model_name,
                                                       relcat_config=config)

special_ent_tokens = ["[s1]", "[e1]", "[s2]", "[e2]"]
tokenizer.hf_tokenizers.add_tokens(special_ent_tokens, special_tokens=True)
tokenizer.hf_tokenizers.add_special_tokens({'pad_token': '[PAD]'})

config.general.tokenizer_relation_annotation_special_tokens_tags = special_ent_tokens
config.general.annotation_schema_tag_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(special_ent_tokens)

INFO:medcat.utils.relation_extraction.tokenizer:Attempted to load Tokenizer from path:bert-base-uncased, but it doesn't exist, loading default toknizer from model_name relcat_config.general.model_name:bert-base-uncased
INFO:medcat.utils.relation_extraction.tokenizer:Addeding special tokens to tokenizer:['[s1]', '[e1]', '[s2]', '[e2]'] {'pad_token': '[PAD]'}


In [9]:
#Initiliaze model

# Load model configuration
model_config = BaseConfig_RelationExtraction.load(pretrained_model_name_or_path=config.general.model_name,
                                                 relcat_config=config)

# Update vocab size
model_config.hf_model_config.vocab_size = tokenizer.get_size()
config.model.padding_idx = model_config.pad_token_id = tokenizer.get_pad_id()

# Load model
model = BaseModel_RelationExtraction.load(pretrained_model_name_or_path=config.general.model_name,
                                         model_config=model_config,
                                         relcat_config=config)

# Resize embeddings
model.hf_model.resize_token_embeddings(len(tokenizer.hf_tokenizers))

# Create component
component = BaseComponent_RelationExtraction(tokenizer=tokenizer, config=config)
component.model = model
component.model_config = model_config
component.relcat_config = config
component.tokenizer = tokenizer


# Create RelCAT with component
relCAT = RelCAT(cdb, config=config)
relCAT.component = component

You are using a model of type bert to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:medcat.utils.relation_extraction.config:Loaded config from : bert-base-uncased\model_config.json
INFO:medcat.utils.relation_extraction.models:RelCAT model config: PretrainedConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.55.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30526
}

INFO:medcat.utils.relation_extraction.bert.model:RelCAT model config: PretrainedConfig {
  "architectures": [
    "BertForMaske

In [10]:
# Train the model
relCAT.train(
    train_csv_path="../data/relations_date_entity_train.tsv",  
    checkpoint_path="../models/relcat_models"
)

INFO:medcat.utils.relation_extraction.rel_dataset:CSV dataset | No. of relations detected:482| from : ../data/relations_date_entity_train.tsv | nclasses: 2 | idx2label: {0: 'NO_LINK', 1: 'LINK'}
INFO:medcat.utils.relation_extraction.rel_dataset:Samples per class: 
INFO:medcat.utils.relation_extraction.rel_dataset: label: NO_LINK | samples: 454
INFO:medcat.utils.relation_extraction.rel_dataset: label: LINK | samples: 28
INFO:root:Relations after train, test split :  train - 323 | test - 65
INFO:root: label: NO_LINK samples | train 300 | test 60
INFO:root: label: LINK samples | train 23 | test 5
INFO:root:Attempting to load RelCAT model on device: cpu
INFO:medcat.rel_cat:Starting training process...
INFO:medcat.rel_cat:Total epochs on this model: 3 | currently training epoch 0
100%|██████████| 323/323 [02:27<00:00,  2.20it/s]
Consider using tensor.detach() first. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\generated\python_variable_methods

In [11]:
#Save model
relCAT.save(save_path="../models/relcat_models")

Inference

In [12]:
#Load model
relCAT = RelCAT.load("../models/relcat_models")

INFO:medcat.rel_cat:The default CDB file name 'cdb.dat' doesn't exist in the specified path, you will need to load & set                 a CDB manually via rel_cat.cdb = CDB.load('path') 
INFO:root:Loaded config.json
INFO:medcat.utils.relation_extraction.bert.config:Loaded config from file: ../models/relcat_models\model_config.json
INFO:medcat.utils.relation_extraction.tokenizer:Tokenizer loaded TokenizerWrapperBERT_RelationExtraction from:../models/relcat_models
INFO:medcat.utils.relation_extraction.models:RelCAT model config: BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "po

In [14]:
# Load test data
test_df = pd.read_csv("../data/relations_date_entity_test.tsv", sep="\t")
print(f"Loaded test data: {len(test_df)} samples")

Loaded test data: 352 samples


In [20]:
# Run inference on all test documents with error handling
all_predictions = []

test_doc_ids = test_df['doc_id'].unique()

for doc_id in test_doc_ids:
    # Find the document
    for project in data["projects"]:
        for doc in project["documents"]:
            if doc["id"] == doc_id:
                try:
                    # Run inference
                    output_doc_with_relations = relCAT.predict_text_with_anns(
                        text=doc["text"], 
                        annotations=doc["annotations"]
                    )
                    
                    # Collect results using correct dict keys
                    for relation in output_doc_with_relations._.relations:
                        all_predictions.append({
                            'entity_label': relation['ent1_text'],
                            'date': relation['ent2_text'],
                            'confidence': relation['confidence'],
                            'doc_id': doc_id
                        })
                        
                except Exception as e:
                    print(f"Error processing document {doc_id}: {e}")
                    continue
                
                break

print(f"Processed {len(test_doc_ids)} test documents")
print(f"Total predictions: {len(all_predictions)}")

INFO:medcat.rel_cat:total relations for doc: 183
INFO:medcat.rel_cat:processing...
100%|██████████| 183/183 [00:33<00:00,  5.50it/s]


Error processing document 26461: min() arg is an empty sequence
Error processing document 26462: min() arg is an empty sequence


INFO:medcat.rel_cat:total relations for doc: 213
INFO:medcat.rel_cat:processing...
100%|██████████| 213/213 [00:39<00:00,  5.35it/s]


Error processing document 26465: min() arg is an empty sequence
Processed 5 test documents
Total predictions: 396


In [21]:
# Show results
print("Test Set Results:")
print(f"Total predictions: {len(all_predictions)}")

# Show first 10 predictions
print("\nFirst 10 predictions:")
for i, pred in enumerate(all_predictions[:10]):
    print(f"{i+1}. {pred['entity_label']} -> {pred['date']} (conf: {pred['confidence']:.3f}) [doc: {pred['doc_id']}]")

# Show high confidence predictions
high_conf = [p for p in all_predictions if p['confidence'] > 0.7]
print(f"\nHigh confidence predictions (>0.7): {len(high_conf)}")
for i, pred in enumerate(high_conf[:5]):  # Show first 5
    print(f"{i+1}. {pred['entity_label']} -> {pred['date']} (conf: {pred['confidence']:.3f})")

Test Set Results:
Total predictions: 396

First 10 predictions:
1. laboratory findings -> current medication (conf: 0.971) [doc: 26464]
2. laboratory findings -> compliance (conf: 0.956) [doc: 26464]
3. laboratory findings -> Cardiology (conf: 0.962) [doc: 26464]
4. laboratory findings -> etiology (conf: 0.972) [doc: 26464]
5. laboratory findings -> reports (conf: 0.965) [doc: 26464]
6. laboratory findings -> stroke (conf: 0.960) [doc: 26464]
7. laboratory findings -> follow (conf: 0.959) [doc: 26464]
8. laboratory findings -> COPD (conf: 0.959) [doc: 26464]
9. Current medications -> multiple_sclerosis (conf: 0.963) [doc: 26464]
10. Current medications -> review of systems (conf: 0.957) [doc: 26464]

High confidence predictions (>0.7): 396
1. laboratory findings -> current medication (conf: 0.971)
2. laboratory findings -> compliance (conf: 0.956)
3. laboratory findings -> Cardiology (conf: 0.962)
4. laboratory findings -> etiology (conf: 0.972)
5. laboratory findings -> reports (conf: