In [1]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import re
import seaborn as sns
from tqdm import tqdm
import nltk
import random
from nltk.tokenize import word_tokenize,sent_tokenize
import pickle

train_example_names = [fn.split('.')[0] for fn in os.listdir('data/train')]
test_example_names = [fn.split('.')[0] for fn in os.listdir('data/test')]

metadata = pd.read_csv('data/train.csv')
docIdx = train_example_names.copy()

connection_tokens = {'s', 'of', 'and', 'in', 'on', 'for', 'from', 'the', 'act', 'coast', 'future', 'system', 'per'}

In [2]:
def text_cleaning(text):
    text = re.sub('[^A-Za-z]+', ' ', str(text)).strip() # remove unnecessary literals

    # remove extra spaces
    text = re.sub("\s+"," ", text)

    return text.lower().strip()

def is_name_ok(text):
    if len([c for c in text if c.isalnum()]) < 4:
        return False
    
    tokens = [t for t in text.split(' ') if len(t) > 3]
    tokens = [t for t in tokens if not t in connection_tokens]
    if len(tokens) < 3:
        return False

    return True
    
with open('data/all_preds_selected.txt', 'r') as f:
    selected_pred_labels = f.readlines()
    selected_pred_labels = [l.strip() for l in selected_pred_labels]

existing_labels = [text_cleaning(x) for x in metadata['dataset_label']] +\
                  [text_cleaning(x) for x in metadata['dataset_title']] +\
                  [text_cleaning(x) for x in metadata['cleaned_label']] +\
                  [text_cleaning(x) for x in selected_pred_labels]

to_remove = [
    'frequently asked questions', 'total maximum daily load tmd', 'health care facilities',
    'traumatic brain injury', 'north pacific high', 'droplet number concentration', 'great slave lake',
    'census block groups'
]

df = pd.read_csv(r'C:\projects\personal\kaggle\kaggle_coleridge_initiative\string_search\data\gov_data.csv')
print(len(df))


df['title'] = df.title.apply(text_cleaning)
titles = list(df.title.unique())
titles = [t for t in titles if not t in to_remove]
df = pd.DataFrame({'title': titles})
df = df.loc[df.title.apply(is_name_ok)]
df = pd.concat([df, pd.DataFrame({'title': existing_labels})], ignore_index= True).reset_index(drop = True)
titles = list(df.title.unique())
df = pd.DataFrame({'title': titles})
df['title'] = df.title.apply(text_cleaning)

# Sort labels by length in ascending order
existing_labels = sorted(list(df.title.values), key = len, reverse = True)
existing_labels = [l for l in existing_labels if len(l.split(' ')) < 10]
del df

print(len(existing_labels))

291984
60188


## Load sentences
- pos: has dataset name

In [22]:
import pickle

with open(f'data/bert_ner_sentences/pos.pkl', 'rb') as f:
    pos_sentences = pickle.load(f)

with open(f'data/bert_ner_sentences/neg.pkl', 'rb') as f:
    neg_sentences = pickle.load(f)

"""pos_sentences = [text_cleaning(s) for s in pos_sentences_raw]
neg_sentences = [text_cleaning(s) for s in neg_sentences_raw]"""

print(f'pos size: {len(pos_sentences)}')
print(f'neg size: {len(neg_sentences)}')

pos size: 69267
neg size: 902401


## Preprocessing Functions

In [24]:
pos_sentences_processed = []
neg_sentences_processed = []
pos_labels = []
neg_labels = []

n_broken_sent = 0
n_pos_no_label = 0

def text_cleaning_upper(text):
    text = re.sub('[^A-Za-z]+', ' ', str(text)).strip() # remove unnecessary literals

    # remove extra spaces
    text = re.sub("\s+"," ", text)

    return text.strip()

def convert_tokens(text):
    if is_acronym(text):
        return 'ACRONYM'
    return text

def is_acronym(text):
    if len(text) < 4:
        return False
    if text.isupper():
        return True

def is_text_broken(tokens):
    # Some texts are like 'p a dsdv a d a ds f b', remove them
    if len(tokens) == 0:
        return True

    if len(tokens) < 50:
        return False

    one_char_token_ratio = len([l for l in tokens if len(l) == 1]) / len(tokens)
    return one_char_token_ratio > 0.2

