This notebook shows how to fine-tune a BERT model (from huggingface) for our dataset recognition task.

Note that internet is needed during the training phase (for downloading the bert-base-cased model). Internet can be turned off during prediction.

## Install packages

In [None]:
!pip install datasets --no-index --find-links=file:///kaggle/input/coleridge-packages/packages/datasets
!pip install ../input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
!pip install ../input/coleridge-packages/tokenizers-0.10.1-cp37-cp37m-manylinux1_x86_64.whl
!pip install ../input/coleridge-packages/transformers-4.5.0.dev0-py3-none-any.whl

# Import

In [None]:
import os
import re
import json
import time
import datetime
import random
import glob
import importlib

import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

random.seed(123)
np.random.seed(456)

In [None]:
# copy my_seqeval.py to the working directory because the input directory is non-writable
!cp /kaggle/input/coleridge-packages/my_seqeval.py ./

# Hyper-parameters

In [None]:
MAX_LENGTH = 64 # max no. words for each sentence.
OVERLAP = 20 # if a sentence exceeds MAX_LENGTH, we split it to multiple sentences with overlapping

MAX_SAMPLE = None # set a small number for experimentation, set None for production.

# Load data

In [None]:
train_path = '../input/coleridgeinitiative-show-us-the-data/train.csv'
paper_train_folder = '../input/coleridgeinitiative-show-us-the-data/train'

train = pd.read_csv(train_path)
train = train[:MAX_SAMPLE]
print(f'No. raw training rows: {len(train)}')

Group by publication, training labels should have the same form as expected output.

In [None]:
train = train.groupby('Id').agg({
    'pub_title': 'first',
    'dataset_title': '|'.join,
    'dataset_label': '|'.join,
    'cleaned_label': '|'.join
}).reset_index()

print(f'No. grouped training rows: {len(train)}')

In [None]:
papers = {}
for paper_id in train['Id'].unique():
    with open(f'{paper_train_folder}/{paper_id}.json', 'r') as f:
        paper = json.load(f)
        papers[paper_id] = paper

# Transform data to NER format

In [None]:
def clean_training_text(txt):
    """
    similar to the default clean_text function but without lowercasing.
    """
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt)).strip()

def get_words(idx, sentences, length):
    words = []
    # Keep taking sentences with words untill full
    for sentence in sentences[(idx+1):]:
        words.extend(sentence.split())
        words.append("[SEP]")

        # Combining these would result in a too large entry
        if length-len(words) < 0:
            break

    if length-len(words) > 0:
        for sentence in sentences[:idx]:
            words.extend(sentence.split())
            words.append("[SEP]")

            # Combining these would result in a too large entry
            if length-len(words) < 0:
                break
    
#     # DOC is shorter than maximum length
#     if length-len(words) > 0:
#         words.append("[PAD]")

    return words[:length]
        
def get_sentences_before(sentences, idx, jdx, cur_length):
    reverse_words = []
    for sentence in sentences[max(idx-jdx, 0):idx][::-1]:
        words = sentence.split()
        words.append("[SEP]")
        words = words[::-1]
        # If added sentences are too much, then only add last words of a sentence
        # this indicates this iteration is the last, since no further 'before-sentences' can be added
        if len(words)+len(reverse_words)+cur_length > MAX_LENGTH-1:
            reverse_words.extend(words[:(MAX_LENGTH-len(reverse_words)-cur_length-1)])
            break
        else:
            reverse_words.extend(words)
    
    if idx-jdx < 0:
        for sentence in sentences[idx-jdx:][::-1]:
            words = sentence.split()
            words.append("[SEP]")
            words = words[::-1]
            # If added sentences are too much, then only add last words of a sentence
            # this indicates this iteration is the last, since no further 'before-sentences' can be added
            if len(words)+len(reverse_words)+cur_length > MAX_LENGTH-1:
                reverse_words.extend(words[:(MAX_LENGTH-len(reverse_words)-cur_length-1)])
                break
            else:
                reverse_words.extend(words)
    
    return reverse_words[::-1]

def shorten_sentences_with_context(sentences):
    sentences = list(sentences)
    short_sentences = []
    for idx, val in enumerate(sentences):
        val = val.split()
        
        # If this sentence is too long already:
        if len(val) > (MAX_LENGTH - 2): # We append [CLS], [SEP] and [PAD] tokens
            for p in range(0, len(val), MAX_LENGTH-1):
                temp = val[p:p+MAX_LENGTH-1]

                # Append this sentence to the list if it is of max length
                # If this is the final part of the sentence that fits with the tokens
                if len(temp) == MAX_LENGTH-2:
                    short_sentences.append("[CLS] " + ' '.join(temp) + " [SEP]")
                        
                # This is the final part of the sentence, but does not fit [SEP] token
                elif len(temp) == MAX_LENGTH-1:
                    short_sentences.append("[CLS] " + ' '.join(temp))

                # Fill it up and continue if not full
                else:
                    temp.append("[SEP]")
                    temp.extend(get_words(idx, sentences, MAX_LENGTH-len(temp)-1))
                    short_sentences.append("[CLS] " + ' '.join(temp))

            continue


        # For each sentence there are multiple instances, using a larger set of sentences before each time
        val.append("[SEP]")
        for jdx in range(len(sentences)):#idx+1):
            reverse_words = get_sentences_before(sentences, idx, jdx, len(val))
            
            # Combine 'before-senteces' and this sentence
            tot_words = reverse_words + val
            
            # If this is full, then we can't append any more before this part, 
            # so this loop is done, go to next sentence
            if len(tot_words) == MAX_LENGTH-1:
                short_sentences.append("[CLS] " + ' '.join(tot_words))
                break
            
            # Else append next sentences, untill full
            tot_words.extend(get_words(idx, sentences, MAX_LENGTH-len(tot_words)-1))
            short_sentences.append("[CLS] " + ' '.join(tot_words))

    return list(set(short_sentences))

