In [24]:
from datetime import datetime
import pytz
import os
import pandas as pd

task = "ner"
base_model = 'aaaksenova/xlmr_medical'
output_path = 'output/models'
model_checkpoint = f"aaaksenova/xlmr_drug_classifier"
num_labels = 3
target_label = 'FARMACO'

sentences_file = 'data/dev_sentences.tsv'
lang = 'it'
result_file = f'output/multicardioner_track2_cardioccc_dev_{lang}_predictions.tsv'


In [2]:
df_dev_sentences = pd.read_csv(sentences_file, sep='\t')
df_dev_sentences.head()

Unnamed: 0,filename,batch_number,batch_start,text,lang
0,casos_clinicos_cardiologia10,1,0,Anamnesis\n,en
1,casos_clinicos_cardiologia10,2,10,"Male, 79 years old.",en
2,casos_clinicos_cardiologia10,3,30,Self-assisted.,en
3,casos_clinicos_cardiologia10,4,45,From Salto.\n,en
4,casos_clinicos_cardiologia10,5,57,Pathological history: -Chronic arterial hypert...,en


In [3]:
import re
TOKENIZATION_REGEX = re.compile(r"([0-9\w]+|[^0-9\w])", re.UNICODE) # ’ - es, ' - en, ' - it

In [4]:
def tokenize(text):
    original_token_offsets = []

    offset = 0
    new_offset = 0
    nonspace_token_seen = False

    tokens = [t for t in TOKENIZATION_REGEX.split(text) if t]
    for t in tokens:
        if not t.isspace():
            original_token_offsets.append([offset, offset + len(t), t, new_offset, new_offset + len(t)])
            nonspace_token_seen = True
            new_offset += len(t) + 1
        offset += len(t)
        

    tokenized_sentence = ' '.join([l[2] for l in original_token_offsets])

    # store original token offsets
    # pass the tokenized string for prediction
    return tokenized_sentence, original_token_offsets

In [5]:
from transformers import AutoTokenizer
from transformers import pipeline
from transformers import AutoModelForTokenClassification
import pandas as pd
import torch
from tqdm.auto import tqdm
import evaluate

tqdm.pandas()
metric = evaluate.load('seqeval')
map = {0: "O", 1: "B-FARMACO", 2: "I-FARMACO"}
reverse_map = {'O':0, 'B-FARMACO': 1, 'I-FARMACO': 2}
MODEL_CLASS = "aaaksenova/xlmr_drug_classifier"
MODEL_NER = "aaaksenova/xlmr_medical"
tokenizer_kwargs = {'padding': True, 'truncation':True, 'max_length':512}

In [6]:
classifier = pipeline(task="text-classification", model=MODEL_CLASS, device='cuda')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NER)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NER).to('cuda')

  return self.fget.__get__(instance, owner)()


In [7]:
ner_pipe = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple", stride=0)

In [8]:
ner_pipe('Cefazolin2g c/8 hs iv; gentamicin 3mg/kg/day iv; rifampicin 600 mg c/12 hsvo.')

[{'entity_group': 'FARMACO',
  'score': 0.99786454,
  'word': 'Cefazolin',
  'start': 0,
  'end': 9},
 {'entity_group': 'FARMACO',
  'score': 0.9980075,
  'word': 'gentamicin',
  'start': 23,
  'end': 33},
 {'entity_group': 'FARMACO',
  'score': 0.99813527,
  'word': 'rifampicin',
  'start': 49,
  'end': 59}]

In [9]:
# group annotations around a clinical procedure mention, based on the annotation label
def group_annotations_strict(annotations):
    groups = []
    i = 0
    while i < len(annotations):
        if annotations[i]['entity_group'] == 'LABEL_0':
            i += 1
            continue

        group = [] # for the strict strategy, a group is a B (or many Bs), followed by 1 or more Is
        if annotations[i]['entity_group'] == 'LABEL_1':
            group.append(annotations[i])
            i += 1

            while (i < len(annotations) and annotations[i]['entity_group'] == 'LABEL_1'):
                group.append(annotations[i])
                i += 1

            while (i < len(annotations) and annotations[i]['entity_group'] == 'LABEL_2'):
                group.append(annotations[i])
                i += 1

            groups.append(group)
        else:
            i+=1
            continue

    return groups

In [10]:
# merge grouped annotations to form a complete entity mention
def merge_annotation_group_entries(annotation_group, sentence):
    start = annotation_group[0]['start']
    end = annotation_group[len(annotation_group) - 1]['end']
    text = sentence[start:end]
    return {'start': start, 'end': end, 'text': text}

In [11]:
def get_mentions(sentence):
    out = classifier(sentence, **tokenizer_kwargs)
    #print(out[0]['label'] == 'ent')
    if out[0]['label'] == "ent":
        ner_result = ner_pipe(sentence)
        return [{'start': mention['start'], 'end': mention['end'], 'text': mention['word']} for mention in ner_result]
        #annotation_groups = group_annotations_strict(ner_result)        
        #return [merge_annotation_group_entries(group, sentence) for group in annotation_groups]
    else:
        return []

