Create Dataset

In [None]:
#Imports
import logging
import json
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

from medcat.cdb import CDB
from medcat.config_rel_cat import ConfigRelCAT
from medcat.rel_cat import RelCAT
from medcat.utils.relation_extraction.base_component import BaseComponent_RelationExtraction
from medcat.utils.relation_extraction.bert.model import BaseModel_RelationExtraction
from medcat.utils.relation_extraction.bert.config import BaseConfig_RelationExtraction
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper_RelationExtraction

import sys, os
utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from medcat_etl import (
    DATE_CUI,
    get_validated_entities, get_validated_dates, get_validated_links,
    id2value_from_items
)

from relative_date_extractor import add_relative_dates

In [2]:
# Load MedCAT data
with open("../data/MedCAT_Export_With_Text_2025-09-11_08_19_37.json", "r", encoding="utf-8") as f:
    data = json.load(f)

In [3]:
# Count documents
num_docs = sum(len(p.get("documents", [])) for p in data.get("projects", []))
print(f"Loaded data with {num_docs} documents")

Loaded data with 5 documents


In [4]:
# Convert medcat data to df with json columns to match other extractors / notebooks
medcat_rows = []
for project in data.get("projects", []):
    for doc in project.get("documents", []):
        doc_id = doc.get("id")
        text = doc.get("text", "")
        
        # Extract validated entities and dates using medcat_etl.py logic
        ents = get_validated_entities(doc, DATE_CUI)
        dates = get_validated_dates(doc, DATE_CUI)
        links = get_validated_links(doc, DATE_CUI)
        
        # Convert to JSON format expected by add_relative_dates
        entities_json = json.dumps([
            {"id": e["id"], "value": e["value"], "cui": e.get("cui"), "start": e.get("start"), "end": e.get("end")} 
            for e in ents
        ])
        dates_json = json.dumps([
            {"id": d["id"], "value": d["value"], "start": d.get("start"), "end": d.get("end")} 
            for d in dates
        ])
        links_json = json.dumps([
            {"date": next((d["value"] for d in dates if d["id"] == L["date_id"]), ""),
             "entity": next((e["value"] for e in ents if e["id"] == L["entity_id"]), "")}
            for L in links
        ])
        
        medcat_rows.append({
            "doc_id": doc_id,
            "note_text": text,
            "entities_json": entities_json,
            "dates_json": dates_json,
            "links_json": links_json
        })

medcat_df = pd.DataFrame(medcat_rows)
print(f"Created MedCAT format dataset with {len(medcat_df)} documents")

Created MedCAT format dataset with 5 documents


In [5]:
# Add relative dates using the same logic as other notebooks
if 'relative_dates_json' not in medcat_df.columns:
    medcat_df = add_relative_dates(medcat_df)
    print("Added relative dates")
else:
    print("Relative dates already present, skipping extraction")

Added relative dates


In [6]:
#Inspect data
medcat_df

Unnamed: 0,doc_id,note_text,entities_json,dates_json,links_json,relative_dates_json
0,26461,Ultrasound (30nd Jun 2024): no significant fin...,"[{""id"": 308244, ""value"": ""history of meningiti...","[{""id"": 308320, ""value"": ""30nd Jun 2024"", ""sta...","[{""date"": ""12nd Sep 2024"", ""entity"": ""pituitar...",[]
1,26462,Labs (27th Sep 2024): anemia. resolving Skin:...,"[{""id"": 308371, ""value"": ""lesions"", ""cui"": ""52...","[{""id"": 308581, ""value"": ""22/11/24"", ""start"": ...","[{""date"": ""27th Sep 2024"", ""entity"": ""anemia""}...",[]
2,26463,URGENT REVIEW (2024-10-04): cough. suspect ost...,"[{""id"": 308886, ""value"": ""frequent urination"",...","[{""id"": 308940, ""value"": ""2024-10-04"", ""start""...","[{""date"": ""2024-10-04"", ""entity"": ""cough""}, {""...",[]
3,26464,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,"[{""id"": 308951, ""value"": ""multiple_sclerosis"",...","[{""id"": 308996, ""value"": ""05-03-2025"", ""start""...","[{""date"": ""13rd Feb 2025"", ""entity"": ""visual""}...",[]
4,26465,New pt((18/11/24)): pt presents with nausea/vo...,"[{""id"": 308998, ""value"": ""history of neoplasm ...","[{""id"": 309070, ""value"": ""18/11/24"", ""start"": ...","[{""date"": ""18/11/24"", ""entity"": ""nausea/vomiti...",[]


