The current implementation offers support for HF LLama models and BERT models.
We will cover only BERT in this section as the Llama usage is the same, just different imports.

In [None]:
# Install medcat
! pip install "medcat[spacy,rel-cat,meta-cat]~=2.1.0" # NOTE: VERSION-STRING

Collecting medcat@ git+https://github.com/CogStack/cogstack-nlp@medcat/v0.11.2#subdirectory=medcat-v2 (from medcat[meta-cat,spacy]@ git+https://github.com/CogStack/cogstack-nlp@medcat/v0.11.2#subdirectory=medcat-v2)
  Cloning https://github.com/CogStack/cogstack-nlp (to revision medcat/v0.11.2) to /private/var/folders/h4/sklqg_zx1dbbbx76m2__zb8h0000gn/T/pip-install-7r4on_8p/medcat_c4a76da1eaa7411a9ff529c5127bf9eb
  Running command git clone --filter=blob:none --quiet https://github.com/CogStack/cogstack-nlp /private/var/folders/h4/sklqg_zx1dbbbx76m2__zb8h0000gn/T/pip-install-7r4on_8p/medcat_c4a76da1eaa7411a9ff529c5127bf9eb
  Running command git checkout -q b1ce30ba716ff7c1f3b912085ca02026b6de3f22
  Resolved https://github.com/CogStack/cogstack-nlp to commit b1ce30ba716ff7c1f3b912085ca02026b6de3f22
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone


In [None]:
import logging
from medcat.cdb import CDB
from medcat.config.config_rel_cat import ConfigRelCAT
from medcat.components.addons.relation_extraction.rel_cat import RelCAT
from medcat.components.addons.relation_extraction.base_component import RelExtrBaseComponent
from medcat.components.addons.relation_extraction.bert.model import RelExtrBertModel
from medcat.components.addons.relation_extraction.bert.config import RelExtrBertConfig
from medcat.components.addons.relation_extraction.tokenizer import BaseTokenizerWrapper
from medcat.config import Config
from medcat.tokenizing.tokenizers import create_tokenizer

  from .autonotebook import tqdm as notebook_tqdm


<h1>Training RelCAT models with custom datasets from scratch.</h1>
<h2>1. create the RelCAT config and set the parameters</h2>

In [3]:
config = ConfigRelCAT()
config.general.log_level = logging.INFO
config.general.model_name = "bert-base-uncased" # base model that you want to use, we're going to use the HuggingFace bert-base-uncased model

<h3> 1.1 Based on what model you use, you might want to keep an eye on config.model.hidden_size, config.model.model_size and config.model.hidden_layers</h3>

In [4]:
config.model.hidden_size= 256
config.model.model_size = 2304 # 4096 for llama

<h3> 1.2 Other notable configurations</h3>

In [5]:
config.general.cntx_left = 15 # how many tokens to the left of the start entity we select
config.general.cntx_right = 15 # how many tokens to the right of the end entity we selecd
config.general.window_size = 300 # distance (in characters) between two entities to be considered a relation
config.train.nclasses = 2 # number of classes in your medcat export / dataset
config.train.nepochs = 10 # number of epochs to train for
config.model.freeze_layers = False # whether to freeze the layers of the base model
config.general.limit_samples_per_class = 300 # limit the number of training samples per class to this number, to avoid overfitting in unbalanced datasets
config.train.batch_size = 32 # batch size
config.train.lr = 3e-5
config.train.adam_epsilon = 1e-8
config.train.adam_weight_decay = 0.0005

<h2>2. create a CDB, it can be a CDB from another model of your choice or an empty one.
The CDB is used only when filtering by concept unique identifiers (CUI) or concept type ids (TUI).

In [7]:
gen_cnf = Config()
gen_cnf.general.nlp.provider = 'spacy'
cdb = CDB(gen_cnf)
base_tokenizer = create_tokenizer(gen_cnf.general.nlp.provider, gen_cnf)