def shorten_sentences(sentences):
    short_sentences = []
    for sentence in sentences:
        words = sentence.split()
        if len(words) > MAX_LENGTH:
            for p in range(0, len(words), MAX_LENGTH - OVERLAP):
                short_sentences.append(' '.join(words[p:p+MAX_LENGTH]))
        else:
            short_sentences.append(sentence)
    return short_sentences

def find_sublist(big_list, small_list):
    all_positions = []
    for i in range(len(big_list) - len(small_list) + 1):
        if small_list == big_list[i:i+len(small_list)]:
            all_positions.append(i)
    
    return all_positions

def tag_sentence(sentence, labels): # requirement: both sentence and labels are already cleaned
    sentence_words = sentence.split()
    
    if labels is not None and any(re.findall(f'\\b{label}\\b', sentence)
                                  for label in labels): # positive sample
        nes = ['O'] * len(sentence_words)
        for label in labels:
            label_words = label.split()

            all_pos = find_sublist(sentence_words, label_words)
            for pos in all_pos:
                nes[pos] = 'B'
                for i in range(pos+1, pos+len(label_words)):
                    nes[i] = 'I'

        return True, list(zip(sentence_words, nes))
        
    else: # negative sample
        nes = ['O'] * len(sentence_words)
        return False, list(zip(sentence_words, nes))

In [None]:
# # TEST for each case in context function
# MAX_LENGTH = 6 # lower max_length for testing

# def gen_sen(n):
#     return ' '.join([str(i) for i in range(n)])
# # T1: Sentence too long
# t1 = [gen_sen(2), gen_sen(10), gen_sen(4)[-3:]]
# r1 = shorten_sentences_with_context(t1)
# print("Test 1")
# print([len(i.split()) for i in t1])
# print([len(i.split()) for i in r1])
# print(t1)
# print(r1)
# # T2: Add multiple sentences
# t2 = [gen_sen(2), gen_sen(1), gen_sen(4)[-1:]]
# r2 = shorten_sentences_with_context(t2)
# print("Test 2")
# print([len(i.split()) for i in t2])
# print([len(i.split()) for i in r2])
# print(t2)
# print(r2)
# MAX_LENGTH = 64 # Reset max_length

In [None]:
cnt_pos, cnt_neg = 0, 0 # number of sentences that contain/not contain labels
keep_percentage = 0.2 # Percentage to keep a subset of the negative samples
ner_data = []

pbar = tqdm(total=len(train))
for i, id, dataset_label in train[['Id', 'dataset_label']].itertuples():
    # paper
    paper = papers[id]
    
    # labels
    labels = dataset_label.split('|')
    labels = [clean_training_text(label) for label in labels]
    
    # sentences
    sentences = set([clean_training_text(sentence) for section in paper 
                 for sentence in section['text'].split('.') 
                ])
    sentences = shorten_sentences_with_context(sentences) #shorten_sentences(sentences) # make sentences short
    sentences = [sentence for sentence in sentences if len(sentence) > 10] # only accept sentences with length > 10 chars
    
    # positive sample
    for sentence in sentences:
        is_positive, tags = tag_sentence(sentence, labels)
        if is_positive:
            cnt_pos += 1
            ner_data.append(tags)
        elif any(word in sentence.lower() for word in ['data', 'study']):
            if random.random() < keep_percentage:
                ner_data.append(tags)
                cnt_neg += 1
    
    # process bar
    pbar.update(1)
    pbar.set_description(f"Training data size: {cnt_pos} positives + {cnt_neg} negatives")

# shuffling
random.shuffle(ner_data)

write data to file.

In [None]:
with open('train_ner.json', 'w') as f:
    for row in ner_data:
        words, nes = list(zip(*row))
        row_json = {'tokens' : words, 'tags' : nes}
        json.dump(row_json, f)
        f.write('\n')

# Fine-tune a BERT model for NER

In [None]:
!python ../input/kaggle-ner-utils/kaggle_run_ner.py \
--model_name_or_path 'bert-base-cased' \
--train_file './train_ner.json' \
--validation_file './train_ner.json' \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--save_steps 15000 \
--output_dir './output' \
--report_to 'none' \
--seed 123 \
--do_train 

After the tuning finishes, we should find our model in './output'.