In [None]:
# For better error messages when 'device side assert'
!export CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [307]:
import numpy as np # linear algebra
import pandas as pd
import json
import re
import os
import pandas as pd
import pickle
import math
from tqdm import tqdm,trange
from keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import Progbar
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score,accuracy_score
import torch
from transformers import BertForTokenClassification, AdamW, BertTokenizerFast, AutoModelForTokenClassification, AutoTokenizer
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
#from rich.console import Console
#from rich.progress import track
from tqdm import tqdm
from transformers import BertTokenizerFast,BertForTokenClassification

import random

In [308]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print(f'Using {device} device')

Using cpu device


In [309]:
def timer(func):
    """
    Record execution time of any function with timer decorator
    Usage: just decorate a function when building it, the 
    decorator will be called every time the function is executed.
    # build the function
    @timer
    def some_function(some_arg):
        # do_something
        return 'foo'
        
    # call it
    some_function('boo')
    # output:
    >> Function 'some_function' finished after 0.01 seconds.
    """

    def wrapper(*args, **kwargs):
        start = time.time()
        results = func(*args, **kwargs)
        duration = time.time() - start
        print("Function '{}' finished after {:.4f} seconds."\
              .format(func.__name__, duration))
        return results
    return wrapper

# Data Preparation

In [310]:
class Pipeline:

    def __init__(self, configs):
        self.configs = configs
        self.tokenizer = None

    def set_tokenizer(self, tokenizer):
        """
        Set a custom tokenizer to be used when running the pipeline.
        If not, this will default to BERTTokenizerFast in `run()`
        """
        self.tokenizer = tokenizer

    def tokenize_and_preserve_labels(self, tupled_sentence):
        tokenized_sentence = []
        labels = []

        for (word, label) in tupled_sentence:

            # Tokenize the word
            tokenized_word = self.tokenizer.tokenize(word)
            tokenized_sentence.extend(tokenized_word)
            
            # Repeat the label for words that are broken up into several tokens
            labels.extend([label]*len(tokenized_word))
            
        # Add the tokenized word and its label to the final tokenized word list
        return list(zip(tokenized_sentence, labels))

    def add_start_end_tokens(self, tupled_sentence):
        tupled_sentence.insert(0, ('[CLS]', 'O'))
        tupled_sentence.append(('[SEP]', 'O'))
        return tupled_sentence

    def add_padding(self, tokenized_sentences, labels):
        # Note that this implicitly converts to an array of objects (strings)
        
        padded_sentences = pad_sequences(
            tokenized_sentences, 
            value='[PAD]', 
            dtype=object, 
            maxlen=self.configs.MAX_LENGTH, 
            truncating='post', 
            padding='post')

        padded_labels = pad_sequences(
            labels, 
            value='O', 
            dtype=object, 
            maxlen=self.configs.MAX_LENGTH, 
            truncating='post', 
            padding='post')
        
        return padded_sentences, padded_labels

    def get_attention_mask(self, input_ids, ignore_tokens=[0,101,102]):
        """
        Compute the attention marks for the tokens in `input_ids`, which is
        assumed to be a list (batch) of lists (sentences) of integer tokens.
        Tokens that should be masked out can be specified using the 
        `ignore_tokens` parameter. By default, these are supposed to be 0, 101,
        and 102, representing [PAD], [CLS], and [SEP] tokens, respectively.
        """

        return [
            [ float(token not in ignore_tokens) for token in sent ] 
                for sent in input_ids
        ]

    def run(self, ner_data):
        """
        Run extracted sentence data through the pipeline.
        """

        console = Console()

        # Initialize tokenizer
        if not self.tokenizer:
            self.tokenizer = BertTokenizerFast.from_pretrained(
                'bert-base-cased', do_lower_case=False)
            console.log('Initialized default BERT tokenizer')
        else:
            console.log('Using custom tokenizer')

        # Tokenize into known tokens
        ner_data = [
            self.tokenize_and_preserve_labels(sentence) for sentence in 
                track(ner_data, description='Tokenizing words...')
        ]
        console.log('Tokenized words')

        with console.status("[bold green]Running pipeline...") as status:

            # Add [CLS] and [SEP] tokens to beginning and end
            ner_data = [
                self.add_start_end_tokens(sentence)
                    for sentence in ner_data
            ]
            console.log('Added [CLS] and [SEP] tokens')

            # Get only sentences, not labels
            tokenized_sentences = [
                [token_label_tuple[0] for token_label_tuple in sent]
                    for sent in ner_data
            ]

            # Get only labels, not sentences
            labels = [
                [token_label_tuple[1] for token_label_tuple in sent] 
                    for sent in ner_data 
            ]

            # Pad sentences and labels 
            padded_sentences, padded_labels = self.add_padding(
                tokenized_sentences, labels)
            console.log('Padded sentences and labels')

            # Convert to integer ids
            input_ids = [
                self.tokenizer.convert_tokens_to_ids(text) 
                    for text in padded_sentences
            ]
            tags = [
                self.tokenizer.convert_tokens_to_ids(text) 
                    for text in padded_labels
            ]
            console.log('Converted to integer ids')

            # Compute attention mask from input tokens
            attention_mask = self.get_attention_mask(
                input_ids,
                # Only ignore [PAD] tokens (integer 0)
                ignore_tokens=[0]
            )
            
            console.log('Computed attention mask')

        if self.configs.SAVE:
            ParseUtils.save_file(
                {
                    'input_ids': input_ids, 
                    'tags': tags,
                    'attention_mask': attention_mask
                },
                self.configs.DATA_PATH,
                self.configs.TOKENIZED_FILENAME
            )

        return input_ids, tags, attention_mask

    def load_outputs(self):
        """
        Recover the outputs of a previously completed run from storage.
        """
        output_dict = ParseUtils.load_file(
            self.configs.DATA_PATH,
            self.configs.TOKENIZED_FILENAME,
        )

        return output_dict['input_ids'], \
               output_dict['tags'], \
               output_dict['attention_mask']

    def extract(self):
        ner_data = ParseUtils.extract(
            max_len = self.configs.MAX_LENGTH,
            overlap = self.configs.OVERLAP,
            max_sample = self.configs.MAX_SAMPLE,
            max_text_tokens = self.configs.MAX_TEXT_TOKENS,
            train_df_path = self.configs.TRAIN_DF_PATH,
            train_data_path = self.configs.TRAIN_DATA_PATH,
            ignore_label_case = self.configs.IGNORE_LABEL_CASE,
            exclude_non_exact_label_match = self.configs.EXCLUDE_NON_EXACT_LABEL_MATCH
        )

        # Write data to file
        if self.configs.SAVE:
            ParseUtils.save_extracted(
                ner_data, 
                self.configs.DATA_PATH, 
                self.configs.EXTRACTED_FILENAME
            )

        return ner_data

    def load_extracted(self):
        return ParseUtils.load_extracted(
            self.configs.DATA_PATH, 
            self.configs.EXTRACTED_FILENAME
        )


