In [None]:
#default_exp train

In [None]:
#export
import os
import torch

import pandas as pd
import numpy as np
import warnings

import Bert4NER.config as config
import Bert4NER.model.model as model
import Bert4NER.utils.utils as utils
import Bert4NER.utils.engine as engine
import Bert4NER.dataset.dataset as dataset


from functools import partial
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from transformers import AdamW, get_linear_schedule_with_warmup

warnings.filterwarnings("ignore") 

In [None]:
#export
SEED = 42
utils.seed_everything(SEED)

In [None]:
#hide
df = pd.read_csv(config.DATA_PATH/'ner_datasetreference.csv', encoding='latin-1')
df.head()

Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,,of,IN,O
2,,demonstrators,NNS,O
3,,have,VBP,O
4,,marched,VBN,O


We use the foward fill method in pandas to fill all the nans for the each sentence in the `Sentence #` column.

In [None]:
#hide
df['Sentence #'].fillna(method='ffill')

0              Sentence: 1
1              Sentence: 1
2              Sentence: 1
3              Sentence: 1
4              Sentence: 1
                ...       
1048570    Sentence: 47959
1048571    Sentence: 47959
1048572    Sentence: 47959
1048573    Sentence: 47959
1048574    Sentence: 47959
Name: Sentence #, Length: 1048575, dtype: object

In [None]:
#export
df['Sentence #'] = df['Sentence #'].fillna(method='ffill')

In total we cans ee that there are 47959 sentences in our dataset

In [None]:
#hide
len(df['Sentence #'].unique())

47959

Now let us encode all the labels for every word in every sentence

In [None]:
#hide
le_pos = LabelEncoder()
le_tag = LabelEncoder()

In [None]:
#export
utils.save_label_encoders(le_tag=le_tag, le_pos=le_pos)

encoders already exist


In [None]:
#export
le_pos, le_tag = utils.load_label_encoders()

In [None]:
#hide
df["encoded_POS"] = le_pos.fit_transform(df.POS)
df["encoded_Tag"] = le_tag.fit_transform(df.Tag)

In [None]:
#export
sentences, tags, pos = utils.process_data(df)

In [None]:
#hide
len(sentences), len(tags), len(pos)

(47959, 47959, 47959)

## data Split

I'll be using a simple train-test split

In [None]:
#export
train_sentences, valid_sentences, train_tag, valid_tag, train_pos, valid_pos = train_test_split(sentences, tags, pos, test_size=0.2)

In [None]:
#export
train_dl = utils.create_loader(train_sentences, train_tag, train_pos, bs=config.TRAIN_BATCH_SIZE)
valid_dl = utils.create_loader(valid_sentences, valid_tag, valid_pos, bs=config.VALID_BATCH_SIZE)

In [None]:
#export
modeller = model.EntityModel(num_tag=len(le_tag.classes_), num_pos=len(le_pos.classes_))

In [None]:
# #export
model_params = list(modeller.named_parameters())

In [None]:
#export
# we don't want weight decay for these
no_decay = ['bias', 'LayerNorm.weight', 'LayerNorm.bias']

optimizer_params = [
    {'params': [p for n, p in model_params if not any(nd in n for nd in no_decay)], 
    'weight_decay':0.001},
    #  no weight decay should be applied
    {'params': [p for n, p in model_params if any(nd in n for nd in no_decay)],
    'weight_decay':0.0}
]

In [None]:
#export
lr = config.LR

In [None]:
#hide
lr

1e-05

In [None]:
#export
optimizer = AdamW(optimizer_params, lr=lr)

In [None]:
#export
num_train_steps = int(len(sentences) / config.TRAIN_BATCH_SIZE * config.NUM_EPOCHS)

In [None]:
#export
scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, 
                                                num_warmup_steps=0, 
                                                num_training_steps=num_train_steps)

In [None]:
#export
learn = engine.BertFitter(modeller, (train_dl, valid_dl), optimizer, [accuracy_score, partial(f1_score, average='macro')], config.DEVICE, scheduler=scheduler, log_file='training_log.txt')

In [None]:
#hide
config.NUM_EPOCHS

4

In [None]:
#export
NUM_EPOCHS = config.NUM_EPOCHS + 2
learn.fit(NUM_EPOCHS, model_path=config.MODEL_PATH/'entity_model.bin')

epoch,train_loss,valid_loss,tag_accuracy,tag_f1_score,pos_accuracy,pos_f1_score,time
