In [None]:
# Dont forget to add all functions and connection_tokens to kaggle

https://www.analyticsvidhya.com/blog/2020/07/transfer-learning-for-nlp-fine-tuning-bert-for-text-classification/
### Model from:

https://github.com/allenai/scibert

In [1]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import re
import seaborn as sns
from tqdm import tqdm
import nltk
import random
from nltk.tokenize import word_tokenize,sent_tokenize
import pickle

train_example_names = [fn.split('.')[0] for fn in os.listdir('data/train')]
test_example_names = [fn.split('.')[0] for fn in os.listdir('data/test')]

metadata = pd.read_csv('data/train.csv')
docIdx = train_example_names.copy()

connection_tokens = {'s', 'of', 'and', 'in', 'on', 'for', 'from', 'the', 'act', 'coast', 'future', 'system', 'per', "'", ','}

## Dataset Name Selection

In [2]:
def text_cleaning(text):
    text = re.sub('[^A-Za-z]+', ' ', str(text)).strip() # remove unnecessary literals

    # remove extra spaces
    text = re.sub("\s+"," ", text)

    return text.lower().strip()

def is_name_ok(text):
    if len([c for c in text if c.isalnum()]) < 4:
        return False
    
    tokens = [t for t in text.split(' ') if len(t) > 3]
    tokens = [t for t in tokens if not t in connection_tokens]
    if len(tokens) < 3:
        return False

    return True

with open('data/all_preds_selected.csv', 'r') as f:
    selected_pred_labels = f.readlines()
    selected_pred_labels = [l.strip() for l in selected_pred_labels]

existing_labels = [text_cleaning(x) for x in metadata['dataset_label']] +\
                  [text_cleaning(x) for x in metadata['dataset_title']] +\
                  [text_cleaning(x) for x in metadata['cleaned_label']] +\
                  [text_cleaning(x) for x in selected_pred_labels]

"""to_remove = [
    'frequently asked questions', 'total maximum daily load tmd', 'health care facilities',
    'traumatic brain injury', 'north pacific high', 'droplet number concentration', 'great slave lake',
    'census block groups'
]"""


"""df = pd.read_csv(r'C:\projects\personal\kaggle\kaggle_coleridge_initiative\string_search\data\gov_data.csv')
print(len(df))


df['title'] = df.title.apply(text_cleaning)
titles = list(df.title.unique())
titles = [t for t in titles if not t in to_remove]
df = pd.DataFrame({'title': titles})
df = df.loc[df.title.apply(is_name_ok)]
df = pd.concat([df, pd.DataFrame({'title': existing_labels})], ignore_index= True).reset_index(drop = True)
titles = list(df.title.unique())
df = pd.DataFrame({'title': titles})
df['title'] = df.title.apply(text_cleaning)"""

# Sort labels by length in ascending order
#existing_labels = sorted(list(df.title.values), key = len, reverse = True)

existing_labels = list(set(existing_labels))
existing_labels = sorted(existing_labels, key = len, reverse = True)
existing_labels = [l for l in existing_labels if len(l.split(' ')) < 15]
#del df
#existing_labels.remove('adni')

print(len(existing_labels))

387


## Create dataframe for tokens and targets

In [3]:
def load_train_example_by_name(name):
    doc_path = os.path.join('data/train', name + '.json')
    with open(doc_path) as f:
        data = json.load(f)
    return data

def load_test_example_by_name(name):
    doc_path = os.path.join('data/test', name + '.json')
    with open(doc_path) as f:
        data = json.load(f)
    return data

## Make sentences

In [7]:
import unidecode

match_puncs_re = r"([.,!?()\-;\[\]+\\\/@:<>#_{}&%'*=" + r'"' + r"|])"
match_puncs_re = re.compile(match_puncs_re)

def text_cleaning_upper(text):
    text = re.sub('[^A-Za-z]+', ' ', str(text)).strip() # remove unnecessary literals

    # remove extra spaces
    text = re.sub("\s+"," ", text)

    return text.strip()