class PipelineConfigs:

    def __init__(
        self,
        DATA_PATH,
        SAVE,
        EXTRACTED_FILENAME,
        TOKENIZED_FILENAME,
        MAX_SAMPLE,
        MAX_LENGTH = 64,
        OVERLAP = 20,
        MAX_TEXT_TOKENS=200000,
        IGNORE_LABEL_CASE=True,
        EXCLUDE_NON_EXACT_LABEL_MATCH=True
    ):

        # Maximum number of words for each sentence
        self.MAX_LENGTH = MAX_LENGTH

        # If a sentence exceeds MAX_LENGTH, we split it to multiple sentences 
        # with overlapping
        self.OVERLAP = OVERLAP

        # During development, you may want to only load part of the data. Leave
        # uninitialized during production
        self.MAX_SAMPLE = MAX_SAMPLE

        self.DATA_PATH = DATA_PATH
        #self.DATA_PATH = \
        #    os.path.join(
        #        os.path.join(
        #            os.path.dirname(
        #                os.path.dirname(
        #                    os.path.dirname(
        #                        os.path.dirname(__file__)
        #                    )
        #                )
        #            ),
        #            'data'
        #        ), 
        #        'coleridgeinitiative-show-us-the-data'
        #    )
        self.TRAIN_DATA_PATH = os.path.join(self.DATA_PATH, 'train')
        self.TRAIN_DF_PATH = os.path.join(self.DATA_PATH, 'train.csv')
        self.TEST_DATA_PATH = os.path.join(self.DATA_PATH, 'test')

        # If SAVE is true, will save the extracted and/or the tokenized data
        # under the provided filename(s)
        self.SAVE = SAVE
        self.EXTRACTED_FILENAME = EXTRACTED_FILENAME
        self.TOKENIZED_FILENAME = TOKENIZED_FILENAME
        # Maximum amount of tokens in training texts. Longer texts will be discarded
        self.MAX_TEXT_TOKENS = MAX_TEXT_TOKENS
        # Whether the tagger should ignore the case of the label when matching labels to the text
        self.IGNORE_LABEL_CASE = IGNORE_LABEL_CASE
        # Whether to exclude texts that do not have a single one-on-one (case insensitve) label match
        self.EXCLUDE_NON_EXACT_LABEL_MATCH = EXCLUDE_NON_EXACT_LABEL_MATCH


