# Training and testing using cross-validation
This notebook uses predefined subsets of examples to train and test models.|

In [1]:
import json
import numpy as np
import pandas as pd
from utils import evaluate_per_example
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBPE
from medcat.config_meta_cat import ConfigMetaCAT
from medcat.meta_cat import MetaCAT
from pathlib import Path

In [2]:
# Input
data_dir = Path.cwd().parents[0] / 'data'
annotation_file = data_dir / 'emc-dcc_ann.json'
split_list_file = data_dir / 'split_list.json'
model_dir = Path.cwd().parents[0] / 'models' / 'bilstm'
embeddings_file = model_dir / 'embeddings.npy'

# Output
annotations_split_dir = data_dir / 'annotations_split'
models_split_dir = model_dir / 'model_splits'
result_dir = Path.cwd().parents[0] / 'results'
score_result_file = result_dir / 'bilstm_scores_cv.csv.gz'
predictions_result_file = result_dir / 'bilstm_predictions_cv.csv.gz'

# Create output dirs
annotations_split_dir.mkdir(exist_ok=True)
models_split_dir.mkdir(exist_ok=True)

# Configure MetaCAT
config_metacat = ConfigMetaCAT()
config_metacat.general['category_name'] = 'Negation'
config_metacat.train['nepochs'] = 10

## Load tokenizer and embeddings matrix
Load a project-wide tokenizer and embeddings matrix which are created in `01_tokenizer_embeddings.ipynb`.

In [3]:
tokenizer = TokenizerWrapperBPE.load(model_dir)
embeddings = np.load(embeddings_file)

## Split DCC file into smaller train and test files
Using the code in `utils/dcc_splitter.py` we split the data into 10 different folds. These splits are saved in `data/split_list.json`.\

In [4]:
# Load complete DCC data, which is in the MedCAT Trainer annotation format
with open(annotation_file) as f:
    annotations = json.load(f)
    
# Load the splits
with open(split_list_file) as f:
    split_lists = json.load(f)

# For each split, create train and test file
for split_list in split_lists:
    train_annotations = []
    test_annotations = []

    for document in annotations['projects'][0]['documents']:
        if document['name'] in split_list['train']:
            train_annotations.append(document)
        elif document['name'] in split_list['test']:
            test_annotations.append(document)
    #     else:
    #         print(f'{document["name"]} not found in either train or test')

    # Create an annotation file for the split following MetaCAT's annotation format
    project_train_annotations = {'projects': [{'documents': train_annotations}]}
    project_test_annotations = {'projects': [{'documents': test_annotations}]}

    # Write output files
    train_output_file = annotations_split_dir / f'train_annotations_{split_list["split_id"]}.json'
    with open(train_output_file, "w") as fp:
        json.dump(project_train_annotations, fp)

    test_output_file = annotations_split_dir / f'test_annotations_{split_list["split_id"]}.json'
    with open(test_output_file, "w") as fp:
        json.dump(project_test_annotations, fp)

## Train and test on folds
Per fold, a MetaCAT model is trained and tested. Testing is done using MetaCAT's eval() function, which contains functionality to evaluate the model on a testset and returns a dictionary with scores and examples, but does not include the example ID, which we use to compare examples between different methods. Therefor we use a different evaluation function later in this notebook.

In [5]:
# List to store results of individual folds
score_result_list = []

for train_file in annotations_split_dir.rglob("train_annotations_*.json"):
    print(train_file)
    split_id = train_file.stem.split('_')[2]
    split_id_dir = models_split_dir / split_id
    split_id_dir.mkdir(exist_ok=True)
    
    # Initiate MetaCAT
    meta_cat = MetaCAT(tokenizer=tokenizer, embeddings=embeddings, config=config_metacat)
    
    # Train model
    train_results = meta_cat.train(json_path=train_file, save_dir_path=str(split_id_dir))
    
    # Evaluate using MetaCAT's eval function
    test_file = train_file.parent / train_file.name.replace('train_annotations_', 'test_annotations_')
    test_results = meta_cat.eval(json_path=test_file)
    
    # Save test results
    score_result_list.append([split_id,
                              round(test_results['f1'], 2),
                              len(test_results['examples']['TP']['negated']),
                              len(test_results['examples']['FP']['negated']),
                              len(test_results['examples']['FN']['negated'])])

D:\Repositories\negation-detection\data\annotations_split\train_annotations_0.json
Epoch: 0 **************************************************  Train
              precision    recall  f1-score   support

           0       0.95      0.98      0.97      8722
           1       0.87      0.69      0.77      1402

    accuracy                           0.94     10124
   macro avg       0.91      0.84      0.87     10124
weighted avg       0.94      0.94      0.94     10124

Epoch: 0 **************************************************  Test
              precision    recall  f1-score   support

           0       0.96      0.99      0.97       955
           1       0.92      0.74      0.82       169

    accuracy                           0.95      1124
   macro avg       0.94      0.86      0.90      1124
weighted avg       0.95      0.95      0.95      1124