def split_to_smaller_sent(tokens, s_size, overlap_size):
    # output sentences will be s_size + overlap_size long
    small_sents = []

    if len(tokens) <= s_size:
        return [tokens]

    n_parts = len(tokens) // s_size
    if len(tokens) % s_size != 0:
        n_parts += 1

    i_part = 0
    end_i = 0
    while end_i < len(tokens):
        start_i = i_part * s_size
        if i_part > 0:
            start_i -= overlap_size

        end_i = min(len(tokens), start_i + s_size)

        small_sents.append(tokens[start_i: end_i])
        i_part += 1

    return small_sents

def join_tuple_tokens(tuples):
    return ' '.join([t[1] for t in tuples])

def get_index(lst, el):
    idx = []
    for i, lst_el in enumerate(lst):
        if el in lst_el:
            idx.append(i)

    return idx

def process_pos_sentence(sentence):
    global n_broken_sent
    global last_doc_labels

    bert_sentence = text_cleaning_upper(sentence)
    label_sentence = bert_sentence.lower()

    if is_text_broken(label_sentence.split(' ')): # Can't use bert cleaning for this, because all punc.s are padded with spaces
        n_broken_sent += 1
        return
    
    bert_tokens = bert_sentence.split(' ')
    ### STEP 1: Split into fixed sized sentences ###
    for small_sentence_tokens in split_to_smaller_sent(bert_tokens, s_size = 64, overlap_size = 20):

        small_bert_sentence = ' '.join(small_sentence_tokens)

        # Need to remove punc.s and uppercase letters to find labels
        small_label_sentence = small_bert_sentence.lower()

        has_label = False
        sent_labels = []
        ### STEP 2: Match labels ###
        # Check if contains labels
        for clean_label in existing_labels:
            if clean_label in small_label_sentence:
                has_label = True

                # Remove label from the text, to only match the largest label
                small_label_sentence = small_label_sentence.replace(clean_label, '')
                sent_labels.append(clean_label)

        small_sent_targets = ['O' for _ in range(len(small_sentence_tokens))]

        if has_label:
            # Tokenize labels for matching
            sent_label_tokens = [l.split(' ') for l in sent_labels]

            # Get index, token tuples for clean tokens. Indices are for raw tokens
            small_sent_tuples = [(i, token.lower()) for i, token in enumerate(small_sentence_tokens) if text_cleaning_upper(token) != '']

            ### STEP 3: Set corresponding targets for each label ###
            # Target: (B, I, O), Label: adni
            for l in sent_labels:
                l_tokens = l.split(' ')
                small_sent_joined = [join_tuple_tokens(small_sent_tuples[i: i + len(l_tokens)]) for i in range(len(small_sent_tuples) - len(l_tokens) + 1)]

                label_start_idx = get_index(small_sent_joined, l) # list of indices
                for label_start_i in label_start_idx:
                    label_end_i = label_start_i + len(l_tokens) - 1

                    target_start_i = small_sent_tuples[label_start_i][0]
                    target_end_i = small_sent_tuples[label_end_i][0]

                    # Do not use the same tokens for multiple labels
                    #small_sent_tuples = small_sent_tuples[:label_start_i] + small_sent_tuples[label_end_i:]

                    try:
                        if small_sent_targets[target_start_i] == 'O': # If not was already labeled
                            small_sent_targets[target_start_i] = 'B'
                            if target_end_i - target_start_i > 0:
                                for i in range(target_start_i+1, target_end_i+1):
                                    small_sent_targets[i] = 'I'

                    except Exception as e:
                        print('DEBUG')
                        print(small_sentence_tokens)
                        print(len(small_sentence_tokens))
                        print(len(small_sent_targets))
                        print(target_start_i)
                        print(small_sent_joined)
                        print('DEBUG')
                        raise e
        
        ### STEP 4: Add sentence output to lists ###
        #assert has_label
        if has_label:
            pos_sentences_processed.append([convert_tokens(t) for t in small_sentence_tokens])
            pos_labels.append(small_sent_targets)
        else:
            neg_sentences_processed.append([convert_tokens(t) for t in small_sentence_tokens])
            neg_labels.append(small_sent_targets)