In [7]:
#Insert marker function
def insert_marker(txt, start, end, tag_open, tag_close):
                    return txt[:start] + tag_open + txt[start:end] + tag_close + txt[end:]

In [8]:
# Create dataset for training
rows = []

for _, row in medcat_df.iterrows():
    doc_id = row["doc_id"]
    text = row["note_text"]
    
    # Parse the JSON columns
    entities = json.loads(row["entities_json"])
    dates = json.loads(row["dates_json"])
    relative_dates = json.loads(row["relative_dates_json"]) if "relative_dates_json" in row else []
    
    # Combine absolute and relative dates
    all_dates = dates + relative_dates
    
    # Get original links for labeling
    links = []
    for project in data.get("projects", []):
        for doc in project.get("documents", []):
            if doc.get("id") == doc_id:
                links = get_validated_links(doc, DATE_CUI)
                break
        if links:  # Break outer loop if we found links
            break
    
    # Build link pairs from validated links (convert to strings to match)
    link_pairs = {tuple(sorted([str(L["date_id"]), str(L["entity_id"])])) for L in links}
    
    # Create pairs for all date-entity combinations
    for date in all_dates:
        for entity in entities:
            # Use real ID for both absolute and relative dates
            date_id = str(date["id"])
            entity_id = str(entity["id"])
            
            # Determine label (both absolute and relative dates can have links)
            if tuple(sorted([date_id, entity_id])) in link_pairs:
                label, label_id = "LINK", 1
            else:
                label, label_id = "NO_LINK", 0
            
            # Insert markers
            s1, e1 = date.get("start"), date.get("end")
            s2, e2 = entity.get("start"), entity.get("end")
            
            if s1 is not None and s2 is not None:
                if s1 < s2:
                    marked = insert_marker(text, s2, e2, "[s2]", "[e2]")
                    marked = insert_marker(marked, s1, e1, "[s1]", "[e1]")
                else:
                    marked = insert_marker(text, s1, e1, "[s2]", "[e2]")
                    marked = insert_marker(marked, s2, e2, "[s1]", "[e1]")
            else:
                marked = text
            
            rows.append({
                "relation_token_span_ids": None,
                "ent1_ent2_start": (s1, s2),
                "ent1": date.get("value", ""),
                "ent2": entity.get("value", ""),
                "label": label,
                "label_id": label_id,
                "ent1_type": "DATE",
                "ent2_type": "ENTITY",
                "ent1_id": date_id,
                "ent2_id": entity_id,
                "ent1_cui": None,
                "ent2_cui": None,
                "doc_id": doc_id,
                "text": marked
            })

df = pd.DataFrame(rows)
print(f"Created {len(df)} date-entity pairs (including relative dates)")

Created 741 date-entity pairs (including relative dates)


In [9]:
#Inspect data
#df

In [10]:
# Per-document summary 
summary = []

for project in data.get("projects", []):
    for d in project.get("documents", []):
        ents  = get_validated_entities(d, DATE_CUI)
        dates = get_validated_dates(d, DATE_CUI)
        links = get_validated_links(d, DATE_CUI)
        summary.append({
            "doc_id": d.get("id"),
            "n_entities": len(ents),
            "n_dates": len(dates),
            "n_links": len(links),
        })

doc_level = pd.DataFrame(summary)
doc_level

Unnamed: 0,doc_id,n_entities,n_dates,n_links
0,26461,64,6,4
1,26462,21,7,11
2,26463,15,7,12
3,26464,11,3,6
4,26465,24,3,4


In [11]:
# Pair-based stats from the generated df
pair_stats = (
    df.groupby("doc_id")
      .agg(
          total_pairs=("label", "size"),
          link_pairs=("label", lambda s: (s == "LINK").sum()),
      )
      .reset_index()
)

pair_stats["link_pct"] = (100 * pair_stats["link_pairs"] / pair_stats["total_pairs"]).round(1)
pair_stats

Unnamed: 0,doc_id,total_pairs,link_pairs,link_pct
0,26461,384,4,1.0
1,26462,147,11,7.5
2,26463,105,12,11.4
3,26464,33,6,18.2
4,26465,72,4,5.6


