In [1]:
import pandas as pd
import os
import sys
from IPython.display import display
from transformers import  DistilBertTokenizerFast,DistilBertModel
import numpy as np
import json
import torch

## Functions

In [2]:
#BOI
def BOI_labels(lab_map,offset_mapping,entity_labels):
    labels =[],
    label_ids =[]
    cur = lab_map.pop()
    prev=False
    for token in offset_mapping:
        #Case:B - if first match
        if cur[1]==token[0]:
            labels.append("B-" + entity_labels[cur[0]])
            #if second doesnt match prev =True
            if cur[2]!=token[1]:
                prev = True
            else:
                #pop label mappings
                temp = lab_map.pop() if len(lab_map)!=0 else ['dummy',-1,-1]
                while temp[1]==cur[1]:
                    cur = temp
                    temp = lab_map.pop() if len(lab_map)!=0 else ['dummy',-1,-1]
                cur = temp
        #Case:I - if first doesnt match and prev==True
        elif cur[1]!=token[0] and prev:
            labels.append("I-" + entity_labels[cur[0]])
            #if second matches, prev = False and pop label mappings
            if cur[2]==token[1]:
                prev = False
                #pop labels
                temp = lab_map.pop() if len(lab_map)!=0 else ['dummy',-1,-1]
                while temp[1]==cur[1]:
                    cur = temp
                    temp = lab_map.pop() if len(lab_map)!=0 else ['dummy',-1,-1]
                cur = temp
        #Case:O - if first doesn't match and prev == False
        elif cur[1]!=token[0] and not prev:
            labels.append("O")
    
    return labels

#BOIES
def BOIES_labels(lab_map,offset_mapping,entity_labels):
    labels =[]
    cur = lab_map.pop()
    prev=False
    for token in offset_mapping:
        #If first and second both match, then 'S'
        if cur[1]==token[0] and cur[2]==token[1]:
            labels.append("S-" + entity_labels[cur[0]])
            #pop label
            temp = lab_map.pop() if len(lab_map)!=0 else ['dummy',-1,-1]
            while temp[1]==cur[1]:
                cur = temp
                temp = lab_map.pop() if len(lab_map)!=0 else ['dummy',-1,-1]
            cur = temp


        #Elseif first not match and prev==False, then 'O'
        elif cur[1]!=token[0] and not prev:
            labels.append("O" )

        #Elseif first match and second dont match,then 'B'
        elif cur[1]==token[0] and cur[2]!=token[1]:
            labels.append("B-" + entity_labels[cur[0]])
            #set prev=True
            prev = True

        #Elseif first doesn't match,prev==True, and second doesn't match then 'I'
        elif cur[1]!=token[0] and cur[2]!=token[1] and prev:
            labels.append("I-" + entity_labels[cur[0]])

        #Elseif first doesn't match,prev==True, and second matches then 'E'
        elif cur[1]!=token[0] and cur[2]==token[1] and prev:
            labels.append("E-" + entity_labels[cur[0]])
            #prev=False and pop label
            prev =False
            temp = lab_map.pop() if len(lab_map)!=0 else ['dummy',-1,-1]
            while temp[1]==cur[1]:
                cur = temp
                temp = lab_map.pop() if len(lab_map)!=0 else ['dummy',-1,-1]
            cur = temp
    
    return labels

In [3]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

## Data Preprocessing

In [75]:
path_text = "../data/train/text/"
path_tsv = "../data/train/tsv/"
all_files = os.listdir(path_text)
tsv_files = os.listdir(path_tsv)

In [76]:
texts = []
text_labels = []
exceptions = []
for file in all_files:
    try:
        with open(os.path.join(path_text,file)) as f:
            text = f.read()
        tsv_data = pd.read_csv(os.path.join(path_tsv,file.split('.')[0]+".tsv"),sep="\t")[['annotType','startOffset','endOffset','text','annotId','other']].sort_values(by='startOffset',ascending=False)
        texts.append(text)
        text_labels.append(tsv_data.values.tolist())
    except Exception as e:
        exceptions.append(e)
        
        
                               

In [77]:
len(texts),len(text_labels)