def process_neg_sentence(sentence):
    global n_broken_sent
    
    bert_sentence = text_cleaning_upper(sentence)
    label_sentence = bert_sentence.lower()

    if is_text_broken(label_sentence.split(' ')): # Can't use bert cleaning for this, because all punc.s are padded with spaces
        n_broken_sent += 1
        return

    bert_tokens = bert_sentence.split(' ')
    
    ### STEP 1: Split into fixed sized sentences ###
    for small_sentence_tokens in split_to_smaller_sent(bert_tokens, s_size = 64, overlap_size = 20):
        small_sent_targets = ['O' for _ in range(len(bert_tokens))]

        neg_sentences_processed.append([convert_tokens(t) for t in small_sentence_tokens])
        neg_labels.append(small_sent_targets)

## Process Data

In [25]:
assert len(pos_sentences) > 0

pos_sentences_processed = []
neg_sentences_processed = []
pos_labels = []
neg_labels = []

n_pos_no_label = 0
n_broken_sent = 0

for sent in tqdm(pos_sentences):
    process_pos_sentence(sent)

for sent in tqdm(neg_sentences):
    process_neg_sentence(sent)

import pickle

with open(f'data/bert_ner_data/pos.pkl', 'wb') as f:
    pickle.dump(pos_sentences_processed, f)

with open(f'data/bert_ner_data/neg.pkl', 'wb') as f:
    pickle.dump(neg_sentences_processed, f)

with open(f'data/bert_ner_data/pos_labels.pkl', 'wb') as f:
    pickle.dump(pos_labels, f)

with open(f'data/bert_ner_data/neg_labels.pkl', 'wb') as f:
    pickle.dump(neg_labels, f)


print('')
print(f'broken sentences: {n_broken_sent}')
print(f'n_pos_no_label: {n_pos_no_label}')
print(f'pos_proc size: {len(pos_sentences_processed)}')
print(f'neg_proc size: {len(neg_sentences_processed)}')

100%|████████████████████████████████████████████████████████████████████████████| 69267/69267 [11:47<00:00, 97.90it/s]
100%|███████████████████████████████████████████████████████████████████████| 902401/902401 [00:32<00:00, 28021.11it/s]



broken sentences: 3970
n_pos_no_label: 0
pos_proc size: 69983
neg_proc size: 933496


## Load Processed Data

In [3]:
import pickle

with open(f'data/bert_ner_data/pos.pkl', 'rb') as f:
    pos_sentences_processed = pickle.load(f)

with open(f'data/bert_ner_data/neg.pkl', 'rb') as f:
    neg_sentences_processed = pickle.load(f)

with open(f'data/bert_ner_data/pos_labels.pkl', 'rb') as f:
    pos_labels = pickle.load(f)

with open(f'data/bert_ner_data/neg_labels.pkl', 'rb') as f:
    neg_labels = pickle.load(f)

print(f'pos size: {len(pos_sentences_processed)}')
print(f'neg size: {len(neg_sentences_processed)}')
print(f'pos label size: {len(pos_labels)}')
print(f'neg label size: {len(neg_labels)}')

pos size: 69983
neg size: 933496
pos label size: 69983
neg label size: 933496


## Augmentation

In [4]:
def replace_target(x, lst):
    if x['label'].iloc[0] == 'O':
        # if not a dataset name, do not augment
        lst.append(x)
    else:
        random_name_tokens = random.choice(existing_labels).split(' ')
        random_name_tokens = [r[0].upper() + r[1:] if not r.lower() in connection_tokens else r for r in random_name_tokens]

        new_x = pd.DataFrame()
        # Replace tokens
        new_x['token'] = random_name_tokens
        new_x['label'] = 'I'
        new_x.loc[new_x.index == 0, 'label'] = 'B'
        lst.append(new_x)

def augment_sentence(tokens, labels, augment_chance = 0.8):
    if random.uniform(0,1) > augment_chance:
        # No augmentation
        return tokens, labels

    df_pieces = []
    sent_df = pd.DataFrame({'token': tokens, 'label': labels})
    sent_df['label_o'] = sent_df.label == 'O'

    gb = sent_df.groupby((sent_df['label_o'].shift() != sent_df['label_o']).cumsum())
    for name, group in gb:
        replace_target(group, df_pieces)

    sent_df = pd.concat(df_pieces, ignore_index = True, axis = 0)

    return list(sent_df.token.values), list(sent_df.label.values)

In [5]:
pos_sentences_processed_aug = []
pos_labels_aug = []