def has_connected_uppercase(tokens):
    if len(tokens) < 5:
        return False

    group_len = 0
    n_long_tokens = 0
    for token in tokens:
        token_lower = token.lower()
        if token[0].isupper():
            if token_lower not in connection_tokens:
                if len(token) > 2:
                    n_long_tokens += 1

                group_len += 1
                if group_len > 2 and n_long_tokens > 0:
                    return True

        else:
            if token_lower not in connection_tokens:
                group_len = 0
                n_long_tokens = 0

    return False

def sent_has_acronym(tokens):
    # Acronym check
    for token in tokens:
        if len(token) > 3 and token.isupper():
            return True

    return False

def sent_is_candidate(clean_sentence):
    tokens = clean_sentence.split(' ')
    
    if sent_has_acronym(tokens):
        return True
    else:
        return has_connected_uppercase(tokens)
        

In [5]:
"""pos_sentences = []
neg_sentences = []
docs_no_pos = []
total_sentences = 0



def process_doc(doc_id):
    global total_sentences
    doc_json = load_train_example_by_name(doc_id)
    doc_text = ' '.join([sec['text'] for sec in doc_json])
    doc_has_pos = False

    # Tokenize sentencewise
    sentences = sent_tokenize(doc_text)
    total_sentences += len(sentences)

    for sentence in sentences:
        clean_sentence = text_cleaning_upper(sentence)
        is_candidate = sent_is_candidate(clean_sentence)

        has_label = False
        if is_candidate:
            clean_sentence_lower = clean_sentence.lower()
            for clean_label in existing_labels:
                if clean_label in clean_sentence_lower:
                    has_label = True
                    break
        
        # Store sentence in list if candidate
        # Non-candidate sentences are discarded
        if has_label:
            pos_sentences.append(sentence)
            doc_has_pos = True
        elif is_candidate:
            neg_sentences.append(sentence)

    if not doc_has_pos:
        docs_no_pos.append(doc_id)

#process_doc('0026563b-d5b3-417d-bd25-7656b97a044f')"""

"pos_sentences = []\nneg_sentences = []\ndocs_no_pos = []\ntotal_sentences = 0\n\n\n\ndef process_doc(doc_id):\n    global total_sentences\n    doc_json = load_train_example_by_name(doc_id)\n    doc_text = ' '.join([sec['text'] for sec in doc_json])\n    doc_has_pos = False\n\n    # Tokenize sentencewise\n    sentences = sent_tokenize(doc_text)\n    total_sentences += len(sentences)\n\n    for sentence in sentences:\n        clean_sentence = text_cleaning_upper(sentence)\n        is_candidate = sent_is_candidate(clean_sentence)\n\n        has_label = False\n        if is_candidate:\n            clean_sentence_lower = clean_sentence.lower()\n            for clean_label in existing_labels:\n                if clean_label in clean_sentence_lower:\n                    has_label = True\n                    break\n        \n        # Store sentence in list if candidate\n        # Non-candidate sentences are discarded\n        if has_label:\n            pos_sentences.append(sentence)\n       

## Generate and Save Sentences

In [6]:
"""import pickle
assert len(docIdx) > 0

pos_sentences = []
neg_sentences = []
docs_no_pos = []
total_sentences = 0

pbar = tqdm(docIdx)
for doc_id in pbar:
    process_doc(doc_id)
    pbar.set_description(\
        f'pos_size: {len(pos_sentences)}, neg_size: {len(neg_sentences)}, no pos label doc: {len(docs_no_pos)}, n_sentences: {total_sentences}')

with open(f'data/bert_ner_sentences/pos.pkl', 'wb') as f:
    pickle.dump(pos_sentences, f)

with open(f'data/bert_ner_sentences/neg.pkl', 'wb') as f:
    pickle.dump(neg_sentences, f)

print(f'pos size: {len(pos_sentences)}')
print(f'neg size: {len(neg_sentences)}')"""

"import pickle\nassert len(docIdx) > 0\n\npos_sentences = []\nneg_sentences = []\ndocs_no_pos = []\ntotal_sentences = 0\n\npbar = tqdm(docIdx)\nfor doc_id in pbar:\n    process_doc(doc_id)\n    pbar.set_description(        f'pos_size: {len(pos_sentences)}, neg_size: {len(neg_sentences)}, no pos label doc: {len(docs_no_pos)}, n_sentences: {total_sentences}')\n\nwith open(f'data/bert_ner_sentences/pos.pkl', 'wb') as f:\n    pickle.dump(pos_sentences, f)\n\nwith open(f'data/bert_ner_sentences/neg.pkl', 'wb') as f:\n    pickle.dump(neg_sentences, f)\n\nprint(f'pos size: {len(pos_sentences)}')\nprint(f'neg size: {len(neg_sentences)}')"

