If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it.

In [None]:
data_path = '/content/drive/MyDrive/AI/CLEF2023/data/validation_files'
output_path_root = '/content/drive/MyDrive/AI/CLEF2023/inference/postprocessed/paper/Tag_only_BIoM_BERT_strict_full_processed_set'
inference_files_output_path = f'{output_path_root}/inference_files'

model_checkpoint_path = '/content/drive/MyDrive/AI/CLEF2023/paper_models/tag-only-80-20-BioM-BERT-PubMed-PMC-Large-finetuned-ner/checkpoint-1785'
base_model = "sultan/BioM-BERT-PubMed-PMC-Large"

In [None]:
# # If running in Colab
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# v8+ should be fine
!java -version

openjdk version "11.0.19" 2023-04-18
OpenJDK Runtime Environment (build 11.0.19+7-post-Ubuntu-0ubuntu120.04.1)
OpenJDK 64-Bit Server VM (build 11.0.19+7-post-Ubuntu-0ubuntu120.04.1, mixed mode, sharing)


In [None]:
! pip install datasets transformers sentencepiece

Collecting datasets
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/486.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━[0m [32m286.7/486.2 kB[0m [31m8.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m99.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m67.7 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl

Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:

In [None]:
import transformers

print(transformers.__version__)

4.30.2


In [None]:
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline, TokenClassificationPipeline
import pandas as pd
import os

model = AutoModelForTokenClassification.from_pretrained(model_checkpoint_path)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model, model_max_length=512, use_fast=True)

Downloading (…)lve/main/config.json:   0%|          | 0.00/719 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/220k [00:00<?, ?B/s]

In [None]:
import re

special_chars_pattern = r'[~:\'+\[\\@^{%(\-"*|,&<`}._=\]!>;?#$)/™®©℠]'
numbers_pattern = r'[0-9]+,?[0-9]*'
combined_pattern = r'[~:\'+\[\\@^{%(\-"*|,&<`}._=\]!>;?#$)/™®©℠]|([0-9]+,?[0-9]*)'

In [None]:
class Modification:
    def __init__(self, src_start, src_end, mod_start, mod_end, original, replacement, length_difference):
      self.src_start = src_start
      self.src_end = src_end
      self.mod_start = mod_start
      self.mod_end = mod_end
      self.original = original
      self.replacement = replacement
      # how many characters we must add to the processed text to get the original length
      self.length_difference = length_difference

    def affects_position(self, position):
      return self.end <= position

def get_modifications_from_regex_matches(regex_iterator, line):
  modifications = []
  offset = 0
  for pattern_match in regex_iterator:
    src_start = pattern_match.start()
    src_end = pattern_match.end()
    matched_string = pattern_match.group(0)

    match_offset = 0 # how does this modification affect positions after it
    replacement = ''
    if re.search(numbers_pattern, matched_string):
      replacement = 'NUMBER'
      match_offset = len(matched_string) - len(replacement) - (1 if re.search(special_chars_pattern, matched_string) else 0)
    else:
      replacement = ''
      match_offset = len(matched_string) - len(replacement)

    modifications.append(Modification(src_start, src_end, src_start + offset, src_end + offset, matched_string, replacement, match_offset))

    # this offset tracks the modifications in the length of the modified line
    offset -= match_offset

  return modifications

In [None]:
# aggregation strategy 'first' - strive for the usual BIO scheme when merging.
ner_pipe = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="first", stride=0, pipeline_class=TokenClassificationPipeline)

In [None]:
# group annotations around a clinical procedure mention, based on the annotation label
def group_annotations_strict(annotations):
  groups = []
  i = 0
  while i < len(annotations):
    if annotations[i]['entity_group'] == 'LABEL_0':
      i += 1
      continue

    group = [] # for the strict strategy, a group is a B followed by 1 or more Is
    if annotations[i]['entity_group'] == 'LABEL_1':
      group.append(annotations[i])
      i += 1

      while (i < len(annotations) and annotations[i]['entity_group'] == 'LABEL_2'):
          group.append(annotations[i])
          i += 1

      groups.append(group)
    else:
      i+=1
      continue

  return groups

In [None]:
# merge grouped annotations to form a complete entity mention
def merge_annotations(annotation_group):
  start = annotation_group[0]['start']
  end = annotation_group[len(annotation_group) - 1]['end']
  text = ' '.join(annotation['word'] for annotation in annotation_group)
  return {'start': start, 'end': end, 'text': text}

In [None]:
# replaces the text of infered mentions with the original text from the input file
def restore_annotations_text(processed_annotations, input_file_text):
  for annotation in processed_annotations:
    annotation['text'] = input_file_text[annotation['start']:annotation['end']]

In [None]:
def save_annotations_tsv(processed_annotations, file_name):
  file_name = file_name.rstrip('.txt')

  filenames = []
  labels = []
  start_spans = []
  end_spans = []
  texts = []
  for annotation in processed_annotations:
    filenames.append(file_name)
    labels.append('PROCEDIMIENTO')
    start_spans.append(annotation['start'])
    end_spans.append(annotation['end'])
    texts.append(annotation['text'])

  df = pd.DataFrame(data={'filename': filenames, 'label': labels, 'start_span': start_spans, 'end_span': end_spans, 'text': texts })
  df.to_csv(f'{inference_files_output_path}/{file_name}.tsv', sep='\t', index=False)

In [None]:
done = [file_name.strip('.tsv') for file_name in os.listdir(inference_files_output_path)]

In [None]:
from transformers.models.perceiver.modeling_perceiver import PerceiverOpticalFlowDecoder
from subprocess import Popen, PIPE, STDOUT
import re
import os

for file_name in os.listdir(data_path):
  print('File: ', file_name)

  if (file_name.strip('.txt') in done):
    print('skipping...')
    continue

  # the length of text we have processed so far
  global_offset = 0

  # the final annotations for the current file
  offset_annotations = []

  with open(f'{data_path}/{file_name}') as in_file:
    for line in in_file.readlines():
      if not line:
        continue

      # preprocess
      modifications = get_modifications_from_regex_matches(re.finditer(combined_pattern, line), line)
      modifications.sort(key=lambda m: m.src_start)

      modified_line = re.sub(special_chars_pattern, '', line)
      modified_line = re.sub(numbers_pattern, 'NUMBER', modified_line)

      if not modified_line.strip():
        continue

      # predict
      annotations = ner_pipe(modified_line)
      annotation_groups = group_annotations_strict(annotations)
      merged_annotations = [merge_annotations(group) for group in annotation_groups]

      # offset predicted positions to account for preprocessing modifications
      for annotation in merged_annotations:
        offset = 0 # this is the offset caused by modifications preceeding our span, it affects both start and end positions
        end_offset = 0 # this is the offset caused by modifications that are inside our span, this affects the end position
        for modification in modifications:
          if (modification.mod_start <= annotation['start']):
            offset += (modification.length_difference)
          elif (modification.mod_end <= annotation['end']):
            if (annotation['text'] == 'tacto rectal'):
              print('End offset hit: ', modification.mod_start, modification.mod_end, modification.original)
            end_offset += modification.length_difference
          else:
            break

        annotation['start'] += (offset + global_offset)
        annotation['end'] += (offset + end_offset + global_offset)

      # after each line...
      global_offset += len(line)

      offset_annotations.extend(merged_annotations)

      # once the line is processed, replace the output text with the original, in the same start/end positions
      restore_annotations_text(offset_annotations, line)

  save_annotations_tsv(offset_annotations, file_name)

File:  es-S1130-05582004000400006-1-b-7.txt
File:  es-S1137-66272014000100021-1-b-22.txt
File:  es-S1137-66272014000100021-1-b-15.txt
File:  S0004-06142007000100011-1-b-2.txt
File:  es-S1130-05582004000400006-1-b-11.txt
File:  es-S1137-66272014000100021-1-b-19.txt
File:  es-S1130-05582004000400006-1-b-9.txt
File:  es-S1130-05582004000400006-1-b-19.txt
File:  es-S1137-66272014000100021-1-b-4.txt
File:  es-S1137-66272014000100021-1-b-9.txt
File:  es-S1130-05582007000600003-1-b-6.txt
File:  es-S0212-71992003000500006-1-b-3.txt
File:  es-S0212-71992003000500006-1-b-13.txt
File:  es-S1130-05582007000600003-1-b-10.txt
File:  es-S0212-71992003000500006-1-b-14.txt
File:  es-S1139-76322009000200009-1-b-18.txt
File:  es-S0365-66912006000100010-1-b-8.txt
File:  es-S1130-01082008001100009-1-b-4.txt
File:  es-S1139-76322009000200009-1-b-13.txt
File:  es-S1139-76322009000200009-1-b-3.txt
File:  es-S1130-01082008001100009-1-b-7.txt
File:  es-S1139-76322009000200009-1-b-9.txt
File:  es-S1139-763220090

In [None]:
import pandas as pd
import glob
import os

li = []

for filename in os.listdir(inference_files_output_path):
    df = pd.read_csv(f'{inference_files_output_path}/{filename}', index_col=None, header=0, sep='\t')
    li.append(df)

frame = pd.concat(li, axis=0, ignore_index=True)
frame.to_csv(f'{output_path_root}/val_20_strict_first_tag_only_full_preprocess.tsv', sep='\t', index=False)