for _ in range(5):
    for s_tokens, s_labels in tqdm(zip(pos_sentences_processed, pos_labels), total = len(pos_labels)):
        aug_tokens, aug_labels = augment_sentence(s_tokens, s_labels)
        pos_sentences_processed_aug.append(aug_tokens)
        pos_labels_aug.append(aug_labels)

pos_sentences_processed = pos_sentences_processed_aug
pos_sentences_processed = [' '.join(sent_tokens) for sent_tokens in pos_sentences_processed]
neg_sentences_processed = [' '.join(sent_tokens) for sent_tokens in neg_sentences_processed]
pos_labels = pos_labels_aug
pos_labels = [1 for _ in pos_labels]
neg_labels = [0 for _ in neg_labels]

100%|███████████████████████████████████████████████████████████████████████████| 69983/69983 [03:39<00:00, 319.05it/s]
100%|███████████████████████████████████████████████████████████████████████████| 69983/69983 [03:37<00:00, 322.22it/s]
100%|███████████████████████████████████████████████████████████████████████████| 69983/69983 [03:37<00:00, 321.51it/s]
100%|███████████████████████████████████████████████████████████████████████████| 69983/69983 [03:37<00:00, 322.47it/s]
100%|███████████████████████████████████████████████████████████████████████████| 69983/69983 [03:34<00:00, 325.70it/s]


In [6]:
from sklearn.model_selection import train_test_split
import numpy as np

neg_size = 350000
neg_idx = np.random.permutation(len(neg_labels))
neg_sentences_processed = [neg_sentences_processed[i] for i in neg_idx[:neg_size]]
neg_labels = [neg_labels[i] for i in neg_idx[:neg_size]]

sentences = pos_sentences_processed + neg_sentences_processed
labels = pos_labels + neg_labels

assert len(sentences) == len(labels)

idx = np.random.permutation(len(sentences))
sentences = [sentences[i] for i in idx]
labels = [labels[i] for i in idx]

with open(f'data/bert_ner_data/train_sentences.pkl', 'wb') as f:
    pickle.dump(sentences, f)

with open(f'data/bert_ner_data/train_labels.pkl', 'wb') as f:
    pickle.dump(labels, f)

## Load Training Data

In [7]:
import pickle

with open(f'data/bert_ner_data/train_sentences.pkl', 'rb') as f:
    sentences = pickle.load(f)

with open(f'data/bert_ner_data/train_labels.pkl', 'rb') as f:
    labels = pickle.load(f)
    
    
SENTENCE_TOKEN_SIZE = 20

def shorten_sentence(text):
    tokens = text.split(' ')
    return ' '.join(tokens[:SENTENCE_TOKEN_SIZE])

sentences = [shorten_sentence(s) for s in sentences]

## Training

In [9]:
import os
import math
import random
import csv
import sys
import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.metrics import classification_report
import statistics as stats
from bert_sklearn import BertClassifier

model = BertClassifier(bert_model='scibert-scivocab-cased',
                        validation_fraction= 0.15,
                        max_seq_length=64,
                        train_batch_size=4,
                        warmup_proportion=0.1,
                        gradient_accumulation_steps=1,
                        epochs = 1
                        )

Building sklearn text classifier...


In [11]:
pd.Series(sentences).iloc[400]

'In this study we used ACRONYM data that were previously collected across sites'

In [12]:
model.fit(pd.Series(sentences), pd.Series(labels))

Loading scibert-scivocab-cased model...
Defaulting to linear classifier/regressor
Loading Pytorch checkpoint
train data size: 594928, validation data size: 104987


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:1005.)
  next_m.mul_(beta1).add_(1 - beta1, grad)
Training  : 100%|████████████████████████████████████████████████| 148732/148732 [8:00:58<00:00,  5.15it/s, loss=0.441]
Validating: 100%|████████████████████████████████████████████████████████████████| 13124/13124 [12:10<00:00, 17.97it/s]

Epoch 1, Train loss: 0.4411, Val loss: 0.4821, Val accy: 88.24%





BertClassifier(bert_model='scibert-scivocab-cased', do_lower_case=False,
               epochs=1, label_list=array([0, 1], dtype=int64),
               max_seq_length=64, train_batch_size=4, validation_fraction=0.15)

In [13]:
# save model to disk
savefile='data/sklearn_bert_classification.bin'
model.save(savefile)

In [None]:
d