##### Model saved to D:\Repositories\negation-detection\models\bilstm\model_splits\0\model.dat at epoch: 0 and f1: 0.948835988798

Epoch: 0 **************************************************  Eval
              precision    recall  f1-score   support

           0       0.99      0.98      0.98      1114
           1       0.88      0.92      0.90       189

    accuracy                           0.97      1303
   macro avg       0.93      0.95      0.94      1303
weighted avg       0.97      0.97      0.97      1303

D:\Repositories\negation-detection\data\annotations_split\train_annotations_1.json
Epoch: 0 **************************************************  Train
              precision    recall  f1-score   support

           0       0.95      0.98      0.97      8712
           1       0.88      0.69      0.77      1439

    accuracy                           0.94     10151
   macro avg       0.92      0.84      0.87     10151
weighted avg       0.94      0.94      0.94     10151

Epoch: 0 **************************************************  Test
              precision    recall  f1-score   support

         

Epoch: 0 **************************************************  Eval
              precision    recall  f1-score   support

           0       0.98      1.00      0.99      1099
           1       0.99      0.87      0.93       174

    accuracy                           0.98      1273
   macro avg       0.99      0.94      0.96      1273
weighted avg       0.98      0.98      0.98      1273

D:\Repositories\negation-detection\data\annotations_split\train_annotations_2.json
Epoch: 0 **************************************************  Train
              precision    recall  f1-score   support

           0       0.95      0.98      0.97      8827
           1       0.86      0.70      0.77      1436

    accuracy                           0.94     10263
   macro avg       0.91      0.84      0.87     10263
weighted avg       0.94      0.94      0.94     10263

Epoch: 0 **************************************************  Test
              precision    recall  f1-score   support

         

Epoch: 9 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8827
           1       0.99      0.99      0.99      1436

    accuracy                           1.00     10263
   macro avg       0.99      0.99      0.99     10263
weighted avg       1.00      1.00      1.00     10263

Epoch: 9 **************************************************  Test
              precision    recall  f1-score   support

           0       0.98      0.99      0.99       991
           1       0.94      0.87      0.90       149

    accuracy                           0.98      1140
   macro avg       0.96      0.93      0.94      1140
weighted avg       0.98      0.98      0.97      1140

Epoch: 0 **************************************************  Eval
              precision    recall  f1-score   support

           0       0.98      0.99      0.99       973
           1       0.96      0.90     

Epoch: 8 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8745
           1       0.99      0.99      0.99      1422

    accuracy                           1.00     10167
   macro avg       0.99      0.99      0.99     10167
weighted avg       1.00      1.00      1.00     10167

Epoch: 8 **************************************************  Test
              precision    recall  f1-score   support

           0       0.99      0.99      0.99       958
           1       0.92      0.95      0.93       171

    accuracy                           0.98      1129
   macro avg       0.96      0.97      0.96      1129
weighted avg       0.98      0.98      0.98      1129

Epoch: 9 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8745
           1       0.99      0.98    

Epoch: 8 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8699
           1       0.99      0.98      0.98      1429

    accuracy                           1.00     10128
   macro avg       0.99      0.99      0.99     10128
weighted avg       1.00      1.00      1.00     10128

Epoch: 8 **************************************************  Test
              precision    recall  f1-score   support

           0       0.98      0.99      0.99       968
           1       0.94      0.88      0.91       157

    accuracy                           0.98      1125
   macro avg       0.96      0.93      0.95      1125
weighted avg       0.97      0.98      0.97      1125

Epoch: 9 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8699
           1       0.99      0.99    

Epoch: 8 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8701
           1       0.99      0.99      0.99      1412

    accuracy                           1.00     10113
   macro avg       0.99      0.99      0.99     10113
weighted avg       1.00      1.00      1.00     10113

Epoch: 8 **************************************************  Test
              precision    recall  f1-score   support

           0       0.99      0.97      0.98       959
           1       0.86      0.92      0.89       164

    accuracy                           0.97      1123
   macro avg       0.92      0.95      0.93      1123
weighted avg       0.97      0.97      0.97      1123

Epoch: 9 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8701
           1       0.99      0.99    

Epoch: 8 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8777
           1       0.98      0.99      0.99      1447

    accuracy                           1.00     10224
   macro avg       0.99      0.99      0.99     10224
weighted avg       1.00      1.00      1.00     10224

Epoch: 8 **************************************************  Test
              precision    recall  f1-score   support

           0       0.98      0.99      0.99       961
           1       0.96      0.90      0.93       174

    accuracy                           0.98      1135
   macro avg       0.97      0.95      0.96      1135
weighted avg       0.98      0.98      0.98      1135

Epoch: 9 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8777
           1       0.98      0.99    

Epoch: 8 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8739
           1       0.99      0.99      0.99      1413

    accuracy                           1.00     10152
   macro avg       0.99      0.99      0.99     10152
weighted avg       1.00      1.00      1.00     10152

Epoch: 8 **************************************************  Test
              precision    recall  f1-score   support

           0       0.99      0.99      0.99       968
           1       0.94      0.92      0.93       159

    accuracy                           0.98      1127
   macro avg       0.96      0.96      0.96      1127