Collecting en-core-web-md==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl (33.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.5/33.5 MB[0m [31m86.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: en-core-web-md
Successfully installed en-core-web-md-3.8.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_md')


<h2>3. Create a tokenizer

In [8]:
tokenizer = BaseTokenizerWrapper.load(tokenizer_path=config.general.model_name,
                                                                           relcat_config=config)   

<h2>4. Add token tags to tokenizer.
 This step is optional because the [s1], [e1], [s2], [e2] tags are already located in the default RelCATConfig.
 If you are using a LLama based model, you will need to add the [PAD] token to the tokenizer, as shown below.

In [9]:
special_ent_tokens = ["[s1]", "[e1]", "[s2]", "[e2]"]
tokenizer.hf_tokenizers.add_tokens(special_ent_tokens, special_tokens=True)
tokenizer.hf_tokenizers.add_special_tokens({'pad_token': '[PAD]'}) # used in llama tokenizer

0

<h2>5. Add tokens to the RelCATConfig

In [10]:
config.general.tokenizer_relation_annotation_special_tokens_tags = special_ent_tokens
config.general.annotation_schema_tag_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(special_ent_tokens)

<h2>6. Create the relCAT object and initialize its components</h2>

In [12]:
# if you wish to skip the steps in section 6.1 you can pass the init_model=True arguement to intialize the components with the default ConfigRelCAT settings.
relCAT = RelCAT(base_tokenizer, cdb, config=config)

INFO:medcat.components.addons.relation_extraction.base_component:RelExtrBaseComponent initialized


<h3>6.1 Use the BaseComponent object, this one holds the tokenizer, model and model config. We will have to initialize each component beforehand.</h3>

<p>Resize token embeddings since we added the tokens before, this should be done after adding tokens to the tokenizer. It is not required after creating and saving/loading a model as the value will be retained.</p>

In [14]:
model_config = RelExtrBertConfig.load(pretrained_model_name_or_path=config.general.model_name,
                                                                   relcat_config=config)

# update the model config with the proper vocab size, since we added special tokens to the tokenizer
model_config.hf_model_config.vocab_size = tokenizer.get_size()

# set the padding idx in the model config and relcat config, this is necesasry as it depends on what tokenizer you use
config.model.padding_idx = model_config.pad_token_id = tokenizer.get_pad_id()

model = RelExtrBertModel.load(pretrained_model_name_or_path=config.general.model_name,
                                                                   model_config=model_config,
                                                                   relcat_config=config)

# we have to update the model to reflect the new token embeddings, since we added special tokens to the tokenizer
model.hf_model.resize_token_embeddings(len(tokenizer.hf_tokenizers)) # type: ignore

component = RelExtrBaseComponent(tokenizer=tokenizer, config=config)
component.model = model
component.model_config = model_config
component.relcat_config = config
component.tokenizer = tokenizer

relCAT.component = component

INFO:medcat.components.addons.relation_extraction.bert.config:Loaded config from pretrained: bert-base-uncased
INFO:medcat.components.addons.relation_extraction.models:RelCAT model config: BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.53.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30526
}

INFO:medcat.components.addons.relation_extraction.models:RelCAT model config: BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": n

<h2> 7. Train the model from the ADE dataset. </h2>

In [15]:
! rm -rf "./ade_relcat_model"
! mkdir -p "./ade_relcat_model"

In [17]:
relCAT.train(train_csv_path="./data/rel_cat_ADE_V2.tsv", checkpoint_path="./ade_relcat_model")

# for MedCAT Trainer Exports, use the export_path argument : relCAT.train(export_data_path="./data/MedCAT_Export_relation_extraction.json")


INFO:medcat.components.addons.relation_extraction.rel_dataset:CSV dataset | No. of relations detected: 7093 | from : ./data/rel_cat_ADE_V2.tsv | nclasses: 2 | idx2label: {0: 'DRUG-AE', 1: 'DRUG-DOSE'}
INFO:medcat.components.addons.relation_extraction.rel_dataset:Samples per class: 
INFO:medcat.components.addons.relation_extraction.rel_dataset: label: DRUG-AE | samples: 6814
INFO:medcat.components.addons.relation_extraction.rel_dataset: label: DRUG-DOSE | samples: 279
INFO:root:Relations after train, test split :  train - 524 | test - 115
INFO:root: label: DRUG-AE samples | train 300 | test 60
INFO:root: label: DRUG-DOSE samples | train 224 | test 55
INFO:root:Attempting to load RelCAT model on device: cpu
INFO:medcat.components.addons.relation_extraction.rel_cat:Starting training process...
INFO:medcat.components.addons.relation_extraction.rel_cat:Total epochs on this model: 10 | currently training epoch 0
100%|██████████| 524/524 [00:29<00:00, 17.52it/s]
INFO:medcat.components.addons.

In [18]:
# save the model
relCAT.save(save_path="./ade_relcat_model")