(233, 233)

In [78]:
entity_labels={
    'Quantity':'QUANTITY',
    'MeasuredEntity': 'MEASURED_ENTITY',
    'MeasuredProperty': 'MEASURED_PROPERTY',
    'Qualifier': 'QUALIFIER' 
}    

entity_labels

{'Quantity': 'QUANTITY',
 'MeasuredEntity': 'MEASURED_ENTITY',
 'MeasuredProperty': 'MEASURED_PROPERTY',
 'Qualifier': 'QUALIFIER'}

In [79]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')

In [80]:
encodings = tokenizer(texts,return_offsets_mapping=True, padding=True, truncation=True)
offset_mappings = encodings.offset_mapping

In [81]:
labels =[]
for offset_mapping,text_label in zip(offset_mappings,text_labels):
    label = BOI_labels(text_label,offset_mapping,entity_labels)
    labels.append(label)
encodings['labels'] = [ [lab_to_id[lab] for lab in item] for item in labels]

In [82]:
pd.set_option('display.max_rows', 5)
i=2
df = pd.DataFrame({
    'tokens': tokenizer.convert_ids_to_tokens(encodings.input_ids[i]),
    'labels':encodings.labels[i]
})


In [83]:
BOI = ["B-","I-"]
id_to_lab = []
for _,el in entity_labels.items():
    for b in BOI:
        id_to_lab.append(b+el)
id_to_lab.append('O')

In [84]:
id_to_lab

['B-QUANTITY',
 'I-QUANTITY',
 'B-MEASURED_ENTITY',
 'I-MEASURED_ENTITY',
 'B-MEASURED_PROPERTY',
 'I-MEASURED_PROPERTY',
 'B-QUALIFIER',
 'I-QUALIFIER',
 'O']

In [85]:
lab_to_id = {lab:i for i,lab in enumerate(id_to_lab)}

In [86]:
lab_to_id


{'B-QUANTITY': 0,
 'I-QUANTITY': 1,
 'B-MEASURED_ENTITY': 2,
 'I-MEASURED_ENTITY': 3,
 'B-MEASURED_PROPERTY': 4,
 'I-MEASURED_PROPERTY': 5,
 'B-QUALIFIER': 6,
 'I-QUALIFIER': 7,
 'O': 8}

## Datasets and Dataloaders

In [87]:
class NERDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        item = {key: torch.Tensor(val[idx]) for key, val in self.encodings.items()}
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

In [88]:
train_dataset = NERDataset(encodings)

In [89]:
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

{'input_ids': tensor([  101.,  1109.,  2525.,  6468.,  4143.,   117.,   157.,  1403.,   117.,
          1104.,  1155.,  1103.,  9634.,  8025.,  1108.,  7140.,  1606.,  9652.,
           118.,  6676., 10735.,  3622.,   113.,   141., 13910.,  1592.,   114.,
          1114.,   170.,   154., 18910.,  1568.,   141.,  8271.,  1121.,   157.,
          1592., 25832.,   117.,  1993.,   119.,   138.,  2702.,   118.,  1169.,
         23677.,  4121.,  5418.,  1120.,   122., 25364.,  1108.,  4071.,  1606.,
          2774.,  9985.,  2539.,   240.,  1275.,   240.,   124.,  2608.,  1495.,
          1107.,  2060.,   119.,  1109.,  4143.,  2079.,  1215.,  1108.,   851.,
         20150.,  5702.,  1106.,  2363.,  5702.,  1114.,   170., 11187.,  2603.,
          1104.,   125.,  5702.,   120., 11241.,   119.,  1109.,  2860.,  1104.,
           157.,  1403.,  1108.,  3552.,  1120.,  1103.,  4709.,  2860.,  1104.,
         15925.,   421.,   119.,  1109.,  1295.,  1903.,  9546.,  2841.,  1206.,
          2771.

[[8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  2,
  3,
  3,
  4,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  0,
  1,
  1,
  1,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  2,
  4,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  0,
  1,
  1,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  2,
  3,
  3,
  6,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  6,
  4,
  8,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  8,
  0,