weighted avg       0.98      0.98      0.98      1127

Epoch: 9 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8739
           1       0.99      0.99    

Epoch: 8 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8763
           1       0.99      0.99      0.99      1416

    accuracy                           1.00     10179
   macro avg       0.99      0.99      0.99     10179
weighted avg       1.00      1.00      1.00     10179

Epoch: 8 **************************************************  Test
              precision    recall  f1-score   support

           0       0.99      0.98      0.98       969
           1       0.89      0.93      0.91       161

    accuracy                           0.97      1130
   macro avg       0.94      0.95      0.94      1130
weighted avg       0.97      0.97      0.97      1130

Epoch: 9 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8763
           1       0.99      0.99    

Epoch: 8 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8734
           1       0.99      0.99      0.99      1435

    accuracy                           1.00     10169
   macro avg       0.99      0.99      0.99     10169
weighted avg       1.00      1.00      1.00     10169

Epoch: 8 **************************************************  Test
              precision    recall  f1-score   support

           0       0.99      0.98      0.99       991
           1       0.89      0.96      0.92       138

    accuracy                           0.98      1129
   macro avg       0.94      0.97      0.96      1129
weighted avg       0.98      0.98      0.98      1129

Epoch: 9 **************************************************  Train
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      8734
           1       0.98      0.99    

## Gather scores from folds
In this section, results are gathered from the folds and saved in a single CSV.

Currently, recall and precision are not returned by MetaCAT's eval() function. A future release will add this functionality (https://github.com/CogStack/MedCAT/pull/172).

In [6]:
def calculate_recall(row):
    tp = row.tp
    fp = row.fp
    fn = row.fn
    recall = round(tp / (tp + fn), 2)
    return recall

def calculate_precision(row):
    tp = row.tp
    fp = row.fp
    fn = row.fn
    precision = round(tp / (tp + fp), 2)
    return precision

score_results = pd.DataFrame(score_result_list, columns=['split_id', 'f1', 'tp', 'fp', 'fn'])
score_results['recall'] = score_results.apply(calculate_recall, axis=1)
score_results['precision'] = score_results.apply(calculate_precision, axis=1)
score_results.to_csv(score_result_file, index=False, compression='gzip')
score_results

Unnamed: 0,split_id,f1,tp,fp,fn,recall,precision
0,0,0.97,174,24,15,0.92,0.88
1,1,0.98,152,1,22,0.87,0.99
2,2,0.98,158,6,17,0.9,0.96
3,3,0.98,150,8,17,0.9,0.95
4,4,0.98,154,8,20,0.89,0.95
5,5,0.97,162,11,22,0.88,0.94
6,6,0.98,129,10,10,0.93,0.93
7,7,0.97,163,14,25,0.87,0.92
8,8,0.98,173,13,10,0.95,0.93
9,9,0.97,168,13,19,0.9,0.93


## Custom evaluation per example per fold
In this project we are interested per example whether a negation has been correctly predicted or not. MetaCAT does not have such functionality; it only returns scores, predictions and examples.

In this section we iterate through all annotations from an annotation file (MedCAT Trainer format), create an ID for every example (`exampleID = documentID_start_end`), collect the prediction per example and save all predictions in a CSV.

In [7]:
# Evaluate models on their respective test sets
predictions_on_test_list = []
for annotation_filename in annotations_split_dir.rglob("test_annotations_*.json"):
    
    # Extract split ID
    split_id = annotation_filename.stem.split('_')[2]
    split_id_dir = models_split_dir / split_id
    print(f'Evaluating test set {split_id}')
    
    # Load MetaCAT model
    meta_cat = MetaCAT.load(split_id_dir)
    
    # Gather the predictions on every example in the provided annotation file.
    predictions_on_test_list.append(evaluate_per_example(annotation_filename, meta_cat, f'bilstm_cv'))
    
# Save al predictions in a single dataframe
predictions_on_test_df = pd.DataFrame(columns=['entity_id', 'bilstm_cv'])
for i in predictions_on_test_list:
    predictions_on_test_df = predictions_on_test_df.append(i)

# Save predictions in a csv
predictions_on_test_df.reset_index(drop=True, inplace=True)
predictions_on_test_df.to_csv(predictions_result_file, index=False, compression='gzip', line_terminator='/n')
predictions_on_test_df

Evaluating test set 0
Evaluating test set 1
Evaluating test set 2
Evaluating test set 3
Evaluating test set 4
Evaluating test set 5
Evaluating test set 6
Evaluating test set 7
Evaluating test set 8
Evaluating test set 9


Unnamed: 0,entity_id,bilstm_cv
0,DL1111_32_46,not negated
1,DL1111_272_280,not negated
2,DL1111_363_377,not negated
3,DL1116_32_41,not negated
4,DL1116_137_148,not negated
...,...,...
12546,SP2100_201_212,not negated
12547,SP2100_294_304,not negated
12548,SP2107_87_92,not negated
12549,SP2108_22_30,not negated
