In [1]:
from transformers import AutoTokenizer
from transformers import pipeline
from transformers import AutoModelForTokenClassification
import pandas as pd
import torch
from tqdm.auto import tqdm
import evaluate
from datetime import datetime
import pytz
import os
import pandas as pd

tqdm.pandas()
metric = evaluate.load('seqeval')
map = {0: "O", 1: "B-FARMACO", 2: "I-FARMACO"}
reverse_map = {'O':0, 'B-FARMACO': 1, 'I-FARMACO': 2}
tokenizer_kwargs = {'padding': True, 'truncation':True, 'max_length':512}

task = "ner"
#base_model = 'aaaksenova/xlmr_medical'
#output_path = 'output/models'
#model_checkpoint = f"aaaksenova/xlmr_drug_classifier"
num_labels = 3
target_label = 'FARMACO'

#aaaksenova/xlmr_medical - for just xlmr
#aaaksenova/xlmr_drug_classifier + aaaksenova/xlmr_medical for filtering + xlmr
#Spanish: aaaksenova/SpanishRoberta_multicardioner , 
#English: aaaksenova/BioLinkBert_multicardioner , 
#Italian: aaaksenova/SpanishRoberta_it_medprocner
MODEL_CLASS = "aaaksenova/xlmr_drug_classifier"
MODEL_NER = "aaaksenova/xlmr_medical"
use_filtering = True
test_name = 'test' #dev
lang = 'it'

model_name = MODEL_NER[MODEL_NER.index('/')+1:]
sentences_file = f'data/{test_name}_sentences.tsv'
result_file = f'output/multicardioner_track2_{test_name}_{lang}_predictions_{model_name}_{use_filtering}.tsv'

In [2]:
result_file

'output/multicardioner_track2_test_it_predictions_xlmr_medical_True.tsv'

In [3]:
df_dev_sentences = pd.read_csv(sentences_file, sep='\t', keep_default_na=False)
df_dev_sentences.head()

Unnamed: 0,filename,batch_number,batch_start,text,lang
0,multicardioner_test+bg_1,1,0,Setting: primary care (PC).\n,en
1,multicardioner_test+bg_1,2,28,Reason for consultation: 26-year-old woman who...,en
2,multicardioner_test+bg_1,3,132,She had been seen at the PC emergency centre f...,en
3,multicardioner_test+bg_1,4,197,She explained abdominal pain of 2 weeks' evolu...,en
4,multicardioner_test+bg_1,5,333,Clinical history: personal history of no inter...,en


In [4]:
import re
TOKENIZATION_REGEX = re.compile(r"([0-9\w]+|[^0-9\w])", re.UNICODE) # ’ - es, ' - en, ' - it

In [5]:
def tokenize(text):
    original_token_offsets = []

    offset = 0
    new_offset = 0
    nonspace_token_seen = False

    tokens = [t for t in TOKENIZATION_REGEX.split(text) if t]
    for t in tokens:
        if not t.isspace():
            original_token_offsets.append([offset, offset + len(t), t, new_offset, new_offset + len(t)])
            nonspace_token_seen = True
            new_offset += len(t) + 1
        offset += len(t)
        

    tokenized_sentence = ' '.join([l[2] for l in original_token_offsets])

    # store original token offsets
    # pass the tokenized string for prediction
    return tokenized_sentence, original_token_offsets

In [6]:
classifier = pipeline(task="text-classification", model=MODEL_CLASS, device='cuda')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NER)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NER).to('cuda')

  return self.fget.__get__(instance, owner)()


In [7]:
tokenizer.model_max_length = 512
ner_pipe = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple", stride=0, device='cuda')

In [8]:
ner_pipe('Cefazolin2g c/8 hs iv; gentamicin 3mg/kg/day iv; rifampicin 600 mg c/12 hsvo.')

[{'entity_group': 'FARMACO',
  'score': 0.9957895,
  'word': 'Cefazolin2g',
  'start': 0,
  'end': 11},
 {'entity_group': 'FARMACO',
  'score': 0.9987437,
  'word': 'gentamicin',
  'start': 23,
  'end': 33},
 {'entity_group': 'FARMACO',
  'score': 0.99871254,
  'word': 'rifampicin',
  'start': 49,
  'end': 59}]