In [7]:
#metadata.loc[metadata.Id == docs_no_pos[0]]

## Load Sentences

#### import pickle

with open(f'data/classifier_output/pos_classified.pkl', 'rb') as f:
    pos_sentences = pickle.load(f)

"""with open(f'data/bert_ner_sentences/neg.pkl', 'rb') as f:
    neg_sentences = pickle.load(f)
"""
print(f'pos size: {len(pos_sentences)}')
#print(f'neg size: {len(neg_sentences)}')

In [5]:
import pickle

with open(f'data/selected_sentences/pos.pkl', 'rb') as f:
    pos_sentences_raw = pickle.load(f)

with open(f'data/selected_sentences/neg.pkl', 'rb') as f:
    neg_sentences_raw = pickle.load(f)
    
    
pos_sentences = pos_sentences_raw + neg_sentences_raw

In [6]:
pos_sentences[:3]

['In fact, organizations are now identifying digital skills or computer literacy as one of their core values for employability (such as the US Department of Education, the US Department of commerce, the OECD Program for the International Assessment of Adult Competencies and the European Commission).',
 'International studies on student achievement, such as Trends in International Mathematics and Science Study (TIMMS) and the Programme for International Student Assessment (PISA) from past several years have documented a narrowing gap in gender differences in science and mathematics achievement (Else-Quest, Hyde, & Linn, 2010; Martin, Mullis, Foy, & Hooper, 2016; OECD, 2016) .',
 '1 manages access to results of the Agricultural Resources Management Survey (ARMS), a fundamental source of information on agricultural practices, farm businesses and farm household financials.']

In [8]:
pos_sentences_processed = []
neg_sentences_processed = []
pos_labels = []
neg_labels = []

n_broken_sent = 0
n_pos_no_label = 0

def text_cleaning_for_bert(text):
    # Keeps puncs, pads them with whitespaces
    text = text.replace('^', ' ')
    text = unidecode.unidecode(text)
    
    text = re.sub(r'\[[0-9]+]', ' SpecialReference ', text)
    
    # Remove years
    text = re.sub(r'(19|20)[0-9][0-9]', ' SpecialYear ', text)
    
    # remove other digits
    text = re.sub(r'\d+', ' ', text)
    
    # Remove websites
    text = ' '.join(['SpecialWebsite' if 'http' in t or 'www' in t else t for t in text.split(' ') ])

    text = match_puncs_re.sub(r' \1 ', text)

    # remove extra spaces
    text = re.sub("\s+"," ", text)

    return text.strip()

def convert_tokens(text):
    if is_acronym(text):
        return 'ACRONYM'
    return text

def is_acronym(text):
    if len(text) < 3:
        return False
    if text.isupper():
        return True

def is_text_broken(tokens):
    # Some texts are like 'p a dsdv a d a ds f b', remove them
    if len(tokens) == 0:
        return True

    if len(tokens) < 50:
        return False

    one_char_token_ratio = len([l for l in tokens if len(l) == 1]) / len(tokens)
    return one_char_token_ratio > 0.2

def split_to_smaller_sent(tokens, s_size, overlap_size):
    # output sentences will be s_size + overlap_size long
    small_sents = []

    if len(tokens) <= s_size:
        return [tokens]

    n_parts = len(tokens) // s_size
    if len(tokens) % s_size != 0:
        n_parts += 1

    i_part = 0
    end_i = 0
    while end_i < len(tokens):
        start_i = i_part * s_size
        if i_part > 0:
            start_i -= overlap_size

        end_i = min(len(tokens), start_i + s_size)

        small_sents.append(tokens[start_i: end_i])
        i_part += 1

    return small_sents

def join_tuple_tokens(tuples):
    return ' '.join([t[1] for t in tuples])

def get_index(lst, el):
    idx = []
    for i, lst_el in enumerate(lst):
        if el in lst_el:
            idx.append(i)

    return idx