In [311]:
class ParseUtils:

    @staticmethod
    def count_in_json(json_id, label, train_data_path):
        path_to_json = os.path.join(train_data_path, (json_id + '.json'))
        count_dict = {}
        with open(path_to_json, 'r') as f:
            json_decode = json.load(f)
            for data in json_decode:
                heading = data.get('section_title')
                content = data.get('text')
                count_dict[heading] = content.count(heading)
        return count_dict

    @staticmethod
    def shorten_sentences(sentences, max_len, overlap):
        short_sentences = []
        for sentence in sentences:
            words = sentence.split()
            if len(words) > max_len:
                for p in range(0, len(words), max_len - overlap):
                    short_sentences.append(' '.join(words[p:p + max_len]))
            else:
                short_sentences.append(sentence)
        return short_sentences

    @staticmethod
    def clean_training_text(txt):
        """
        similar to the default clean_text function but without lowercasing.
        """
        txt = re.sub('[^A-Za-z0-9]+', ' ', str(txt)).strip()

        return txt

    @staticmethod
    def find_sublist(big_list, small_list):
        all_positions = []
        for i in range(len(big_list) - len(small_list) + 1):
            if small_list == big_list[i:i + len(small_list)]:
                all_positions.append(i)

        return all_positions

    @staticmethod
    def tag_sentence(sentence, labels, ignore_case):  # requirement: both sentence and
        re_flags = re.IGNORECASE if ignore_case else None

        # labels are already cleaned
        sentence_words = sentence.split()

        if labels is not None and any(re.findall(f'\\b{label}\\b', sentence,
                                                 flags=re_flags) for label in labels):  # positive sample
            nes = ['O'] * len(sentence_words)
            for label in labels:
                label_words = label.split()

                if ignore_case:
                    nocase_label_words = list(map(lambda word: word.lower(), label_words))
                    nocase_sentence_words = list(map(lambda word: word.lower(), sentence_words))
                    all_pos = ParseUtils.find_sublist(nocase_sentence_words, nocase_label_words)
                else:
                    all_pos = ParseUtils.find_sublist(sentence_words, label_words)

                for pos in all_pos:
                    nes[pos] = 'B'
                    for i in range(pos + 1, pos + len(label_words)):
                        nes[i] = 'I'

            return True, list(zip(sentence_words, nes))

        else:  # negative sample
            nes = ['O'] * len(sentence_words)
            return False, list(zip(sentence_words, nes))

    @staticmethod
    def read_append_return(filename, train_data_path, output='text'):
        """
        Function to read json file and then return the text data from them and append to the dataframe

        Basically parse json but then from https://www.kaggle.com/prashansdixit/coleridge-initiative-eda-baseline-model
        """
        json_path = os.path.join(train_data_path, (filename + '.json'))
        headings = []
        contents = []
        combined = []
        with open(json_path, 'r') as f:
            json_decode = json.load(f)
            for data in json_decode:
                headings.append(data.get('section_title'))
                contents.append(data.get('text'))
                combined.append(data.get('section_title'))
                combined.append(data.get('text'))

        all_headings = ' '.join(headings)
        all_contents = ' '.join(contents)
        all_data = '. '.join(combined)

        if output == 'text':
            return all_contents
        elif output == 'head':
            return all_headings
        else:
            return all_data

    @staticmethod
    def save_extracted(ner_data, data_path, file_name):
        with open(os.path.join(data_path, file_name), 'w') as f:
            for row in ner_data:
                words, nes = list(zip(*row))
                row_json = {'tokens': words, 'tags': nes}
                json.dump(row_json, f)
                f.write('\n')

    @staticmethod
    def load_extracted(data_path, file_name):
        ner_data = []
        with open(os.path.join(data_path, file_name), 'r') as f:
            for line in f.readlines():
                # Each line is formatted in JSON format, e.g.
                # { "tokens" : ["A", "short", "sentence"],
                #   "tags"   : ["0", "0", "0"] }
                sentence = json.loads(line)

                # From the tokens and tags, we create a list of
                # tuples of the form
                # [ ("A", "0"), ("short", "0"), ("sentence", "0")]
                sentence_tuple_list = [
                    (token, tag) for token, tag
                    in zip(sentence["tokens"], sentence["tags"])
                ]

                # Each of these parsed sentences becomes an entry
                # in our overall data list
                ner_data.append(sentence_tuple_list)

        return ner_data

    @staticmethod
    def load_auxiliary_datasets(data_path, file_name):
        with open(os.path.join(data_path, file_name), 'r') as f:
            return f.read().split('\n')

    @staticmethod
    def load_tokenized_auxiliary_datasets(data_path, file_name):
        with open(os.path.join(data_path, file_name), 'r') as f:
            datasets = f.read().split('\n')
            return [[int(item) for item in row.split(',') if item != ''] for row in datasets if len(row) > 0]

    @staticmethod
    def save_file(output, data_path, file_name):
        with open(os.path.join(data_path, file_name), 'wb') as f:
            pickle.dump(output, f)

    @staticmethod
    def load_file(data_path, file_name):
        with open(os.path.join(data_path, file_name), 'rb') as f:
            output = pickle.load(f)
        return output

    @staticmethod
    def all_labels_mentioned(data):
        """
        Method that can be applied to a dataframe and check, for all dataset labels, if they occur in the text at least
        once. Case insensitive
        """
        labels = data['dataset_label'].split("|")
        return all(list(map(lambda label: data['text'].lower().count(label.lower()) > 0, labels)))

    @staticmethod
    def extract(
            max_len,
            overlap,
            max_sample,
            max_text_tokens,
            train_df_path,
            train_data_path,
            ignore_label_case,
            exclude_non_exact_label_match

    ):
        """
        Reads the training data from storage using the train.csv file as well
        as all json files inside the train folder, and computes a list,
        where each element is a sentence. Each sentence is itself a list,
        consisting of tuples, where the first element is the word (token) and
        the second is the label (tag).

        This is an example of the data list returned:

            ner_data = [
                ...
                [
                    ("This", "0"),
                    ("is", "0"),
                    ("New", "LOC"),
                    ("York", "LOC"),
                ],
                ...
            ]

        If `save` is True, the data will be stored on disk in the DATA_PATH
        directory in a single text file, where each line is in JSON format, e.g.

            { "tokens" : ["A", "short", "sentence"], "tags" : ["0", "0", "0"] }
        """

        # Read data in CSV file
        train = pd.read_csv(train_df_path)
        train = train[:max_sample]
        print(f'Found {len(train)} raw training rows')

        # Group rows by publication ID
        train = train.groupby('Id').agg({
            'pub_title': 'first',
            'dataset_title': '|'.join,
            'dataset_label': '|'.join,
            'cleaned_label': '|'.join
        }).reset_index()
        print(f'Found {len(train)} unique training rows')

        print('Loading texts, this might take a while...')
        # Read texts for text length analysis
        train['text'] = train['Id'].apply(lambda ID: ParseUtils.read_append_return(ID, train_data_path))
        train['text_token_length'] = train['text'].apply(lambda text: len(text))

        # Remove texts that have more tokens than max_text_tokens
        train = train[train['text_token_length'] <= max_text_tokens]
        print(f'Removed texts exceeding max length, {len(train)} training rows left')

        if exclude_non_exact_label_match:
            # Count label mentions in text
            train["all_labels_mentioned"] = train.apply(ParseUtils.all_labels_mentioned, axis=1)

            # Remove texts that have 0 label count for at least 1 label
            train = train[train['all_labels_mentioned']]
            print(f'Removed texts that had at least one label with 0 exact (case insensitive) matches in the text, '
                  f'{len(train)} training rows left')

        # Read individual papers by ID from storage
        papers = {}
        for paper_id in train['Id'].unique():
            with open(f'{train_data_path}/{paper_id}.json', 'r') as f:
                paper = json.load(f)
                papers[paper_id] = paper

        cnt_pos, cnt_neg = 0, 0  # number of sentences that contain/not contain labels
        ner_data = []

        pbar = tqdm(total=len(train))
        for i, id, dataset_label in train[['Id', 'dataset_label']].itertuples():
            # paper
            paper = papers[id]

            # labels
            labels = dataset_label.split('|')
            labels = [ParseUtils.clean_training_text(label) for label in labels]

            # sentences
            sentences = set([
                ParseUtils.clean_training_text(sentence)
                for section in paper
                for sentence in section['text'].split('.')
            ])
            sentences = ParseUtils.shorten_sentences(
                sentences, max_len, overlap)

            # only accept sentences with length > 10 chars
            sentences = [sentence for sentence in sentences if len(sentence) > 10]

            # positive sample
            for sentence in sentences:
                is_positive, tags = ParseUtils.tag_sentence(sentence, labels, ignore_label_case)
                if is_positive:
                    cnt_pos += 1
                else:
                    cnt_neg += 1
                ner_data.append(tags)

            # process bar
            pbar.update(1)
            pbar.set_description(f"Training data size: {cnt_pos} positives + {cnt_neg} negatives")

        # shuffling
        # random.shuffle(ner_data)

        return ner_data