In [12]:
def get_original_mention_offset(mentions, sentence, original_token_offsets, original_sentence, filename, batch_start):
    original_mention_offsets = []
    current_offset_idx = 0 
    for mention in mentions:
        start = mention['start']
        end = mention['end']
        
        original_start = -1
        original_end = -1
        while current_offset_idx < len(original_token_offsets):
            token = original_token_offsets[current_offset_idx]
            
            if token[3] <= start:
                original_start = token[0]
            
            if token[4] >= end:
                original_end = token[1]
                break
            
            current_offset_idx += 1

        sentence_no_spaces = sentence[start:end].replace(' ', '')
        original_sentence_no_spaces = original_sentence[original_start:original_end].replace(' ', '')
        # check whether the detected span is contained in the original
        if sentence_no_spaces != original_sentence_no_spaces and not(sentence_no_spaces in original_sentence_no_spaces):
            print('potential offset issue ', filename, sentence[start:end], original_sentence[original_start:original_end])
        if original_start == -1 or original_end == -1:
            print('mention not found ', filename, mention)
            
        original_mention_offsets.append({
            'filename': filename, 
            'start_span':original_start+batch_start, 
            'end_span':original_end+batch_start, 
            'text': original_sentence[original_start:original_end]
        })
    
    return original_mention_offsets

In [13]:
text = "Treatment was started with piperacillin tazobactam 4.5g iv every 8h (after taking blood cultures) for 4 days."

In [14]:
tokenized_sentence, original_token_offsets = tokenize(text)
tokenized_sentence

'Treatment was started with piperacillin tazobactam 4 . 5g iv every 8h ( after taking blood cultures ) for 4 days .'

In [15]:
ner_pipe(tokenized_sentence)

[{'entity_group': 'FARMACO',
  'score': 0.9969311,
  'word': 'piperacillin tazobactam',
  'start': 27,
  'end': 50}]

In [16]:
mentions = get_mentions(tokenized_sentence)
mentions

[{'start': 27, 'end': 50, 'text': 'piperacillin tazobactam'}]

In [17]:
get_original_mention_offset(mentions, tokenized_sentence, original_token_offsets, text, 'filename', 10)

[{'filename': 'filename',
  'start_span': 37,
  'end_span': 60,
  'text': 'piperacillin tazobactam'}]

In [18]:
def get_mention_spans(doc, mentions):
    spans = []
    current_offset_idx = 0 
    for mention in mentions:
        start = mention['start']
        end = mention['end']
        
        span_start = -1
        span_end = -1
        while current_offset_idx < len(doc):
            token = doc[current_offset_idx]
            
            if token.idx <= start:
                span_start = token.idx
            
            if token.idx + len(token) >= end:
                span_end = token.idx + len(token)
                break
            
            current_offset_idx += 1
        
        spans.append(doc.char_span(span_start, span_end))
    return spans

In [19]:
from spacy.util import filter_spans
def get_filtered_mentions(doc, phrase_mentions, mentions):
    span_mentions = get_mention_spans(doc, mentions)
    #phrase_mentions.extend(span_mentions)
    filtered_matches = filter_spans(span_mentions)
    return [{'start': doc[match.start].idx,'end': doc[match.end-1].idx + len(doc[match.end-1]), 'text':doc[match.start:match.end]} for match in filtered_matches if len(match) > 0]

In [25]:
original_mentions_list = []
for index, row in df_dev_sentences[df_dev_sentences['lang']==lang].iterrows():
    text = row['text'].rstrip()
    tokenized_sentence, original_token_offsets = tokenize(text)
    #doc = nlp(tokenized_sentence)
    #phrase_mentions = get_phrase_mentions(doc)
    filtered_mentions = get_mentions(tokenized_sentence)
    
    #filtered_mentions = get_filtered_mentions(doc, phrase_mentions, mentions)
    
    original_mentions = get_original_mention_offset(filtered_mentions, tokenized_sentence, original_token_offsets, text, row['filename'], row['batch_start'])
    original_mentions_list.extend(original_mentions)
    
    if (index+1) % 500 == 0:
        print(f'processed {index+1} sentences')



processed 37500 sentences
processed 38000 sentences
processed 38500 sentences
processed 39000 sentences
processed 39500 sentences
processed 40000 sentences
processed 40500 sentences
processed 41000 sentences
processed 41500 sentences
processed 42000 sentences
processed 42500 sentences
processed 43000 sentences
processed 43500 sentences
processed 44000 sentences
processed 44500 sentences
processed 45000 sentences
processed 45500 sentences
processed 46000 sentences
processed 46500 sentences
processed 47000 sentences
processed 47500 sentences
processed 48000 sentences
processed 48500 sentences
processed 49000 sentences
processed 49500 sentences
processed 50000 sentences
processed 50500 sentences
processed 51000 sentences
processed 51500 sentences
processed 52000 sentences
processed 52500 sentences
processed 53000 sentences
processed 53500 sentences
processed 54000 sentences
processed 54500 sentences


In [26]:
df_mentions = pd.DataFrame.from_records(original_mentions_list)
df_mentions.head()

Unnamed: 0,filename,start_span,end_span,text
0,casos_clinicos_cardiologia10,2557,2569,Cefazolina2g
1,casos_clinicos_cardiologia10,2581,2592,gentamicina
2,casos_clinicos_cardiologia10,2608,2619,rifampicina
3,casos_clinicos_cardiologia10,3265,3274,amikacina
4,casos_clinicos_cardiologia10,3280,3291,vancomicina


In [27]:
df_mentions['label'] = target_label

In [28]:
df_mentions[['filename', 'label', 'start_span', 'end_span', 'text']].to_csv(result_file, sep='\t', index=False)

In [None]:
# es - |0.8636|0.9028|0.8827
# en - |0.8507|0.8924|0.8711
# it - |0.8606|0.8789|0.8697