def get_connected_uppercase(tokens):
    # Acronyms should not be a part of connected uppercase texts
    tokens = [t if not is_acronym(t) else '*****' for t in tokens ]
    if len(tokens) == 0:
        return []
    
    groups = []
    this_group_tokens = []
    in_group = False
    last_token_connection = False

    group_len = 0
    n_long_tokens = 0
    for token in tokens:
        token_lower = token.lower()
        if token[0].isupper():
            in_group = True
            
            if token_lower not in connection_tokens:
                if len(token) > 2:
                    n_long_tokens += 1
                group_len += 1
                last_token_connection = False
                this_group_tokens.append(token)
            else:
                last_token_connection = True
                
                # Prevent connection tokens to be the first
                if group_len > 0:
                    this_group_tokens.append(token)
                
        else:
            if token_lower not in connection_tokens:
                if in_group:
                    if group_len > 2 and n_long_tokens > 0:
                        if last_token_connection:
                            this_group_tokens = this_group_tokens[:-1]
                        groups.append(this_group_tokens)
                    this_group_tokens = []
                
                last_token_connection = False
                group_len = 0
                n_long_tokens = 0
                in_group = False
                
            elif in_group:
                last_token_connection = True
                this_group_tokens.append(token)
                
                
    if in_group:
        if group_len > 2 and n_long_tokens > 0:
            if last_token_connection:
                this_group_tokens = this_group_tokens[:-1]
            groups.append(this_group_tokens)
    
    return groups

def index_list_in_list(lst, search_lst):
    for i_start in range(len(lst) - len(search_lst) + 1):
        i_end = i_start + len(search_lst)
        if lst[i_start:i_end] == search_lst:
            return i_start
        
    raise ValueError(f'{lst} not found in {search_lst[:50]}')

def convert_sentence(tokens, labels):
    connected_uppercase = get_connected_uppercase(tokens)
    tokens = [t if not is_acronym(t) else 'ACRONYM' for t in tokens]
    
    """for conn_tokens in connected_uppercase:
        i_start = index_list_in_list(tokens, conn_tokens)
        i_end = i_start + len(conn_tokens)
        
        tokens = tokens[:i_start] + ['UPPERCASEENTITY'] + tokens[i_end:]
        labels = labels[:i_start] + ['B'] + labels[i_end:]"""
        
    return tokens, labels

def process_pos_sentence(sentence):
    global n_broken_sent
    global last_doc_labels

    bert_sentence = text_cleaning_for_bert(sentence)
    label_sentence = text_cleaning_upper(bert_sentence).lower()

    if is_text_broken(label_sentence.split(' ')): # Can't use bert cleaning for this, because all punc.s are padded with spaces
        n_broken_sent += 1
        return
    
    bert_tokens = bert_sentence.split(' ')
    ### STEP 1: Split into fixed sized sentences ###
    for small_sentence_tokens in split_to_smaller_sent(bert_tokens, s_size = 64, overlap_size = 20):

        small_bert_sentence = ' '.join(small_sentence_tokens)

        # Need to remove punc.s and uppercase letters to find labels
        small_label_sentence = text_cleaning(small_bert_sentence)

        has_label = False
        sent_labels = []
        ### STEP 2: Match labels ###
        # Check if contains labels
        for clean_label in existing_labels:
            if re.search(r'\b{}\b'.format(clean_label), small_label_sentence):
                has_label = True

                # Remove label from the text, to only match the largest label
                small_label_sentence = small_label_sentence.replace(clean_label, '')
                sent_labels.append(clean_label)

        small_sent_targets = ['O' for _ in range(len(small_sentence_tokens))]

        if has_label:
            # Tokenize labels for matching
            sent_label_tokens = [l.split(' ') for l in sent_labels]

            # Get index, token tuples for clean tokens. Indices are for raw tokens
            small_sent_tuples = [(i, token.lower()) for i, token in enumerate(small_sentence_tokens) if text_cleaning_upper(token) != '']

            ### STEP 3: Set corresponding targets for each label ###
            # Target: (B, I, O), Label: adni
            for l in sent_labels:
                l_tokens = l.split(' ')
                small_sent_joined = [join_tuple_tokens(small_sent_tuples[i: i + len(l_tokens)]) for i in range(len(small_sent_tuples) - len(l_tokens) + 1)]

                label_start_idx = get_index(small_sent_joined, l) # list of indices
                for label_start_i in label_start_idx:
                    label_end_i = label_start_i + len(l_tokens) - 1

                    target_start_i = small_sent_tuples[label_start_i][0]
                    target_end_i = small_sent_tuples[label_end_i][0]

                    # Do not use the same tokens for multiple labels
                    #small_sent_tuples = small_sent_tuples[:label_start_i] + small_sent_tuples[label_end_i:]

                    try:
                        if small_sent_targets[target_start_i] == 'O': # If not was already labeled
                            small_sent_targets[target_start_i] = 'B'
                            if target_end_i - target_start_i > 0:
                                for i in range(target_start_i+1, target_end_i+1):
                                    small_sent_targets[i] = 'I'

                    except Exception as e:
                        print('DEBUG')
                        print(small_sentence_tokens)
                        print(len(small_sentence_tokens))
                        print(len(small_sent_targets))
                        print(target_start_i)
                        print(small_sent_joined)
                        print('DEBUG')
                        raise e
        
        ### STEP 4: Add sentence output to lists ###
        small_sentence_tokens, small_sent_targets = convert_sentence(small_sentence_tokens, small_sent_targets)
        if has_label:
            pos_sentences_processed.append(small_sentence_tokens)
            pos_labels.append(small_sent_targets)
        else:
            neg_sentences_processed.append(small_sentence_tokens)
            neg_labels.append(small_sent_targets)