In [12]:
# Merge stats with doc_level summary
doc_level = (
    doc_level.merge(pair_stats[["doc_id", "total_pairs", "link_pct"]], on="doc_id", how="left")
             .sort_values("doc_id")
             .reset_index(drop=True)
)

# Calculate additional metrics
total_pairs_overall = len(df)
link_total = (df["label"] == "LINK").sum()
link_pct_overall = 100 * (df["label"] == "LINK").mean()

print(f"total number of date-entity pairs is {total_pairs_overall}")
print(f"total number of links: {link_total}")
print(f"percentage positive class: {link_pct_overall:.1f}%")

#Look at overall doc level summary
doc_level

total number of date-entity pairs is 741
total number of links: 37
percentage positive class: 5.0%


Unnamed: 0,doc_id,n_entities,n_dates,n_links,total_pairs,link_pct
0,26461,64,6,4,384,1.0
1,26462,21,7,11,147,7.5
2,26463,15,7,12,105,11.4
3,26464,11,3,6,33,18.2
4,26465,24,3,4,72,5.6


In [13]:
# Save dataset
#df.to_csv("../data/relcat_training_data.tsv", sep="\t", index=False)

In [14]:
# Split the dataset and save

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])

print(f"Train set: {len(train_df)} samples")
print(f"Test set: {len(test_df)} samples")
print("\nTrain label distribution:")
print(train_df["label"].value_counts())
print("\nTest label distribution:")
print(test_df["label"].value_counts())

# Save train and test sets
train_df.to_csv("../data/relcat_training_data.tsv", sep="\t", index=False)
test_df.to_csv("../data/relcat_test_data.tsv", sep="\t", index=False)

Train set: 592 samples
Test set: 149 samples

Train label distribution:
label
NO_LINK    562
LINK        30
Name: count, dtype: int64

Test label distribution:
label
NO_LINK    142
LINK         7
Name: count, dtype: int64


Config & Training

In [15]:
#Create RelCAT config and set parameters
config = ConfigRelCAT()
config.general.log_level = logging.INFO
config.general.model_name = "bert-base-uncased" # base model that you want to use, we're going to use the HuggingFace bert-base-uncased model

#logging.basicConfig(level=logging.INFO)

In [16]:
#Hidden size, model size and hidden layers
config.model.hidden_size= 256
config.model.model_size = 2304 # 4096 for llama

In [17]:
# Further config
config.general.cntx_left = 15 # how many tokens to the left of the start entity we select
config.general.cntx_right = 15 # how many tokens to the right of the end entity we selecd
config.general.window_size = 300 # distance (in characters) between two entities to be considered a relation
config.train.nclasses = 2 # number of classes in your medcat export / dataset
config.train.nepochs = 3 # number of epochs to train for
config.model.freeze_layers = False # whether to freeze the layers of the base model
config.general.limit_samples_per_class = 300 # limit the number of training samples per class to this number, to avoid overfitting in unbalanced datasets
config.train.batch_size = 32 # batch size
config.train.lr = 3e-5
config.train.adam_epsilon = 1e-8
config.train.adam_weight_decay = 0.0005

In [18]:
#Create CDB
cdb = CDB()

In [19]:
#Create tokenizer
tokenizer = BaseTokenizerWrapper_RelationExtraction.load(tokenizer_path=config.general.model_name,
                                                       relcat_config=config)

In [20]:
#Add special tokens
special_ent_tokens = ["[s1]", "[e1]", "[s2]", "[e2]"]
tokenizer.hf_tokenizers.add_tokens(special_ent_tokens, special_tokens=True)
tokenizer.hf_tokenizers.add_special_tokens({'pad_token': '[PAD]'}) # used in llama tokenizer

0

In [21]:
#Add tokens to config
config.general.tokenizer_relation_annotation_special_tokens_tags = special_ent_tokens
config.general.annotation_schema_tag_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(special_ent_tokens)

In [22]:
#Create RelCAT object and initialize components
# if you wish to skip the steps in section 6.1 you can pass the init_model=True arguement to intialize the components with the default ConfigRelCAT settings.
relCAT = RelCAT(cdb, config=config)

INFO:medcat.utils.relation_extraction.base_component:BaseComponent_RelationExtraction initialized


In [23]:
# Load model configuration
model_config = BaseConfig_RelationExtraction.load(pretrained_model_name_or_path=config.general.model_name,
                                                 relcat_config=config)

