https://www.analyticsvidhya.com/blog/2020/07/transfer-learning-for-nlp-fine-tuning-bert-for-text-classification/
### Model from:

https://github.com/allenai/scibert

In [1]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import re
import seaborn as sns
from tqdm import tqdm
import nltk
import random
nltk.download('punkt')
from nltk.tokenize import word_tokenize,sent_tokenize
import pickle

train_example_names = [fn.split('.')[0] for fn in os.listdir('data/train')]
test_example_names = [fn.split('.')[0] for fn in os.listdir('data/test')]

metadata = pd.read_csv('data/train.csv')
docIdx = train_example_names.copy()

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\ozano\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Create dataframe for tokens and targets

In [5]:
import unidecode

match_puncs_re = r"([.,!?()\-;\[\]+\\\/@:<>#_{}&%'*=" + r'"' + r"|])"
match_puncs_re = re.compile(match_puncs_re)

def load_train_example_by_name(name):
    doc_path = os.path.join('data/train', name + '.json')
    with open(doc_path) as f:
        data = json.load(f)
    return data

def load_test_example_by_name(name):
    doc_path = os.path.join('data/test', name + '.json')
    with open(doc_path) as f:
        data = json.load(f)
    return data

def text_cleaning_for_bert(text):
    # Keeps puncs, pads them with whitespaces

    text = text.replace('^', ' ')
    text = unidecode.unidecode(text)

    # Remove websites
    text = ' '.join(['specialwebsite' if 'http' in t or 'www' in t else t for t in text.split(' ') ])

    text = match_puncs_re.sub(r' \1 ', text)

    # remove extra spaces
    text = re.sub("\s+"," ", text)

    return text.lower()

def text_cleaning_for_label(text):
    text = text.replace('^', ' ')
    text = unidecode.unidecode(text)

    text = re.sub('[^A-Za-z0-9]+', ' ', str(text)).strip() # remove unnecessary literals

    # Remove websites
    text = ' '.join(['specialwebsite' if 'http' in t or 'www' in t else t for t in text.split(' ') ])

    # remove extra spaces
    text = re.sub("\s+"," ", text)

    return text.lower()

In [6]:
import string

def text_cleaning(text):
    text = re.sub('[^A-Za-z0-9]+', ' ', str(text)).strip() # remove unnecessary literals

    # remove extra spaces
    text = re.sub("\s+"," ", text)

    return text.lower()

##### STEP 1: Make a list of the known labels provided to us

temp_1 = [text_cleaning(x) for x in metadata['dataset_label']]
temp_2 = [text_cleaning(x) for x in metadata['dataset_title']]
temp_3 = [text_cleaning(x) for x in metadata['cleaned_label']]

existing_labels = temp_1 + temp_2 + temp_3
existing_labels = [l.lower() for l in existing_labels]
existing_labels = list(set(existing_labels))
existing_labels = existing_labels + ['programme for international student assessment', 'kindergarten cohort ecls', 'organization for economic cooperation and development', 'blsa']
# Sort labels by length in descending order
existing_labels = sorted(existing_labels, key = len, reverse= True)

## Make sentences

In [7]:
pos_sentences = []
neg_sentences = []

def process_doc(doc_id):
    doc_json = load_train_example_by_name(doc_id)
    doc_text = ' '.join([sec['text'] for sec in doc_json])

    # Tokenize sentencewise
    sentences = sent_tokenize(doc_text)

    adni_count = 0
    for sentence in sentences:
        clean_sentence = text_cleaning(sentence)

        has_label = False
        label_is_adni = False
        for clean_label in existing_labels:
            if clean_label in clean_sentence:
                has_label = True

                if 'adni' in clean_label or 'alzheimer' in clean_label:
                    adni_count += 1
                    label_is_adni = True

                break

        if has_label and (adni_count <= 2 or not label_is_adni):
            pos_sentences.append(sentence)
        else:
            if random.uniform(0, 1) < 0.25:
                neg_sentences.append(sentence)

## Generate and Save Sentences

In [5]:
import pickle

pos_sentences = []
neg_sentences = []

for doc_id in tqdm(docIdx):
    process_doc(doc_id)

with open(f'data/bert_ner_sentences/pos.pkl', 'wb') as f:
    pickle.dump(pos_sentences, f)

with open(f'data/bert_ner_sentences/neg.pkl', 'wb') as f:
    pickle.dump(neg_sentences, f)