"""def process_neg_sentence(sentence):
    global n_broken_sent
    
    bert_sentence = text_cleaning_upper(sentence)
    label_sentence = bert_sentence.lower()

    if is_text_broken(label_sentence.split(' ')): # Can't use bert cleaning for this, because all punc.s are padded with spaces
        n_broken_sent += 1
        return

    bert_tokens = bert_sentence.split(' ')
    
    ### STEP 1: Split into fixed sized sentences ###
    for small_sentence_tokens in split_to_smaller_sent(bert_tokens, s_size = 64, overlap_size = 20):
        small_sent_targets = ['O' for _ in range(len(bert_tokens))]

        neg_sentences_processed.append([convert_tokens(t) for t in small_sentence_tokens])
        neg_labels.append(small_sent_targets)"""

#process_pos_sentence(pos_sentences[2472])

"def process_neg_sentence(sentence):\n    global n_broken_sent\n    \n    bert_sentence = text_cleaning_upper(sentence)\n    label_sentence = bert_sentence.lower()\n\n    if is_text_broken(label_sentence.split(' ')): # Can't use bert cleaning for this, because all punc.s are padded with spaces\n        n_broken_sent += 1\n        return\n\n    bert_tokens = bert_sentence.split(' ')\n    \n    ### STEP 1: Split into fixed sized sentences ###\n    for small_sentence_tokens in split_to_smaller_sent(bert_tokens, s_size = 64, overlap_size = 20):\n        small_sent_targets = ['O' for _ in range(len(bert_tokens))]\n\n        neg_sentences_processed.append([convert_tokens(t) for t in small_sentence_tokens])\n        neg_labels.append(small_sent_targets)"

In [8]:
text_cleaning_for_bert(pos_sentences[100])

'During the biomass study , the tide was coming in moving from a low tide at : GMT with a water level of . m to high tide at : GMT with a water level of . m ( NOAA tide gauge Atlantic City , NJ ) .'

## Create NER Dataset and Save

In [11]:
assert len(pos_sentences) > 0

pos_sentences_processed = []
neg_sentences_processed = []
pos_labels = []
neg_labels = []

n_pos_no_label = 0
n_broken_sent = 0

for sent in tqdm(pos_sentences):
    process_pos_sentence(sent)

"""for sent in tqdm(neg_sentences):
        process_neg_sentence(sent)"""

import pickle

with open(f'data/bert_ner_processed/pos.pkl', 'wb') as f:
    pickle.dump(pos_sentences_processed, f)

with open(f'data/bert_ner_processed/neg.pkl', 'wb') as f:
    pickle.dump(neg_sentences_processed, f)

with open(f'data/bert_ner_processed/pos_labels.pkl', 'wb') as f:
    pickle.dump(pos_labels, f)