You are using a model of type bert to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO:medcat.utils.relation_extraction.config:Loaded config from : bert-base-uncased\model_config.json


In [24]:
# Update vocab size in model config to match tokenizer
model_config.hf_model_config.vocab_size = tokenizer.get_size()

In [25]:
# set the padding idx in the model config and relcat config, this is necesasry as it depends on what tokenizer you use
config.model.padding_idx = model_config.pad_token_id = tokenizer.get_pad_id()

In [26]:
# Load model
model = BaseModel_RelationExtraction.load(pretrained_model_name_or_path=config.general.model_name,
                                         model_config=model_config,
                                         relcat_config=config)

INFO:medcat.utils.relation_extraction.models:RelCAT model config: PretrainedConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.55.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30526
}

INFO:medcat.utils.relation_extraction.bert.model:RelCAT model config: PretrainedConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  

In [27]:
# Resize embeddings to match tokenizer
model.hf_model.resize_token_embeddings(len(tokenizer.hf_tokenizers)) # type: ignore

Embedding(30526, 768, padding_idx=0)

In [28]:
# Create component
component = BaseComponent_RelationExtraction(tokenizer=tokenizer, config=config)
component.model = model
component.model_config = model_config
component.relcat_config = config
component.tokenizer = tokenizer
relCAT.component = component

INFO:medcat.utils.relation_extraction.base_component:BaseComponent_RelationExtraction initialized


In [29]:
# Train the model using the dataset we created
relCAT.train(
    train_csv_path="../data/relcat_training_data.tsv",  
    checkpoint_path="../models/relcat_models"
)

INFO:medcat.utils.relation_extraction.rel_dataset:CSV dataset | No. of relations detected:208| from : ../data/relcat_training_data.tsv | nclasses: 2 | idx2label: {0: 'LINK', 1: 'NO_LINK'}
INFO:medcat.utils.relation_extraction.rel_dataset:Samples per class: 
INFO:medcat.utils.relation_extraction.rel_dataset: label: LINK | samples: 28
INFO:medcat.utils.relation_extraction.rel_dataset: label: NO_LINK | samples: 180
INFO:root:Relations after train, test split :  train - 167 | test - 41
INFO:root: label: NO_LINK samples | train 144 | test 36
INFO:root: label: LINK samples | train 23 | test 5
INFO:root:Attempting to load RelCAT model on device: cpu
INFO:medcat.rel_cat:Starting training process...
INFO:medcat.rel_cat:Total epochs on this model: 3 | currently training epoch 0
100%|██████████| 167/167 [01:15<00:00,  2.22it/s]
Consider using tensor.detach() first. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\generated\python_variable_methods.cpp:83

In [30]:
#Save model
relCAT.save(save_path="../models/relcat_models")

Inference & Testing

In [31]:
#Load model
relCAT = RelCAT.load("../models/relcat_models")

INFO:medcat.rel_cat:The default CDB file name 'cdb.dat' doesn't exist in the specified path, you will need to load & set                 a CDB manually via rel_cat.cdb = CDB.load('path') 
INFO:root:Loaded config.json
INFO:medcat.utils.relation_extraction.bert.config:Loaded config from file: ../models/relcat_models\model_config.json
INFO:medcat.utils.relation_extraction.tokenizer:Tokenizer loaded TokenizerWrapperBERT_RelationExtraction from:../models/relcat_models
INFO:medcat.utils.relation_extraction.models:RelCAT model config: BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "po

In [32]:
# Load test data
test_df = pd.read_csv("../data/relcat_test_data.tsv", sep="\t")
print(f"Loaded test data: {len(test_df)} samples")

Loaded test data: 149 samples


In [54]:
test_df