print(f'pos size: {len(pos_sentences)}')
print(f'neg size: {len(neg_sentences)}')

100%|██████████| 14316/14316 [03:52<00:00, 61.47it/s]
pos size: 32509
neg size: 1036186


## Load Sentences

In [8]:
import pickle

with open(f'data/bert_ner_sentences/pos.pkl', 'rb') as f:
    pos_sentences = pickle.load(f)

with open(f'data/bert_ner_sentences/neg.pkl', 'rb') as f:
    neg_sentences = pickle.load(f)

print(f'pos size: {len(pos_sentences)}')
print(f'neg size: {len(neg_sentences)}')

pos size: 32509
neg size: 1036186


In [5]:
pos_sentences_processed = []
neg_sentences_processed = []
pos_labels = []
neg_labels = []

n_broken_sent = 0
n_pos_no_label = 0

def is_text_broken(tokens):
    # Some texts are like 'p a dsdv a d a ds f b', remove them
    if len(tokens) == 0:
        return True

    if len(tokens) < 50:
        return False

    one_char_token_ratio = len([l for l in tokens if len(l) == 1]) / len(tokens)
    return one_char_token_ratio > 0.15

def split_to_smaller_sent(tokens, s_size, overlap_size):
    # output sentences will be s_size + overlap_size long
    small_sents = []

    if len(tokens) <= s_size:
        return [tokens]

    n_parts = len(tokens) // s_size
    if len(tokens) % s_size != 0:
        n_parts += 1

    for i_part in range(n_parts):
        start_i = i_part * s_size
        if i_part > 0:
            start_i -= overlap_size

        end_i = min(len(tokens), (i_part + 1) * s_size)

        small_sents.append(tokens[start_i: end_i])

    return small_sents

def join_tuple_tokens(tuples):
    return ' '.join([t[1] for t in tuples])

def get_index(lst, el):
    try:
        return lst.index(el)
    except ValueError as e:
        for i, lst_el in enumerate(lst):
            if el in lst_el:
                return i
        
    raise ValueError(f'Element {el} not found in {lst}')

def process_pos_sentence(sentence):
    global n_broken_sent
    global last_doc_labels

    bert_sentence = text_cleaning_for_bert(sentence)
    label_sentence = text_cleaning_for_label(sentence)

    if is_text_broken(label_sentence.split(' ')): # Can't use bert cleaning for this, because all punc.s are padded with spaces
        n_broken_sent += 1
        return
    
    bert_tokens = bert_sentence.split(' ')
    ### STEP 1: Split into fixed sized sentences ###
    for small_sentence_tokens in split_to_smaller_sent(bert_tokens, s_size = 125, overlap_size = 25):

        small_bert_sentence = ' '.join(small_sentence_tokens)

        # Need to remove punc.s and uppercase letters to find labels
        small_label_sentence = text_cleaning_for_label(small_bert_sentence)

        has_label = False
        sent_labels = []
        ### STEP 2: Match labels ###
        # Check if contains labels
        for clean_label in existing_labels:
            if clean_label in small_label_sentence:
                has_label = True

                # Remove label from the text, to only match the largest label
                small_label_sentence = small_label_sentence.replace(clean_label, '')
                sent_labels.append(clean_label)

        small_sent_targets = ['O' for _ in range(len(small_sentence_tokens))]

        if has_label:
            # Tokenize labels for matching
            sent_label_tokens = [l.split(' ') for l in sent_labels]

            # Get index, token tuples for clean tokens. Indices are for raw tokens
            small_sent_tuples = [(i, token) for i, token in enumerate(small_sentence_tokens) if text_cleaning_for_label(token) != '']

            ### STEP 3: Set corresponding targets for each label ###
            # Target: (B, I, O), Label: adni
            for l in sent_labels:
                l_tokens = l.split(' ')
                small_sent_joined = [join_tuple_tokens(small_sent_tuples[i: i + len(l_tokens)]) for i in range(len(small_sent_tuples) - len(l_tokens) + 1)]

                label_start_i = get_index(small_sent_joined, l)
                label_end_i = label_start_i + len(l_tokens) - 1

                target_start_i = small_sent_tuples[label_start_i][0]
                target_end_i = small_sent_tuples[label_end_i][0]

                # Do not use the same tokens for multiple labels
                small_sent_tuples = small_sent_tuples[:label_start_i] + small_sent_tuples[label_end_i:]

                try:
                    small_sent_targets[target_start_i] = 'B'
                    if target_end_i - target_start_i > 0:
                        for i in range(target_start_i+1, target_end_i+1):
                            small_sent_targets[i] = 'I'

                except Exception as e:
                    print('DEBUG')
                    print(small_sentence_tokens)
                    print(len(small_sentence_tokens))
                    print(len(small_sent_targets))
                    print(target_start_i)
                    print(small_sent_joined)
                    print('DEBUG')
                    raise e

        ### STEP 4: Add sentence output to lists ###
        if has_label:
            pos_sentences_processed.append(small_sentence_tokens)
            pos_labels.append(small_sent_targets)
        """else:
            neg_sentences_processed.append(small_sentence_tokens)
            neg_labels.append(small_sent_targets)"""