with open(f'data/bert_ner_processed/neg_labels.pkl', 'wb') as f:
    pickle.dump(neg_labels, f)


print('')
print(f'broken sentences: {n_broken_sent}')
print(f'n_pos_no_label: {n_pos_no_label}')
print(f'pos_proc size: {len(pos_sentences_processed)}')
print(f'neg_proc size: {len(neg_sentences_processed)}')

100%|████████████████████████████████████████████████████████████████████████| 971668/971668 [03:01<00:00, 5354.71it/s]



broken sentences: 4032
n_pos_no_label: 0
pos_proc size: 264346
neg_proc size: 821066


In [12]:
pos_sentences_processed[0]

['In',
 'fact',
 ',',
 'organizations',
 'are',
 'now',
 'identifying',
 'digital',
 'skills',
 'or',
 'computer',
 'literacy',
 'as',
 'one',
 'of',
 'their',
 'core',
 'values',
 'for',
 'employability',
 '(',
 'such',
 'as',
 'the',
 'US',
 'Department',
 'of',
 'Education',
 ',',
 'the',
 'US',
 'Department',
 'of',
 'commerce',
 ',',
 'the',
 'ACRONYM',
 'Program',
 'for',
 'the',
 'International',
 'Assessment',
 'of',
 'Adult',
 'Competencies',
 'and',
 'the',
 'European',
 'Commission',
 ')',
 '.']

In [13]:
i_ex = 0
pd.DataFrame({'token': pos_sentences_processed[i_ex], 'label': pos_labels[i_ex]})

Unnamed: 0,token,label
0,In,O
1,fact,O
2,",",O
3,organizations,O
4,are,O
5,now,O
6,identifying,O
7,digital,O
8,skills,O
9,or,O


## Load NER Dataset

In [11]:
import pickle

with open(f'data/bert_ner_processed/pos.pkl', 'rb') as f:
    pos_sentences_processed = pickle.load(f)

with open(f'data/bert_ner_processed/neg.pkl', 'rb') as f:
    neg_sentences_processed = pickle.load(f)

with open(f'data/bert_ner_processed/pos_labels.pkl', 'rb') as f:
    pos_labels = pickle.load(f)

with open(f'data/bert_ner_processed/neg_labels.pkl', 'rb') as f:
    neg_labels = pickle.load(f)

print(f'pos size: {len(pos_sentences_processed)}')
print(f'neg size: {len(neg_sentences_processed)}')
print(f'pos label size: {len(pos_labels)}')
print(f'neg label size: {len(neg_labels)}')

pos size: 264346
neg size: 821066
pos label size: 264346
neg label size: 821066


## Augmentation

In [12]:
import pickle

with open(f'data/gov_data_selected.pkl', 'rb') as f:
    unique_gov_names = pickle.load(f)
    
unique_gov_names += existing_labels

def replace_target(x, lst):
    if x['label'].iloc[0] == 'O':
        # if not a dataset name, do not augment
        lst.append(x)
    else:
        random_name = random.choice(unique_gov_names)
        random_name_tokens = random_name.split(' ')
        if random_name.islower():
            if len(random_name_tokens) > 1:
                random_name_tokens = [t for t in random_name_tokens if len(t) > 0]
                random_name_tokens = [r[0].upper() + r[1:] if not r.lower() in connection_tokens else r for r in random_name_tokens]
            else:
                random_name_tokens = ['ACRONYM']
        else:
            random_name_tokens = [t if not is_acronym(t) else 'ACRONYM' for t in random_name_tokens]

        new_x = pd.DataFrame()
        # Replace tokens
        new_x['token'] = random_name_tokens
        new_x['label'] = 'I'
        new_x.loc[new_x.index == 0, 'label'] = 'B'
        lst.append(new_x)

def augment_sentence(tokens, labels, augment_chance = 0.8):
    if random.uniform(0,1) > augment_chance:
        # No augmentation
        return tokens, labels

    df_pieces = []
    sent_df = pd.DataFrame({'token': tokens, 'label': labels})
    sent_df['label_o'] = sent_df.label == 'O'

    gb = sent_df.groupby((sent_df['label_o'].shift() != sent_df['label_o']).cumsum())
    for name, group in gb:
        replace_target(group, df_pieces)

    sent_df = pd.concat(df_pieces, ignore_index = True, axis = 0)

    return list(sent_df.token.values), list(sent_df.label.values)