Unnamed: 0,relation_token_span_ids,ent1_ent2_start,ent1,ent2,label,label_id,ent1_type,ent2_type,ent1_id,ent2_id,ent1_cui,ent2_cui,doc_id,text
0,,"(715, 305)",07.05.25,weight gain,NO_LINK,0,DATE,ENTITY,308945,308898,,,26463,URGENT REVIEW (2024-10-04): cough. suspect ost...
1,,"(588, 1221)",23rd Oct 2024,improved,NO_LINK,0,DATE,ENTITY,308323,308426,,,26461,Ultrasound (30nd Jun 2024): no significant fin...
2,,"(935, 946)",16/02/25,elevated CRP,LINK,1,DATE,ENTITY,308885,308820,,,26462,Labs (27th Sep 2024): anemia. resolving Skin:...
3,,"(339, 307)",22/11/24,multiple_sclerosis,NO_LINK,0,DATE,ENTITY,308581,308809,,,26462,Labs (27th Sep 2024): anemia. resolving Skin:...
4,,"(471, 50)",26/12/24,rashes,NO_LINK,0,DATE,ENTITY,308882,308853,,,26462,Labs (27th Sep 2024): anemia. resolving Skin:...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
144,,"(588, 429)",23rd Oct 2024,CLINIC,NO_LINK,0,DATE,ENTITY,308323,308297,,,26461,Ultrasound (30nd Jun 2024): no significant fin...
145,,"(443, 854)",16 Sep'24,mild,NO_LINK,0,DATE,ENTITY,308322,308438,,,26461,Ultrasound (30nd Jun 2024): no significant fin...
146,,"(935, 641)",16/02/25,multiple_sclerosis,NO_LINK,0,DATE,ENTITY,308885,308810,,,26462,Labs (27th Sep 2024): anemia. resolving Skin:...
147,,"(12, 328)",30nd Jun 2024,reveals,NO_LINK,0,DATE,ENTITY,308320,308284,,,26461,Ultrasound ([s1]30nd Jun 2024[e1]): no signifi...


In [46]:
# Run inference on all test documents using the same data source as training
all_predictions = []

test_doc_ids = test_df['doc_id'].unique()

for doc_id in test_doc_ids:
    # Find the document in medcat_df (same source as training)
    doc_row = medcat_df[medcat_df['doc_id'] == doc_id].iloc[0]
    
    # Parse the same JSON columns used in training
    entities = json.loads(doc_row["entities_json"])
    dates = json.loads(doc_row["dates_json"])
    
    # Create annotations in the same format as training
    annotations = []
    for entity in entities:
        annotations.append({
            "value": entity["value"],
            "cui": entity.get("cui"),
            "start": entity.get("start"),
            "end": entity.get("end")
        })
    for date in dates:
        annotations.append({
            "value": date["value"],
            "cui": DATE_CUI,
            "start": date.get("start"),
            "end": date.get("end")
        })
    
    try:
        # Run inference
        output_doc_with_relations = relCAT.predict_text_with_anns(
            text=doc_row["note_text"], 
            annotations=annotations
        )
        
        # Collect results - only keep date-entity pairs
        for relation in output_doc_with_relations._.relations:
            # Check if this is a date-entity pair (not entity-entity)
            if (relation['ent1_text'] in [d['value'] for d in dates] and 
                relation['ent2_text'] in [e['value'] for e in entities]) or \
               (relation['ent2_text'] in [d['value'] for d in dates] and 
                relation['ent1_text'] in [e['value'] for e in entities]):
                
                all_predictions.append({
                    'entity_label': relation['ent1_text'],
                    'date': relation['ent2_text'],
                    'confidence': relation['confidence'],
                    'doc_id': doc_id
                })
                
    except Exception as e:
        print(f"Error processing document {doc_id}: {e}")
        continue

print(f"Processed {len(test_doc_ids)} test documents")
print(f"Total predictions: {len(all_predictions)}")

INFO:medcat.rel_cat:total relations for doc: 45
INFO:medcat.rel_cat:processing...
100%|██████████| 45/45 [00:08<00:00,  5.24it/s]


Error processing document 26461: min() arg is an empty sequence


INFO:medcat.rel_cat:total relations for doc: 55
INFO:medcat.rel_cat:processing...
100%|██████████| 55/55 [00:10<00:00,  5.31it/s]
INFO:medcat.rel_cat:total relations for doc: 20
INFO:medcat.rel_cat:processing...


Error processing document 26465: min() arg is an empty sequence


100%|██████████| 20/20 [00:03<00:00,  5.62it/s]

Processed 5 test documents
Total predictions: 31





In [47]:
# Show results
print("Test Set Results:")
print(f"Total predictions: {len(all_predictions)}")

# Show first 10 predictions
print("\nFirst 10 predictions:")
for i, pred in enumerate(all_predictions[:10]):
    print(f"{i+1}. {pred['entity_label']} -> {pred['date']} (conf: {pred['confidence']:.3f}) [doc: {pred['doc_id']}]")

