In [20]:
from datetime import datetime
import pytz
import os
import pandas as pd

task = "ner"
base_model = 'PlanTL-GOB-ES/roberta-base-biomedical-clinical-es'
model_name = "roberta-base-biomedical-clinical-es"
output_path = 'output/models'
model_checkpoint = f"{output_path}/{model_name}-finetuned-{task}"
num_labels = 3
target_label = 'ENFERMEDAD'

sentences_file = 'custom_dataset_track1/dev_sentences.tsv'
result_file = 'multicardioner_track1_cardioccc_dev_predictions.tsv'

In [21]:
df_dev_sentences = pd.read_csv(sentences_file, sep='\t')
df_dev_sentences.head()

Unnamed: 0,filename,batch_number,batch_start,text
0,casos_clinicos_cardiologia10,1,0,Anamnesis\n
1,casos_clinicos_cardiologia10,2,10,"Sexo masculino, 79 años.\n"
2,casos_clinicos_cardiologia10,3,35,Autoválido.\n
3,casos_clinicos_cardiologia10,4,47,Procedente de Salto.\n
4,casos_clinicos_cardiologia10,5,68,Antecedentes patológicos: –Hipertensión arteri...


In [3]:
import re
TOKENIZATION_REGEX = re.compile(r'([0-9\w]+|[^0-9\w])', re.UNICODE)

In [4]:
def tokenize(text):
    original_token_offsets = []

    offset = 0
    new_offset = 0
    nonspace_token_seen = False

    tokens = [t for t in TOKENIZATION_REGEX.split(text) if t]
    for t in tokens:
        if not t.isspace():
            original_token_offsets.append([offset, offset + len(t), t, new_offset, new_offset + len(t)])
            nonspace_token_seen = True
            new_offset += len(t) + 1
        offset += len(t)
        

    tokenized_sentence = ' '.join([l[2] for l in original_token_offsets])

    # store original token offsets
    # pass the tokenized string for prediction
    return tokenized_sentence, original_token_offsets

In [5]:
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline, TokenClassificationPipeline

tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

In [6]:
ner_pipe = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple", stride=0, pipeline_class=TokenClassificationPipeline)

In [7]:
# group annotations around a clinical procedure mention, based on the annotation label
def group_annotations_strict(annotations):
    groups = []
    i = 0
    while i < len(annotations):
        if annotations[i]['entity_group'] == 'LABEL_0':
            i += 1
            continue

        group = [] # for the strict strategy, a group is a B (or many Bs), followed by 1 or more Is
        if annotations[i]['entity_group'] == 'LABEL_1':
            group.append(annotations[i])
            i += 1

            while (i < len(annotations) and annotations[i]['entity_group'] == 'LABEL_1'):
                group.append(annotations[i])
                i += 1

            while (i < len(annotations) and annotations[i]['entity_group'] == 'LABEL_2'):
                group.append(annotations[i])
                i += 1

            groups.append(group)
        else:
            i+=1
            continue

    return groups

In [8]:
# merge grouped annotations to form a complete entity mention
def merge_annotation_group_entries(annotation_group, sentence):
    start = annotation_group[0]['start']
    end = annotation_group[len(annotation_group) - 1]['end']
    text = sentence[start:end]
    return {'start': start, 'end': end, 'text': text}

In [9]:
def get_mentions(sentence):
    annotation_groups = group_annotations_strict(ner_pipe(sentence))
    return [merge_annotation_group_entries(group, sentence) for group in annotation_groups]

In [10]:
def get_original_mention_offset(mentions, sentence, original_token_offsets, original_sentence, filename, batch_start):
    original_mention_offsets = []
    current_offset_idx = 0 
    for mention in mentions:
        start = mention['start']
        end = mention['end']
        
        original_start = -1
        original_end = -1
        while current_offset_idx < len(original_token_offsets):
            token = original_token_offsets[current_offset_idx]
            
            if token[3] <= start:
                original_start = token[0]
            
            if token[4] >= end:
                original_end = token[1]
                break
            
            current_offset_idx += 1

        sentence_no_spaces = sentence[start:end].replace(' ', '')
        original_sentence_no_spaces = original_sentence[original_start:original_end].replace(' ', '')
        # check whether the detected span is contained in the original
        if sentence_no_spaces != original_sentence_no_spaces and not(sentence_no_spaces in original_sentence_no_spaces):
            print('potential offset issue ', filename, sentence[start:end], original_sentence[original_start:original_end])
        if original_start == -1 or original_end == -1:
            print('mention not found ', filename, mention)
            
        original_mention_offsets.append({
            'filename': filename, 
            'start_span':original_start+batch_start, 
            'end_span':original_end+batch_start, 
            'text': original_sentence[original_start:original_end]
        })
    
    return original_mention_offsets