In [13]:
pos_sentences_processed_aug = []
pos_labels_aug = []

for _ in range(5):
    for s_tokens, s_labels in tqdm(zip(pos_sentences_processed, pos_labels), total = len(pos_labels)):
        aug_tokens, aug_labels = augment_sentence(s_tokens, s_labels)
        pos_sentences_processed_aug.append(aug_tokens)
        pos_labels_aug.append(aug_labels)

pos_sentences_processed = pos_sentences_processed_aug
pos_labels = pos_labels_aug

100%|█████████████████████████████████████████████████████████████████████████| 264346/264346 [16:20<00:00, 269.68it/s]
100%|█████████████████████████████████████████████████████████████████████████| 264346/264346 [14:57<00:00, 294.57it/s]
100%|█████████████████████████████████████████████████████████████████████████| 264346/264346 [14:39<00:00, 300.47it/s]
100%|█████████████████████████████████████████████████████████████████████████| 264346/264346 [14:37<00:00, 301.17it/s]
100%|█████████████████████████████████████████████████████████████████████████| 264346/264346 [14:36<00:00, 301.62it/s]


In [14]:
i_ex = 1001
pd.DataFrame({'token': pos_sentences_processed[i_ex], 'label': pos_labels[i_ex]})


Unnamed: 0,token,label
0,The,O
1,Plant,B
2,Wall,I
3,Degradative,I
4,Compounds,I
...,...,...
56,-,O
57,year,O
58,publicprivate,O
59,partnership,O


## Create Training Data

In [15]:
from sklearn.model_selection import train_test_split
import numpy as np

neg_size = 500000
neg_idx = np.random.permutation(len(neg_labels))
neg_sentences_processed = [neg_sentences_processed[i] for i in neg_idx[:neg_size]]
neg_labels = [neg_labels[i] for i in neg_idx[:neg_size]]

sentences = pos_sentences_processed + neg_sentences_processed
labels = pos_labels + neg_labels

idx = np.random.permutation(len(sentences))
sentences = [sentences[i] for i in idx]
labels = [labels[i] for i in idx]

In [16]:
with open(f'data/bert_ner_data/train_sentences.pkl', 'wb') as f:
    pickle.dump(sentences, f)

with open(f'data/bert_ner_data/train_labels.pkl', 'wb') as f:
    pickle.dump(labels, f)
    
print(len(sentences))

In [17]:
print(len(sentences))

1821730


## Load Training Data

In [2]:
import pickle

with open(f'data/bert_ner_data/train_sentences.pkl', 'rb') as f:
    sentences = pickle.load(f)

with open(f'data/bert_ner_data/train_labels.pkl', 'rb') as f:
    labels = pickle.load(f)
    
print(len(sentences))

In [18]:
sentences = sentences[:1500000]
labels = labels[:1500000]

## Fine Tune Bert

In [19]:
import os
import math
import random
import csv
import sys
import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.metrics import classification_report
import statistics as stats
from bert_sklearn import BertTokenClassifier 

In [20]:
model = BertTokenClassifier(
    bert_model='scibert-scivocab-cased',
    num_mlp_hiddens= 500,
    max_seq_length=64, 
    epochs=1,
    #gradient accumulation
    gradient_accumulation_steps=4,
    learning_rate=3e-5,
    train_batch_size=8,#batch size for training
    eval_batch_size=8, #batch size for evaluation
    validation_fraction=0.0, 
    #ignore the tokens with label ‘O’
    ignore_label=['O']
)

Building sklearn token classifier...


In [None]:
model.fit(sentences, labels)

  return np.array(X)


Loading scibert-scivocab-cased model...
Defaulting to linear classifier/regressor
Loading Pytorch checkpoint
train data size: 1500000, validation data size: 0


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:1005.)
  next_m.mul_(beta1).add_(1 - beta1, grad)
Training  :  34%|██████████████▊                             | 253202/750000 [7:34:48<16:37:32,  8.30it/s, loss=0.0286]

In [None]:
# save model to disk
savefile='data/sklearn_bert_ner_cased_all_sents.bin'
model.save(savefile)