# Show high confidence predictions
high_conf = [p for p in all_predictions if p['confidence'] > 0.7]
print(f"\nHigh confidence predictions (>0.7): {len(high_conf)}")
for i, pred in enumerate(high_conf[:5]):  # Show first 5
    print(f"{i+1}. {pred['entity_label']} -> {pred['date']} (conf: {pred['confidence']:.3f})")

Test Set Results:
Total predictions: 31

First 10 predictions:
1. frequent urination -> 07.05.25 (conf: 0.911) [doc: 26463]
2. diabetes_mellitus -> 23rd Feb 2025 (conf: 0.852) [doc: 26463]
3. osteoarthritis -> 21-11-2024 (conf: 0.889) [doc: 26463]
4. weight gain -> 23rd Feb 2025 (conf: 0.863) [doc: 26463]
5. hemorrhage -> 23rd Feb 2025 (conf: 0.871) [doc: 26463]
6. PET scan -> 21-11-2024 (conf: 0.864) [doc: 26463]
7. PET scan -> 23rd Feb 2025 (conf: 0.832) [doc: 26463]
8. anemia -> 07.05.25 (conf: 0.848) [doc: 26463]
9. cough -> 21-11-2024 (conf: 0.858) [doc: 26463]
10. cough -> 21-11-2024 (conf: 0.859) [doc: 26463]

High confidence predictions (>0.7): 31
1. frequent urination -> 07.05.25 (conf: 0.911)
2. diabetes_mellitus -> 23rd Feb 2025 (conf: 0.852)
3. osteoarthritis -> 21-11-2024 (conf: 0.889)
4. weight gain -> 23rd Feb 2025 (conf: 0.863)
5. hemorrhage -> 23rd Feb 2025 (conf: 0.871)


In [48]:
# Let's debug the exact counts
print(f"Total test pairs: {len(test_df)}")
print(f"Total predictions: {len(all_predictions)}")

# Count how many test pairs were actually predicted
predicted_count = 0
for _, row in test_df.iterrows():
    found = False
    for pred in all_predictions:
        if (pred['doc_id'] == row['doc_id'] and 
            ((pred['entity_label'] == row['ent1'] and pred['date'] == row['ent2']) or
             (pred['entity_label'] == row['ent2'] and pred['date'] == row['ent1']))):
            found = True
            break
    if found:
        predicted_count += 1

print(f"Test pairs that were predicted: {predicted_count}")
print(f"Test pairs that were NOT predicted: {len(test_df) - predicted_count}")

# Also check if there are predictions for documents not in test set
test_doc_ids = set(test_df['doc_id'].unique())
pred_doc_ids = set([p['doc_id'] for p in all_predictions])
print(f"Test doc IDs: {test_doc_ids}")
print(f"Prediction doc IDs: {pred_doc_ids}")
print(f"Extra predictions (not in test): {len(pred_doc_ids - test_doc_ids)}")

Total test pairs: 149
Total predictions: 31
Test pairs that were predicted: 7
Test pairs that were NOT predicted: 142
Test doc IDs: {26464, 26465, 26461, 26462, 26463}
Prediction doc IDs: {26464, 26462, 26463}
Extra predictions (not in test): 0


In [49]:
# Create predictions for all test pairs
all_test_predictions = []

for _, row in test_df.iterrows():
    # Check if this pair was predicted as a link
    found = False
    for pred in all_predictions:
        if (pred['doc_id'] == row['doc_id'] and 
            ((pred['entity_label'] == row['ent1'] and pred['date'] == row['ent2']) or
             (pred['entity_label'] == row['ent2'] and pred['date'] == row['ent1']))):
            all_test_predictions.append('LINK')
            found = True
            break
    
    if not found:
        all_test_predictions.append('NO_LINK')

# Now calculate metrics on all test pairs
y_true_all = test_df['label'].tolist()
y_pred_all = all_test_predictions

print(f"\nAll Test Pairs Metrics:")
print(f"Accuracy: {sum(1 for t, p in zip(y_true_all, y_pred_all) if t == p) / len(y_true_all):.3f}")
print(classification_report(y_true_all, y_pred_all, labels=['LINK', 'NO_LINK']))


All Test Pairs Metrics:
Accuracy: 0.919
              precision    recall  f1-score   support

        LINK       0.14      0.14      0.14         7
     NO_LINK       0.96      0.96      0.96       142

    accuracy                           0.92       149
   macro avg       0.55      0.55      0.55       149
weighted avg       0.92      0.92      0.92       149