In [11]:
text = "• ECG: bloqueo aurículo-ventricular (BAV) de primer grado"

In [12]:
tokenized_sentence, original_token_offsets = tokenize(text)
tokenized_sentence

'• ECG : bloqueo aurículo - ventricular ( BAV ) de primer grado'

In [13]:
ner_pipe(tokenized_sentence)

[{'entity_group': 'LABEL_0',
  'score': 0.9996733,
  'word': ' • ECG :',
  'start': 0,
  'end': 7},
 {'entity_group': 'LABEL_1',
  'score': 0.99480635,
  'word': ' bloqueo',
  'start': 8,
  'end': 15},
 {'entity_group': 'LABEL_2',
  'score': 0.9984991,
  'word': ' aurículo - ventricular',
  'start': 16,
  'end': 38},
 {'entity_group': 'LABEL_0',
  'score': 0.8087198,
  'word': ' (',
  'start': 39,
  'end': 40},
 {'entity_group': 'LABEL_1',
  'score': 0.96049446,
  'word': ' BAV',
  'start': 41,
  'end': 44},
 {'entity_group': 'LABEL_2',
  'score': 0.99676716,
  'word': ' ) de primer grado',
  'start': 45,
  'end': 62}]

In [14]:
mentions = get_mentions(tokenized_sentence)
mentions

[{'start': 8, 'end': 38, 'text': 'bloqueo aurículo - ventricular'},
 {'start': 41, 'end': 62, 'text': 'BAV ) de primer grado'}]

In [15]:
get_original_mention_offset(mentions, tokenized_sentence, original_token_offsets, text, 'filename', 10)

[{'filename': 'filename',
  'start_span': 17,
  'end_span': 45,
  'text': 'bloqueo aurículo-ventricular'},
 {'filename': 'filename',
  'start_span': 47,
  'end_span': 67,
  'text': 'BAV) de primer grado'}]

In [16]:
original_mentions_list =[]
for index, row in df_dev_sentences.iterrows():
    text = row['text'].rstrip()
    tokenized_sentence, original_token_offsets = tokenize(text)
    mentions = get_mentions(tokenized_sentence)
    original_mentions = get_original_mention_offset(mentions, tokenized_sentence, original_token_offsets, text, row['filename'], row['batch_start'])
    original_mentions_list.extend(original_mentions)
    
    if (index+1) % 500 == 0:
        print(f'processed {index+1} sentences')

processed 500 sentences
processed 1000 sentences
processed 1500 sentences
processed 2000 sentences
processed 2500 sentences
processed 3000 sentences
processed 3500 sentences
processed 4000 sentences
processed 4500 sentences
processed 5000 sentences
processed 5500 sentences
processed 6000 sentences
processed 6500 sentences
processed 7000 sentences
processed 7500 sentences
processed 8000 sentences
processed 8500 sentences
processed 9000 sentences
processed 9500 sentences
processed 10000 sentences
processed 10500 sentences
processed 11000 sentences
processed 11500 sentences
processed 12000 sentences
processed 12500 sentences
processed 13000 sentences
processed 13500 sentences
processed 14000 sentences
processed 14500 sentences
processed 15000 sentences
processed 15500 sentences
processed 16000 sentences
processed 16500 sentences
processed 17000 sentences
processed 17500 sentences
processed 18000 sentences
processed 18500 sentences
processed 19000 sentences


In [17]:
df_mentions = pd.DataFrame.from_records(original_mentions_list)
df_mentions.head()

Unnamed: 0,filename,start_span,end_span,text
0,casos_clinicos_cardiologia10,95,124,Hipertensión arterial crónica
1,casos_clinicos_cardiologia10,126,139,Ex-tabaquista
2,casos_clinicos_cardiologia10,142,215,Diabetes mellitus tipo 2 con repercusione...
3,casos_clinicos_cardiologia10,217,238,cardiopatía isquémica
4,casos_clinicos_cardiologia10,240,260,arteriopatía de MMII


In [18]:
df_mentions['label'] = target_label

In [22]:
df_mentions[['filename', 'label', 'start_span', 'end_span', 'text']].to_csv(result_file, sep='\t', index=False)