def process_neg_sentence(sentence):
    global n_broken_sent
    
    bert_sentence = text_cleaning_for_bert(sentence)
    label_sentence = text_cleaning_for_label(sentence)

    if is_text_broken(label_sentence.split(' ')): # Can't use bert cleaning for this, because all punc.s are padded with spaces
        n_broken_sent += 1
        return

    bert_tokens = bert_sentence.split(' ')
    
    ### STEP 1: Split into fixed sized sentences ###
    for small_sentence_tokens in split_to_smaller_sent(bert_tokens, s_size = 125, overlap_size = 25):
        small_sent_targets = ['O' for _ in range(len(bert_tokens))]

        neg_sentences_processed.append(small_sentence_tokens)
        neg_labels.append(small_sent_targets)

process_pos_sentence(pos_sentences[2472])

NameError: name 'pos_sentences' is not defined

## Create NER Dataset and Save

In [7]:
pos_sentences_processed = []
neg_sentences_processed = []
pos_labels = []
neg_labels = []

n_pos_no_label = 0
n_broken_sent = 0

for sent in tqdm(pos_sentences):
    process_pos_sentence(sent)

for sent in tqdm(neg_sentences):
    process_neg_sentence(sent)

import pickle

with open(f'data/bert_ner_data/pos.pkl', 'wb') as f:
    pickle.dump(pos_sentences_processed, f)

with open(f'data/bert_ner_data/neg.pkl', 'wb') as f:
    pickle.dump(neg_sentences_processed, f)

with open(f'data/bert_ner_data/pos_labels.pkl', 'wb') as f:
    pickle.dump(pos_labels, f)

with open(f'data/bert_ner_data/neg_labels.pkl', 'wb') as f:
    pickle.dump(neg_labels, f)


print('')
print(f'broken sentences: {n_broken_sent}')
print(f'n_pos_no_label: {n_pos_no_label}')
print(f'pos_proc size: {len(pos_sentences_processed)}')
print(f'neg_proc size: {len(neg_sentences_processed)}')

100%|██████████| 32509/32509 [00:08<00:00, 3897.80it/s]
100%|██████████| 1036186/1036186 [00:57<00:00, 18128.76it/s]

broken sentences: 8296
n_pos_no_label: 0
pos_proc size: 32235
neg_proc size: 1032513


## Load NER Dataset

In [2]:
import pickle

with open(f'data/bert_ner_data/pos.pkl', 'rb') as f:
    pos_sentences_processed = pickle.load(f)

with open(f'data/bert_ner_data/neg.pkl', 'rb') as f:
    neg_sentences_processed = pickle.load(f)

with open(f'data/bert_ner_data/pos_labels.pkl', 'rb') as f:
    pos_labels = pickle.load(f)

with open(f'data/bert_ner_data/neg_labels.pkl', 'rb') as f:
    neg_labels = pickle.load(f)

print(f'pos size: {len(pos_sentences_processed)}')
print(f'neg size: {len(neg_sentences_processed)}')
print(f'pos label size: {len(pos_labels)}')
print(f'neg label size: {len(neg_labels)}')

pos size: 32235
neg size: 1032513
pos label size: 32235
neg label size: 1032513


In [3]:
i_ex = 1000
pd.DataFrame({'token': pos_sentences_processed[i_ex], 'label': pos_labels[i_ex]})

Unnamed: 0,token,label
0,based,O
1,on,O
2,the,O
3,2003,O
4,united,O
5,states,O
6,department,O
7,of,O
8,agriculture,O
9,rural,B


