In [109]:
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
import torch
from torch.utils.data import Dataset, DataLoader
from torch import mps
from transformers import BertTokenizer, BertForTokenClassification, BertConfig


device = 'mps' if mps.is_available() else 'cpu'


In [110]:
path = 'data/ner_data.csv'
data = pd.read_csv(path, encoding = 'unicode_escape')

In [111]:
data.head(5)

Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,,of,IN,O
2,,demonstrators,NNS,O
3,,have,VBP,O
4,,marched,VBN,O


In [112]:
data = data.fillna(method='ffill')

  data = data.fillna(method='ffill')


In [113]:
#convert iob tags to base tag
data['base_tag'] = data['Tag'].apply(lambda x: x.split('-')[-1])

In [114]:
freqs = data['Tag'].value_counts()
print("IOB tag count")
freqs

IOB tag count


Tag
O        887908
B-geo     37644
B-tim     20333
B-org     20143
I-per     17251
B-per     16990
I-org     16784
B-gpe     15870
I-geo      7414
I-tim      6528
B-art       402
B-eve       308
I-art       297
I-eve       253
B-nat       201
I-gpe       198
I-nat        51
Name: count, dtype: int64

In [115]:
iob_tags = []
for t, f in zip(freqs.index, freqs):
    iob_tags.append(t)

unique_tag =  []
for tag in iob_tags:
    s = tag.split('-')
    unique_tag.append(s[-1])

unique_tag = list(set(unique_tag))
print(f'Unique base tag: {unique_tag}')

Unique base tag: ['org', 'per', 'art', 'eve', 'O', 'geo', 'nat', 'tim', 'gpe']


In [116]:
data['base_tag'].value_counts()

base_tag
O      887908
geo     45058
org     36927
per     34241
tim     26861
gpe     16068
art       699
eve       561
nat       252
Name: count, dtype: int64

In [117]:
# art eve nat is not defined properly, removing them
to_remove = ['art','nat','eve']

data = data[~data.base_tag.isin(to_remove)]

In [118]:
data.head(5)

Unnamed: 0,Sentence #,Word,POS,Tag,base_tag
0,Sentence: 1,Thousands,NNS,O,O
1,Sentence: 1,of,IN,O,O
2,Sentence: 1,demonstrators,NNS,O,O
3,Sentence: 1,have,VBP,O,O
4,Sentence: 1,marched,VBN,O,O


In [119]:
labels = data['Tag'].value_counts().index

label2id = {}
id2label = {}
for idx, label in enumerate(labels):
    label2id[label] = idx
    id2label[idx] = label

In [None]:
data['Sentence'] = data[['Sentence #','Word', "Tag"]].groupby(['Sentence #'])["Word"].transform(lambda x: ' '.join(x))
data['Word_labels'] = data[['Sentence #','Word', "Tag"]].groupby(['Sentence #'])["Tag"].transform(lambda x: ' '.join(x))


In [88]:
data.head(10)

Unnamed: 0,Sentence #,Word,POS,Tag,base_tag,Sentence,Word_labels
0,Sentence: 1,Thousands,NNS,O,O,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
1,Sentence: 1,of,IN,O,O,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
2,Sentence: 1,demonstrators,NNS,O,O,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
3,Sentence: 1,have,VBP,O,O,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
4,Sentence: 1,marched,VBN,O,O,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
5,Sentence: 1,through,IN,O,O,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
6,Sentence: 1,London,NNP,B-geo,geo,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
7,Sentence: 1,to,TO,O,O,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
8,Sentence: 1,protest,VB,O,O,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
9,Sentence: 1,the,DT,O,O,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...


In [122]:
data = data[["Sentence", "Word_labels"]].drop_duplicates().reset_index(drop=True)
data.head()

Unnamed: 0,Sentence,Word_labels
0,Thousands of demonstrators have marched throug...,O O O O O O B-geo O O O O O B-geo O O O O O B-...
1,Families of soldiers killed in the conflict jo...,O O O O O O O O O O O O O O O O O O B-per O O ...
2,They marched from the Houses of Parliament to ...,O O O O O O O O O O O B-geo I-geo O
3,"Police put the number of marchers at 10,000 wh...",O O O O O O O O O O O O O O O
4,The protest comes on the eve of the annual con...,O O O O O O O O O O O B-geo O O B-org I-org O ...


In [123]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [174]:
input_ids = []
attention_mask = []
labels = []

for sent in data["Sentence"].head(5):
    encoded_dict =  tokenizer.encode_plus(
        sent,
        add_special_tokens = True,
        max_length = 128,
        pad_to_max_length = True,
        return_attention_mask = True,
        return_tensors = 'pt'
    )

    input_ids.append(encoded_dict['input_ids'])
    attention_mask.append(encoded_dict['attention_mask'])

for iob_labels in data["Word_labels"].head(5):
    l = iob_labels.split(' ')
    temp = [label2id[x] for x in l]
    temp = torch.tensor(temp)
    labels.append(temp)




In [175]:
input_ids[0]

tensor([[  101,  5190,  1997, 28337,  2031,  9847,  2083,  2414,  2000,  6186,
          1996,  2162,  1999,  5712,  1998,  5157,  1996, 10534,  1997,  2329,
          3629,  2013,  2008,  2406,  1012,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [176]:
attention_mask[0]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])

In [188]:
len(labels[0])

24

In [179]:
data['Word_labels'][0]

'O O O O O O B-geo O O O O O B-geo O O O O O B-gpe O O O O O'

In [180]:
data['Sentence'][0]

'Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country .'

In [187]:
tokenizer.ids_to_tokens[102]

'[SEP]'