In [9]:
# group annotations around a clinical procedure mention, based on the annotation label
def group_annotations_strict(annotations):
    groups = []
    i = 0
    while i < len(annotations):
        if annotations[i]['entity_group'] == 'LABEL_0':
            i += 1
            continue

        group = [] # for the strict strategy, a group is a B (or many Bs), followed by 1 or more Is
        if annotations[i]['entity_group'] == 'LABEL_1':
            group.append(annotations[i])
            i += 1

            while (i < len(annotations) and annotations[i]['entity_group'] == 'LABEL_1'):
                group.append(annotations[i])
                i += 1

            while (i < len(annotations) and annotations[i]['entity_group'] == 'LABEL_2'):
                group.append(annotations[i])
                i += 1

            groups.append(group)
        else:
            i+=1
            continue

    return groups

In [10]:
# merge grouped annotations to form a complete entity mention
def merge_annotation_group_entries(annotation_group, sentence):
    start = annotation_group[0]['start']
    end = annotation_group[len(annotation_group) - 1]['end']
    text = sentence[start:end]
    return {'start': start, 'end': end, 'text': text}

In [11]:
def get_mentions(sentence):
    if use_filtering:
        out = classifier(sentence, **tokenizer_kwargs)
    #print(out[0]['label'] == 'ent')
    if not use_filtering or out[0]['label'] == "ent":
        ner_result = ner_pipe(sentence)
        return [{'start': mention['start'], 'end': mention['end'], 'text': mention['word']} for mention in ner_result]
        #annotation_groups = group_annotations_strict(ner_result)        
        #return [merge_annotation_group_entries(group, sentence) for group in annotation_groups]
    else:
        return []

In [12]:
def get_original_mention_offset(mentions, sentence, original_token_offsets, original_sentence, filename, batch_start):
    original_mention_offsets = []
    current_offset_idx = 0 
    for mention in mentions:
        start = mention['start']
        end = mention['end']
        
        original_start = -1
        original_end = -1
        while current_offset_idx < len(original_token_offsets):
            token = original_token_offsets[current_offset_idx]
            
            if token[3] <= start:
                original_start = token[0]
            
            if token[4] >= end:
                original_end = token[1]
                break
            
            current_offset_idx += 1

        sentence_no_spaces = sentence[start:end].replace(' ', '')
        original_sentence_no_spaces = original_sentence[original_start:original_end].replace(' ', '')
        # check whether the detected span is contained in the original
        if sentence_no_spaces != original_sentence_no_spaces and not(sentence_no_spaces in original_sentence_no_spaces):
            print('potential offset issue ', filename, sentence[start:end], original_sentence[original_start:original_end])
        if original_start == -1 or original_end == -1:
            print('mention not found ', filename, mention)
            
        original_mention_offsets.append({
            'filename': filename, 
            'start_span':original_start+batch_start, 
            'end_span':original_end+batch_start, 
            'text': original_sentence[original_start:original_end]
        })
    
    return original_mention_offsets

In [13]:
text = "Treatment was started with piperacillin tazobactam 4.5g iv every 8h (after taking blood cultures) for 4 days."

In [14]:
tokenized_sentence, original_token_offsets = tokenize(text)
tokenized_sentence

'Treatment was started with piperacillin tazobactam 4 . 5g iv every 8h ( after taking blood cultures ) for 4 days .'

In [15]:
ner_pipe(tokenized_sentence)

[{'entity_group': 'FARMACO',
  'score': 0.9985007,
  'word': 'piperacillin tazobactam',
  'start': 27,
  'end': 50}]

In [16]:
mentions = get_mentions(tokenized_sentence)
mentions

[{'start': 27, 'end': 50, 'text': 'piperacillin tazobactam'}]

In [17]:
get_original_mention_offset(mentions, tokenized_sentence, original_token_offsets, text, 'filename', 10)

[{'filename': 'filename',
  'start_span': 37,
  'end_span': 60,
  'text': 'piperacillin tazobactam'}]

In [18]:
def get_mention_spans(doc, mentions):
    spans = []
    current_offset_idx = 0 
    for mention in mentions:
        start = mention['start']
        end = mention['end']
        
        span_start = -1
        span_end = -1
        while current_offset_idx < len(doc):
            token = doc[current_offset_idx]
            
            if token.idx <= start:
                span_start = token.idx
            
            if token.idx + len(token) >= end:
                span_end = token.idx + len(token)
                break
            
            current_offset_idx += 1
        
        spans.append(doc.char_span(span_start, span_end))
    return spans