In [312]:
data_path ='../input/showusthedata-tokenized'

In [313]:
configs = PipelineConfigs(
    DATA_PATH = data_path,
    MAX_LENGTH = 64,
    OVERLAP = 20,
    MAX_SAMPLE = None,
    SAVE = False,
    EXTRACTED_FILENAME = 'train_ner.data',
    TOKENIZED_FILENAME = 'train_ner.data.scibert-tokenized',
)

In [314]:
pipeline = Pipeline(configs)

In [315]:
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_cased', do_lower_case=False)

In [316]:
o_id = tokenizer.convert_tokens_to_ids("O")
i_id = tokenizer.convert_tokens_to_ids("I")

In [317]:
input_ids, tags, attention_mask = pipeline.load_outputs()
#mention_positions = pipeline.load_mention_pos()
auxiliary_datasets = ParseUtils.load_tokenized_auxiliary_datasets(data_path, 'auxiliary_datasets_scibert_tokenized.txt')

In [318]:
mention_positions = [[i_word for i_word, tag in enumerate(sentence) if tag != o_id] for i_sentence, sentence in enumerate(tags)]

In [319]:
# Split multiple mentions in sentence
for i, sentence in enumerate(mention_positions):
    if len(sentence) > 0:
        indices = []
        previous = sentence[0]
        for sent_i, item in enumerate(sentence[1:]):
            if item != previous + 1:
                indices.append(sent_i+1)
            previous = item
        
        indices.append(len(sentence))
        
        new_list = []
        prev_index = 0
        for index in indices:
            new_list.append(sentence[prev_index:index])
            prev_index = index
            
        mention_positions[i] = new_list

In [320]:
tag_values=np.unique(tags)

In [321]:
tag2id = {t:i for i, t in enumerate(np.unique(tags))}
tags = [[tag2id[tag] for tag in sent] for sent in tags]

In [322]:
# Train-validation split
val_size = 0.1
permutation = np.random.RandomState(seed=2018).permutation(len(input_ids))
val_ind = math.floor(val_size * len(input_ids))

tr_inputs, tr_tags, tr_masks, tr_mention_pos = [input_ids[idx] for idx in permutation[val_ind:]], [tags[idx] for idx in permutation[val_ind:]], [attention_mask[idx] for idx in permutation[val_ind:]], [mention_positions[idx] for idx in permutation[val_ind:]]
val_inputs, val_tags, val_masks = [input_ids[idx] for idx in permutation[:val_ind]], [tags[idx] for idx in permutation[:val_ind]], [attention_mask[idx] for idx in permutation[:val_ind]],

In [323]:
# tr_inputs, tr_tags, tr_masks, tr_mention_pos  = input_ids, tags, attention_mask, mention_positions

In [324]:
tr_inputs = torch.tensor(tr_inputs).to(device)
val_inputs = torch.tensor(val_inputs).to(device)
tr_tags = torch.tensor(tr_tags).to(device)
val_tags = torch.tensor(val_tags).to(device)
tr_masks = torch.tensor(tr_masks).to(device)
val_masks = torch.tensor(val_masks).to(device)

auxiliary_datasets = [torch.tensor(dataset).to(device) for dataset in auxiliary_datasets]

In [325]:
class BERTDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids, tags, attention_mask, mention_positions, auxiliary_datasets, transform=None):

        self.input_ids = input_ids
        self.tags = tags
        self.attention_mask = attention_mask
        self.mention_positions = mention_positions
        self.auxiliary_datasets = auxiliary_datasets
        self.transform = transform

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        sample =  self.input_ids[idx], self.tags[idx], self.attention_mask[idx]
        
        if self.transform:
            sample = self.transform(sample, self.mention_positions[idx], self.auxiliary_datasets)

        return sample

In [326]:
class MentionSwapAugment(object):

    def __init__(self, chance, o_id, i_id):
        self.chance = chance
        self.o_id = o_id
        self.i_id = i_id

    def __call__(self, sample, mention_positions, auxiliary_datasets):
        input_ids, tags, attention_mask = sample
        if random.random() <= self.chance and len(mention_positions) > 0:
            for ment_pos in mention_positions:
                if len(ment_pos) > 1:
                    new_dataset = auxiliary_datasets[random.randrange(0, len(auxiliary_datasets))]


                    new_dataset_len = len(new_dataset)
                    if len(new_dataset) == len(ment_pos):
                        input_ids[ment_pos] = new_dataset
                    else:
                        length_diff = new_dataset_len - len(ment_pos)