## Create Training Data

In [4]:
from sklearn.model_selection import train_test_split
import numpy as np

sentences = pos_sentences_processed# + neg_sentences_processed
labels = pos_labels# + neg_labels

print('Splitting data...')
train_sents, val_sents, train_labels, val_labels = train_test_split(sentences, labels, test_size=0.20, random_state=42)

Splitting data...


## Save Datasets

In [9]:
data_dir = r'data\ner\coleridge'

In [13]:
import os

# Save train
dst_path = os.path.join(data_dir, 'train.txt')
lines = []
lines = []
docsize = 0
for tokens, labels in zip(train_sents, train_labels):
    for t, l in zip(tokens, labels):
        if len(t) == 0:
            continue
        lines.append(f'{t} NN O {l}')
        docsize += 1

    lines.append('')

lines = lines[:-1]
lines = [f'-DOCSTART- ({docsize})', ''] + lines
with open(dst_path, 'w') as f:
    for line in lines:
        f.write(line)
        f.write('\n')

In [14]:
import os

lines = []
docsize = 0
for tokens, labels in zip(val_sents, val_labels):
    for t, l in zip(tokens, labels):
        if len(t) == 0:
            continue
        lines.append(f'{t} NN O {l}')
        docsize += 1

    lines.append('')

lines = lines[:-1]
lines = [f'-DOCSTART- ({docsize})', ''] + lines

# Save dev
dst_path = os.path.join(data_dir, 'dev.txt')
with open(dst_path, 'w') as f:
    for line in lines:
        f.write(line)
        f.write('\n')

test_path = os.path.join(data_dir, 'test.txt')
with open(test_path, 'w') as f:
    for line in lines:
        f.write(line)
        f.write('\n')

## Fine Tune Bert

In [5]:
 import os
 import math
 import random
 import csv
 import sys
 import numpy as np
 import pandas as pd
 from sklearn import metrics
 from sklearn.metrics import f1_score, precision_score, recall_score
 from sklearn.metrics import classification_report
 import statistics as stats
 from bert_sklearn import BertTokenClassifier 

In [6]:
model = BertTokenClassifier(bert_model='scibert-scivocab-uncased',
                             max_seq_length=178, 
                             epochs=3,
                             #gradient accumulation
                             gradient_accumulation_steps=4,
                             learning_rate=3e-5,
                             train_batch_size=8,#batch size for training
                             eval_batch_size=8, #batch size for evaluation
                             validation_fraction=0., 
                             #ignore the tokens with label ‘O’                      
                             ignore_label=['O'])

Building sklearn token classifier...


In [14]:
def flatten(l):
    return [item for sublist in l for item in sublist]

In [15]:
#X_train, y_train = flatten(train_sents), flatten(train_labels)

In [7]:
model.fit(train_sents, train_labels)

  return np.array(X)
Loading scibert-scivocab-uncased model...
Defaulting to linear classifier/regressor
Loading Pytorch checkpoint
train data size: 25788, validation data size: 0
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:1005.)
  next_m.mul_(beta1).add_(1 - beta1, grad)
Training  : 100%|██████████| 12894/12894 [35:47<00:00,  6.00it/s, loss=0.0086]
Training  : 100%|██████████| 12894/12894 [37:20<00:00,  5.76it/s, loss=0.001]
Training  : 100%|██████████| 12894/12894 [36:39<00:00,  5.86it/s, loss=0.000288]


BertTokenClassifier(bert_model='scibert-scivocab-uncased', do_lower_case=True,
                    gradient_accumulation_steps=4, ignore_label=['O'],
                    label_list=array(['B', 'I', 'O'], dtype='<U1'),
                    learning_rate=3e-05, max_seq_length=178, train_batch_size=8,
                    validation_fraction=0.0)

In [9]:
# save model to disk
savefile='data/sklearn_bert.bin'
model.save(savefile)

In [10]:
val_preds = model.predict(val_sents)

  return np.array(X)
Predicting: 100%|██████████| 806/806 [01:38<00:00,  8.19it/s]


In [12]:
ex_i = 0
pd.DataFrame({'token': val_sents[ex_i], 'pred':val_preds[ex_i]})

Unnamed: 0,pred,token
0,O,this
1,O,work
2,O,describes
3,O,the
4,O,validation
5,O,of
6,O,this
7,O,automated
8,O,hippocampal
9,O,volumetric