In [19]:
from spacy.util import filter_spans
def get_filtered_mentions(doc, phrase_mentions, mentions):
    span_mentions = get_mention_spans(doc, mentions)
    #phrase_mentions.extend(span_mentions)
    filtered_matches = filter_spans(span_mentions)
    return [{'start': doc[match.start].idx,'end': doc[match.end-1].idx + len(doc[match.end-1]), 'text':doc[match.start:match.end]} for match in filtered_matches if len(match) > 0]

In [20]:
original_mentions_list = []
for index, row in df_dev_sentences[df_dev_sentences['lang']==lang].iterrows():
    text = row['text'].rstrip()
    tokenized_sentence, original_token_offsets = tokenize(text)
    #doc = nlp(tokenized_sentence)
    #phrase_mentions = get_phrase_mentions(doc)
    filtered_mentions = get_mentions(tokenized_sentence)
    
    #filtered_mentions = get_filtered_mentions(doc, phrase_mentions, mentions)
    
    original_mentions = get_original_mention_offset(filtered_mentions, tokenized_sentence, original_token_offsets, text, row['filename'], row['batch_start'])
    original_mentions_list.extend(original_mentions)
    
    if (index+1) % 500 == 0:
        print(f'processed {index+1} sentences')



processed 386500 sentences
processed 387000 sentences
processed 387500 sentences
processed 388000 sentences
processed 388500 sentences
processed 389000 sentences
processed 389500 sentences
processed 390000 sentences
processed 390500 sentences
processed 391000 sentences
processed 391500 sentences
processed 392000 sentences
processed 392500 sentences
processed 393000 sentences
processed 393500 sentences
processed 394000 sentences
processed 394500 sentences
processed 395000 sentences
processed 395500 sentences
processed 396000 sentences
processed 396500 sentences
processed 397000 sentences
processed 397500 sentences
processed 398000 sentences
processed 398500 sentences
processed 399000 sentences
processed 399500 sentences
potential offset issue  multicardioner_test+bg_1573 acido clavulanico - acido clavulanico
-
processed 400000 sentences
processed 400500 sentences
processed 401000 sentences
processed 401500 sentences
processed 402000 sentences
processed 402500 sentences
processed 403000 

processed 515000 sentences
processed 515500 sentences
processed 516000 sentences
processed 516500 sentences
processed 517000 sentences
processed 517500 sentences
processed 518000 sentences
processed 518500 sentences
processed 519000 sentences
processed 519500 sentences
processed 520000 sentences
processed 520500 sentences
processed 521000 sentences
potential offset issue  multicardioner_test+bg_655 Ampicillina / sulbactam Claritromicina Ampicillina / sulbactam
Claritromicina
processed 521500 sentences
processed 522000 sentences
processed 522500 sentences
processed 523000 sentences
processed 523500 sentences
processed 524000 sentences
processed 524500 sentences
potential offset issue  multicardioner_test+bg_6691 13 - Fluconazolo 13
- Fluconazolo
potential offset issue  multicardioner_test+bg_6691 Fluconazolo a - Fluconazolo a
-
processed 525000 sentences
processed 525500 sentences
processed 526000 sentences
processed 526500 sentences
processed 527000 sentences
processed 527500 sentences

In [21]:
df_mentions = pd.DataFrame.from_records(original_mentions_list)
df_mentions.head()

Unnamed: 0,filename,start_span,end_span,text
0,multicardioner_test+bg_100,812,820,ossigeno
1,multicardioner_test+bg_100,821,832,umidificato
2,multicardioner_test+bg_100,1250,1258,fentanil
3,multicardioner_test+bg_100,1818,1827,midazolam
4,multicardioner_test+bg_100,1837,1844,morfina


In [22]:
df_mentions['label'] = target_label

In [23]:
df_mentions[['filename', 'label', 'start_span', 'end_span', 'text']].to_csv(result_file, sep='\t', index=False)

In [24]:
# w/ and w/o filtering
# es - |0.8636|0.9028|0.8827 || |0.8479|0.9104|0.878
# en - |0.8507|0.8924|0.8711 || |0.832|0.898|0.8638
# it - |0.8606|0.8789|0.8697 || |0.8414|0.8925|0.8662