#                         print(tokenizer.convert_ids_to_tokens(input_ids))
#                         print(tags)
#                         print(attention_mask)
#                         print(ment_pos)
                    
                        if length_diff < 0:
                            return self.augment_with_smaller_label(input_ids, tags, attention_mask, new_dataset, ment_pos, length_diff, new_dataset_len)
                        else:
                            if ment_pos[0] + new_dataset_len <= len(input_ids):
                                return self.augment_with_bigger_label(input_ids, tags, attention_mask, new_dataset, ment_pos, length_diff, new_dataset_len)

                    
        return input_ids, tags, attention_mask
    
    
    def augment_with_smaller_label(self, input_ids, tags, attention_mask, new_dataset, mention_positions, length_diff, new_dataset_len):
        # Set new tokens
        new_token_positions = mention_positions[:new_dataset_len]
        input_ids[new_token_positions] = new_dataset

        # Roll back rest of sentence and tags
        input_ids[new_token_positions[-1]+1:length_diff] = input_ids[new_token_positions[-1] + 1 - length_diff:].clone()
        tags[new_token_positions[-1]+1:length_diff] = tags[new_token_positions[-1] + 1 - length_diff:].clone()
        attention_mask[new_token_positions[-1]+1:length_diff] = attention_mask[new_token_positions[-1] + 1 - length_diff:].clone()

        # Fill ends (padding and O token)
        input_ids[length_diff:] = 0
        tags[length_diff:] = self.o_id
        attention_mask[length_diff:] = 0

#         print(tokenizer.convert_ids_to_tokens(input_ids))
#         print(tags)
#         print(attention_mask)

#         print("Smaller tag filled in")
#         raise Exception("FINISH")
        
        return input_ids, tags, attention_mask
    
    
    def augment_with_bigger_label(self, input_ids, tags, attention_mask, new_dataset, mention_positions, length_diff, new_dataset_len):
        # New token positions
        new_token_positions = torch.arange(mention_positions[0], mention_positions[0]+new_dataset_len)


        # Roll sentence away
        input_ids[new_token_positions[-1]+1:] = input_ids[mention_positions[-1]+1:-length_diff].clone()
        tags[new_token_positions[-1]+1:] = tags[mention_positions[-1]+1:-length_diff].clone()
        attention_mask[new_token_positions[-1]+1:] = attention_mask[mention_positions[-1]+1:-length_diff].clone()

        # Set new tokens
        input_ids[new_token_positions] = new_dataset
        
        # Continue with I token!
        tags[mention_positions[-1]+1:mention_positions[-1]+1+length_diff] = self.i_id
        attention_mask[mention_positions[-1]+1:mention_positions[-1]+1+length_diff] = 1

#         print(tokenizer.convert_ids_to_tokens(input_ids))
#         print(tags)
#         print(attention_mask)

#         print("Bigger tag filled in")
#         raise Exception("FINISH")
        
        return input_ids, tags, attention_mask
        

In [327]:
BATCH_SIZE = 128

train_dataset = BERTDataset(tr_inputs, tr_tags, tr_masks, tr_mention_pos, auxiliary_datasets, transform=MentionSwapAugment(0.5, o_id=tag2id[o_id], i_id=tag2id[i_id]))
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=BATCH_SIZE)

# Better Evaluation Metric
The following code cell implements an improved F1 score for NER evaluation.

For more details, please see http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/

The code is taken from https://raw.githubusercontent.com/davidsbatista/NER-Evaluation/master/ner_evaluation/ner_eval.py

In [328]:
import logging
from collections import namedtuple
from copy import deepcopy

logging.basicConfig(
    format="%(asctime)s %(name)s %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level="DEBUG",
)

Entity = namedtuple("Entity", "e_type start_offset end_offset")

class Evaluator():

    def __init__(self, true, pred, tags):
        """
        """

        if len(true) != len(pred):
            raise ValueError("Number of predicted documents does not equal true")

        self.true = true
        self.pred = pred
        self.tags = tags

        # Setup dict into which metrics will be stored.

        self.metrics_results = {
            'correct': 0,
            'incorrect': 0,
            'partial': 0,
            'missed': 0,
            'spurious': 0,
            'possible': 0,
            'actual': 0,
            'precision': 0,
            'recall': 0,
        }

        # Copy results dict to cover the four schemes.

        self.results = {
            'strict': deepcopy(self.metrics_results),
            'ent_type': deepcopy(self.metrics_results),
            'partial':deepcopy(self.metrics_results),
            'exact':deepcopy(self.metrics_results),
            }

        # Create an accumulator to store results

        self.evaluation_agg_entities_type = {e: deepcopy(self.results) for e in tags}


    def evaluate(self):

        logging.info(
            "Imported %s predictions for %s true examples",
            len(self.pred), len(self.true)
        )

        for true_ents, pred_ents in zip(self.true, self.pred):

            # Check that the length of the true and predicted examples are the
            # same. This must be checked here, because another error may not
            # be thrown if the lengths do not match.

            if len(true_ents) != len(pred_ents):
                raise ValueError("Prediction length does not match true example length")

            # Compute results for one message

            tmp_results, tmp_agg_results = compute_metrics(
                collect_named_entities(true_ents),
                collect_named_entities(pred_ents),
                self.tags
            )

            # Cycle through each result and accumulate

            # TODO: Combine these loops below:

            for eval_schema in self.results:

                for metric in self.results[eval_schema]:

                    self.results[eval_schema][metric] += tmp_results[eval_schema][metric]

            # Calculate global precision and recall

            self.results = compute_precision_recall_wrapper(self.results)

            # Aggregate results by entity type

            for e_type in self.tags:

                for eval_schema in tmp_agg_results[e_type]:

                    for metric in tmp_agg_results[e_type][eval_schema]:

                        self.evaluation_agg_entities_type[e_type][eval_schema][metric] += tmp_agg_results[e_type][eval_schema][metric]

                # Calculate precision recall at the individual entity level

                self.evaluation_agg_entities_type[e_type] = compute_precision_recall_wrapper(self.evaluation_agg_entities_type[e_type])

        return self.results, self.evaluation_agg_entities_type


def collect_named_entities(tokens):
    """
    Creates a list of Entity named-tuples, storing the entity type and the start and end
    offsets of the entity.

    :param tokens: a list of tags
    :return: a list of Entity named-tuples
    """

    named_entities = []
    start_offset = None
    end_offset = None
    ent_type = None

    for offset, token_tag in enumerate(tokens):

        if token_tag == 'O':
            if ent_type is not None and start_offset is not None:
                end_offset = offset - 1
                named_entities.append(Entity(ent_type, start_offset, end_offset))
                start_offset = None
                end_offset = None
                ent_type = None

        elif ent_type is None:
            ent_type = token_tag[2:]
            start_offset = offset

        elif ent_type != token_tag[2:] or (ent_type == token_tag[2:] and token_tag[:1] == 'B'):

            end_offset = offset - 1
            named_entities.append(Entity(ent_type, start_offset, end_offset))

            # start of a new entity
            ent_type = token_tag[2:]
            start_offset = offset
            end_offset = None

    # catches an entity that goes up until the last token

    if ent_type is not None and start_offset is not None and end_offset is None:
        named_entities.append(Entity(ent_type, start_offset, len(tokens)-1))

    return named_entities


def compute_metrics(true_named_entities, pred_named_entities, tags):


    eval_metrics = {'correct': 0, 'incorrect': 0, 'partial': 0, 'missed': 0, 'spurious': 0, 'precision': 0, 'recall': 0}

    # overall results
    
    evaluation = {
        'strict': deepcopy(eval_metrics),
        'ent_type': deepcopy(eval_metrics),
        'partial': deepcopy(eval_metrics),
        'exact': deepcopy(eval_metrics)
    }

    # results by entity type

    evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags}

    # keep track of entities that overlapped

    true_which_overlapped_with_pred = []

    # Subset into only the tags that we are interested in.
    # NOTE: we remove the tags we don't want from both the predicted and the
    # true entities. This covers the two cases where mismatches can occur:
    #
    # 1) Where the model predicts a tag that is not present in the true data
    # 2) Where there is a tag in the true data that the model is not capable of
    # predicting.

    true_named_entities = [ent for ent in true_named_entities if ent.e_type in tags]
    pred_named_entities = [ent for ent in pred_named_entities if ent.e_type in tags]

    # go through each predicted named-entity

    for pred in pred_named_entities:
        found_overlap = False

        # Check each of the potential scenarios in turn. See
        # http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/
        # for scenario explanation.

        # Scenario I: Exact match between true and pred

        if pred in true_named_entities:
            true_which_overlapped_with_pred.append(pred)
            evaluation['strict']['correct'] += 1
            evaluation['ent_type']['correct'] += 1
            evaluation['exact']['correct'] += 1
            evaluation['partial']['correct'] += 1

            # for the agg. by e_type results
            evaluation_agg_entities_type[pred.e_type]['strict']['correct'] += 1
            evaluation_agg_entities_type[pred.e_type]['ent_type']['correct'] += 1
            evaluation_agg_entities_type[pred.e_type]['exact']['correct'] += 1
            evaluation_agg_entities_type[pred.e_type]['partial']['correct'] += 1

        else:

            # check for overlaps with any of the true entities

            for true in true_named_entities:

                pred_range = range(pred.start_offset, pred.end_offset)
                true_range = range(true.start_offset, true.end_offset)

                # Scenario IV: Offsets match, but entity type is wrong

                if true.start_offset == pred.start_offset and pred.end_offset == true.end_offset \
                        and true.e_type != pred.e_type:

                    # overall results
                    evaluation['strict']['incorrect'] += 1
                    evaluation['ent_type']['incorrect'] += 1
                    evaluation['partial']['correct'] += 1
                    evaluation['exact']['correct'] += 1

                    # aggregated by entity type results
                    evaluation_agg_entities_type[true.e_type]['strict']['incorrect'] += 1
                    evaluation_agg_entities_type[true.e_type]['ent_type']['incorrect'] += 1
                    evaluation_agg_entities_type[true.e_type]['partial']['correct'] += 1
                    evaluation_agg_entities_type[true.e_type]['exact']['correct'] += 1

                    true_which_overlapped_with_pred.append(true)
                    found_overlap = True

                    break

                # check for an overlap i.e. not exact boundary match, with true entities

                elif find_overlap(true_range, pred_range):

                    true_which_overlapped_with_pred.append(true)

                    # Scenario V: There is an overlap (but offsets do not match
                    # exactly), and the entity type is the same.
                    # 2.1 overlaps with the same entity type

                    if pred.e_type == true.e_type:

                        # overall results
                        evaluation['strict']['incorrect'] += 1
                        evaluation['ent_type']['correct'] += 1
                        evaluation['partial']['partial'] += 1
                        evaluation['exact']['incorrect'] += 1

                        # aggregated by entity type results
                        evaluation_agg_entities_type[true.e_type]['strict']['incorrect'] += 1
                        evaluation_agg_entities_type[true.e_type]['ent_type']['correct'] += 1
                        evaluation_agg_entities_type[true.e_type]['partial']['partial'] += 1
                        evaluation_agg_entities_type[true.e_type]['exact']['incorrect'] += 1

                        found_overlap = True

                        break

                    # Scenario VI: Entities overlap, but the entity type is
                    # different.

                    else:
                        # overall results
                        evaluation['strict']['incorrect'] += 1
                        evaluation['ent_type']['incorrect'] += 1
                        evaluation['partial']['partial'] += 1
                        evaluation['exact']['incorrect'] += 1

                        # aggregated by entity type results
                        # Results against the true entity

                        evaluation_agg_entities_type[true.e_type]['strict']['incorrect'] += 1
                        evaluation_agg_entities_type[true.e_type]['partial']['partial'] += 1
                        evaluation_agg_entities_type[true.e_type]['ent_type']['incorrect'] += 1
                        evaluation_agg_entities_type[true.e_type]['exact']['incorrect'] += 1

                        # Results against the predicted entity

                        # evaluation_agg_entities_type[pred.e_type]['strict']['spurious'] += 1

                        found_overlap = True

                        break

            # Scenario II: Entities are spurious (i.e., over-generated).

            if not found_overlap:

                # Overall results

                evaluation['strict']['spurious'] += 1
                evaluation['ent_type']['spurious'] += 1
                evaluation['partial']['spurious'] += 1
                evaluation['exact']['spurious'] += 1

                # Aggregated by entity type results

                # NOTE: when pred.e_type is not found in tags
                # or when it simply does not appear in the test set, then it is
                # spurious, but it is not clear where to assign it at the tag
                # level. In this case, it is applied to all target_tags
                # found in this example. This will mean that the sum of the
                # evaluation_agg_entities will not equal evaluation.

                for true in tags:                    

                    evaluation_agg_entities_type[true]['strict']['spurious'] += 1
                    evaluation_agg_entities_type[true]['ent_type']['spurious'] += 1
                    evaluation_agg_entities_type[true]['partial']['spurious'] += 1
                    evaluation_agg_entities_type[true]['exact']['spurious'] += 1

    # Scenario III: Entity was missed entirely.

    for true in true_named_entities:
        if true in true_which_overlapped_with_pred:
            continue
        else:
            # overall results
            evaluation['strict']['missed'] += 1
            evaluation['ent_type']['missed'] += 1
            evaluation['partial']['missed'] += 1
            evaluation['exact']['missed'] += 1

            # for the agg. by e_type
            evaluation_agg_entities_type[true.e_type]['strict']['missed'] += 1
            evaluation_agg_entities_type[true.e_type]['ent_type']['missed'] += 1
            evaluation_agg_entities_type[true.e_type]['partial']['missed'] += 1
            evaluation_agg_entities_type[true.e_type]['exact']['missed'] += 1

    # Compute 'possible', 'actual' according to SemEval-2013 Task 9.1 on the
    # overall results, and use these to calculate precision and recall.

    for eval_type in evaluation:
        evaluation[eval_type] = compute_actual_possible(evaluation[eval_type])

    # Compute 'possible', 'actual', and precision and recall on entity level
    # results. Start by cycling through the accumulated results.

    for entity_type, entity_level in evaluation_agg_entities_type.items():

        # Cycle through the evaluation types for each dict containing entity
        # level results.

        for eval_type in entity_level:

            evaluation_agg_entities_type[entity_type][eval_type] = compute_actual_possible(
                entity_level[eval_type]
            )

    return evaluation, evaluation_agg_entities_type


def find_overlap(true_range, pred_range):
    """Find the overlap between two ranges

    Find the overlap between two ranges. Return the overlapping values if
    present, else return an empty set().

    Examples:

    >>> find_overlap((1, 2), (2, 3))
    2
    >>> find_overlap((1, 2), (3, 4))
    set()
    """

    true_set = set(true_range)
    pred_set = set(pred_range)

    overlaps = true_set.intersection(pred_set)

    return overlaps


def compute_actual_possible(results):
    """
    Takes a result dict that has been output by compute metrics.
    Returns the results dict with actual, possible populated.

    When the results dicts is from partial or ent_type metrics, then
    partial_or_type=True to ensure the right calculation is used for
    calculating precision and recall.
    """

    correct = results['correct']
    incorrect = results['incorrect']
    partial = results['partial']
    missed = results['missed']
    spurious = results['spurious']

    # Possible: number annotations in the gold-standard which contribute to the
    # final score

    possible = correct + incorrect + partial + missed

    # Actual: number of annotations produced by the NER system

    actual = correct + incorrect + partial + spurious

    results["actual"] = actual
    results["possible"] = possible

    return results


def compute_precision_recall(results, partial_or_type=False):
    """
    Takes a result dict that has been output by compute metrics.
    Returns the results dict with precison and recall populated.

    When the results dicts is from partial or ent_type metrics, then
    partial_or_type=True to ensure the right calculation is used for
    calculating precision and recall.
    """

    actual = results["actual"]
    possible = results["possible"]
    partial = results['partial']
    correct = results['correct']

    if partial_or_type:
        precision = (correct + 0.5 * partial) / actual if actual > 0 else 0
        recall = (correct + 0.5 * partial) / possible if possible > 0 else 0

    else:
        precision = correct / actual if actual > 0 else 0
        recall = correct / possible if possible > 0 else 0

    results["precision"] = precision
    results["recall"] = recall

    return results


def compute_precision_recall_wrapper(results):
    """
    Wraps the compute_precision_recall function and runs on a dict of results
    """

    results_a = {key: compute_precision_recall(value, True) for key, value in results.items() if
                 key in ['partial', 'ent_type']}
    results_b = {key: compute_precision_recall(value) for key, value in results.items() if
                 key in ['strict', 'exact']}

    results = {**results_a, **results_b}

    return results



# Defining Model

In [329]:
model_path = '../input/pretrained-language-models/SciBERT_scivocab_cased'
#finetuned_path = '../input/show-us-the-data-bert-weights/SciBERT_Finetuned_5Eps'

In [330]:
model = AutoModelForTokenClassification.from_pretrained(
    model_path,
    num_labels=3,
    output_attentions = False,
    output_hidden_states = False
).to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at ../input/pretrained-language-models/SciBERT_scivocab_cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [331]:
# Full finetuning to tune all model parameters
# Otherwise, only train classifier
FULL_FINETUNING = True

if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]

In [332]:
optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=3e-5,
    eps=1e-8
)

In [333]:
from transformers import get_linear_schedule_with_warmup
epochs = 7
max_grad_norm = 1.0

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

In [334]:
for param in model.base_model.parameters():
    param.requires_grad = True

# Training

In [335]:
# Put the attention mask on CPU
val_masks = val_masks.cpu().numpy()

In [336]:
loss_values, validation_loss_values = [], []
n_train_steps = len(train_dataloader)
n_valid_steps = len(valid_dataloader)

for epoch_idx in range(epochs):
    # ========================================
    #               Training
    # ========================================
    # Perform one full pass over the training set.

    print(f'Epoch {epoch_idx+1}/{epochs}')
    
    # Put the model into training mode.
    model.train()
    # Reset the total loss for this epoch.
    total_loss = 0
    
    # Training loop
    pbar = Progbar(n_train_steps)
    for step, batch in enumerate(train_dataloader):
        #print(f'Train step {step}/{n_train_steps}\r', end='')
        
        # add batch to gpu
        #batch = tuple(t.to('cpu') for t in batch)
        b_input_ids, b_labels, b_input_mask  = batch
        # Always clear any previously calculated gradients before performing a backward pass.
        model.zero_grad()
        # forward pass
        # This will return the loss (rather than the model output)
        # because we have provided the `labels`.
        outputs = model(b_input_ids,
                        attention_mask=b_input_mask, 
                        labels=b_labels)
        # get the loss
        loss = outputs.loss
        # Perform a backward pass to calculate the gradients.
        loss.backward()
        # track train loss
        total_loss += loss.item()
        # Clip the norm of the gradient
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        # update parameters
        optimizer.step()
        # Update the learning rate.
        scheduler.step()
        
        pbar.update(step+1), [('Train loss', loss.item())])
  
    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)
    #print("Train Loss: {:.8f}".format(avg_train_loss))

    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)


    # ========================================
    #               Validation
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our validation set.

    # Put the model into evaluation mode
    model.eval()
    # Reset the validation loss for this epoch.
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions , true_labels = [], []
    
    pbar = Progbar(n_valid_steps)
    for step, batch in enumerate(valid_dataloader):

        #print(f'Valid step {step}/{n_valid_steps}\r', end='')

        #batch = tuple(t.to('cpu') for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        # Telling the model not to compute or store gradients,
        # saving memory and speeding up validation
        with torch.no_grad():
            # Forward pass, calculate logit predictions.
            # This will return the logits rather than the loss because we have not provided labels.
            outputs = model(b_input_ids,
                            attention_mask=b_input_mask, 
                            labels=b_labels)
        # Move logits and labels to CPU
        logits = outputs[1].detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        # Calculate the accuracy for this batch of test sentences.
        loss_val = outputs[0].mean().item()
        eval_loss += loss_val
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.extend(label_ids)

        pbar.update(step+1, [('Valid loss', loss_val)])
       
    eval_loss = eval_loss / len(valid_dataloader)
    validation_loss_values.append(eval_loss)
    #print("Valid Loss: {:.8f}".format(eval_loss))
    
    # Remove padding from predictions
    masked_preds = []
    for i in range(len(predictions)):
        masked_preds.append(
            predictions[i][
                :len(np.argwhere(val_masks[0]))
            ]
        )

    # Remove padding from labels
    masked_labels = []
    for i in range(len(true_labels)):
        masked_labels.append(
            true_labels[i][
                :len(np.argwhere(val_masks[0]))
            ]
        )
    
    # Flatten both
    flat_masked_preds = np.reshape(masked_preds, -1)
    flat_masked_labels = np.reshape(masked_labels, -1)
    
    # Compute accuracy
    acc = accuracy_score(flat_masked_preds, flat_masked_labels)
    print("Valid Accuracy: {:2.2f}%".format(100*acc))
    
    # Compute ordinary F1
    f1 = f1_score(flat_masked_preds, flat_masked_labels, average=None)
    print("Valid F1-Score: {:2.2f}%, {:2.2f}%, {:2.2f}%".format(*f1*100))
        
    # Compute improved F1
    ids2tag = {0:'B-DATA', 1:'I-DATA', 2:'O'} # Need to convert to String IDs
    masked_preds_strids = [ [ ids2tag[token] for token in sent ] for sent in masked_preds ]
    masked_labels_strids = [ [ ids2tag[token] for token in sent ] for sent in masked_labels ]
    evaluator = Evaluator(masked_labels_strids, masked_preds_strids, ['DATA'])
    results, _ = evaluator.evaluate() # We don't need the second, aggregated results 
                                      # because we only have one kind of entity
    # Add ordinary metrics to results
    results['ordinary'] = { 
        'accuracy': acc, 
        'f1_B': f1[1],
        'f1_I': f1[0],
        'f1_O': f1[2]
    }

    print("SemEval Metrics:")
    for key in ['strict', 'exact', 'ent_type', 'partial']:
        precision = results[key]['precision']
        recall = results[key]['recall']
        imp_f1 = 2 * ( (precision*recall) / (precision+recall) )
        print('  '+key+':')
        print("    Precision: {:2.2f}%".format(precision*100))
        print("    Recall:    {:2.2f}%".format(recall*100))
        print("    F1-Score:  {:2.2f}%".format(imp_f1*100))
        results[key]['f1'] = imp_f1
    
    model.save_pretrained(f"/kaggle/working/SciBERT_Finetuned_(augmentation)_{epoch_idx}Eps")
    with open(f'metrics_{epoch_idx}.json', 'w') as f:
        json.dump(results, f)

Epoch 1/7
Epoch 2/7
Epoch 3/7
Epoch 4/7
Epoch 5/7
Epoch 6/7
Epoch 7/7
