This notebook illustrates how to use Masked Language Modeling for this competition.

Observation: most of the dataset names consist of only words with uppercased-first-letter and some stopwords like `on`, `in`, `and` (e.g. `Early Childhood Longitudinal Study`, `Trends in International Mathematics and Science Study`). 

Thus, one approach to find the datasets is: 
- Locate all the sequences of capitalized words (these sequences may contain some stopwords), 
- Replace each sequence with one of 2 special symbols (e.g. `$` and `#`), implying if that sequence represents a dataset name or not.
- Have the model learn the MLM task.

The code below shows how to train a model for that purpose with the help of the `huggingface`.

In [None]:
MAX_SAMPLE = None # set a small number (e.g. 50) for experimentation, set None for production.

# Install packages

In [None]:
!pip install datasets --no-index --find-links=file:///kaggle/input/coleridge-packages/packages/datasets
!pip install ../input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
!pip install ../input/coleridge-packages/tokenizers-0.10.1-cp37-cp37m-manylinux1_x86_64.whl
!pip install ../input/coleridge-packages/transformers-4.5.0.dev0-py3-none-any.whl

# Import

In [None]:
import os
import re
import json
import time
import datetime
import random
import glob
import importlib

import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, \
AutoModelForMaskedLM, Trainer, TrainingArguments, pipeline, AutoConfig, \
AutoModelForSequenceClassification, DataCollatorWithPadding

sns.set()
random.seed(123)
np.random.seed(456)

MASKEDLM = True
NER = False

In [None]:
def filter_label(label_list):
    out = []
    prep_tokens = {'s', 'of', 'and', 'in', 'on', 'for', 'this', 'we', 'their', 'it', 'to'}
    for label in label_list:
        if len(label.split())==1:
            continue
        elif len(label.split())==2 and not all(word.isupper() for word in label.split()):
            continue
        elif any(word[0].islower() for word in label.split()):
            continue
        elif '\xad' in label:
            continue
        out.append(label)
    return out

import itertools
rccDf = pd.read_json('../input/rccfull/train_test/data_set_citations.json')[['publication_id','mention_list']]
rccDf['mention_list'] = rccDf['mention_list'].map(filter_label)
newLabels=list(set(list(itertools.chain(*rccDf['mention_list'].tolist()))))

In [None]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower()).strip()

def clean_text_multiple(line):
    return '|'.join(clean_text(txt) for txt in line.split('|'))
    

rccDf = pd.read_json('../input/rccfull/train_test/data_set_citations.json')[['publication_id','mention_list']]
rccDf['mention_list'] = rccDf['mention_list'].map(filter_label)
rccDf['mention_list'] = rccDf['mention_list'].map('|'.join)
rccDf = rccDf.rename(columns = {'publication_id':'Id', 'mention_list':'dataset_label'})
rccDf['cleaned_label'] = rccDf['dataset_label'].map(clean_text_multiple)
rccDf['dataset_title'] = rccDf['dataset_label']
rccDf['pub_title'] = ''
rccDf = rccDf.drop([i for i,j in enumerate(rccDf['dataset_label']=='') if j],axis=0)

In [None]:
rccDf

In [None]:
MAX_LENGTH = 64
OVERLAP = 20

DATASET_SYMBOL = '$' # this symbol represents a dataset name
NONDATA_SYMBOL = '#' # this symbol represents a non-dataset name

VAL = 1

# Load data

In [None]:
# train
train_path = '../input/coleridgeinitiative-show-us-the-data/train.csv'
paper_train_folder = '../input/coleridgeinitiative-show-us-the-data/train'

train = pd.read_csv(train_path)
train = train[:MAX_SAMPLE]
rccDf = rccDf[:MAX_SAMPLE]
train = pd.concat([rccDf, train], ignore_index=True)


# Group by publication, training labels should have the same form as expected output.
train = train.groupby('Id').agg({
    'pub_title': 'first',
    'dataset_title': '|'.join,
    'dataset_label': '|'.join,
    'cleaned_label': '|'.join
}).reset_index()    

print('train size: ', len(train))

existing_labels = set(np.load('../input/showdata-labels1/existing_labels.npy', allow_pickle = True).tolist())
existing_labels_list = [lbl.split() for lbl in existing_labels]
existing_labels_list = existing_labels_list+[lbl.split() for lbl in newLabels]

In [None]:
train

In [None]:
extDf = pd.read_csv('../input/training-set-found-labels/FNE_plusAbbrevs.csv')
extDf.loc[extDf['Dataset']==1].sample(30)
extDatasets = extDf.loc[extDf['Dataset']==1,'Text'].tolist()
extDatasets = [dat for dat in extDatasets if len(dat.split())>1]
extDatasets = list(set(extDatasets + newLabels + list(existing_labels)))

In [None]:
len(extDatasets)

# Prepare data for train MLM

### Auxiliary functions

In [None]:
def clean_paper_sentence(s):
    """
    This function is essentially clean_text without lowercasing.
    """
    s = re.sub('[^A-Za-z0-9]+', ' ', str(s)).strip()
    s = re.sub(' +', ' ', s)
    return s

def shorten_sentences(sentences):
    """
    Sentences that have more than MAX_LENGTH words will be split
    into multiple sentences with overlappings.
    """
    short_sentences = []
    for sentence in sentences:
        words = sentence.split()
        if len(words) > MAX_LENGTH:
            for p in range(0, len(words), MAX_LENGTH - OVERLAP):
                short_sentences.append(' '.join(words[p:p+MAX_LENGTH]))
        else:
            short_sentences.append(sentence)
    return short_sentences

def find_sublist(big_list, small_list):
    """
    find all positions of $small_list in $big_list.
    """
    all_positions = []
    for i in range(len(big_list) - len(small_list) + 1):
        if small_list == big_list[i:i+len(small_list)]:
            all_positions.append(i)
    
    return all_positions

def jaccard_similarity_list(l1, l2):
    """
    Return the Jaccard Similarity score of 2 lists.
    """
    intersection = len(list(set(l1).intersection(l2)))
    union = (len(l1) + len(l2)) - intersection
    return float(intersection) / union

connection_tokens = {'s', 'of', 'and', 'in', 'on', 'for', 'data', 'dataset', 'survey', 'study','sequence'}
prep_tokens = {'s', 'of', 'and', 'in', 'on', 'for', 'this', 'we', 'their', 'it', 'to'}
fullLabels = list(set(list(existing_labels)+list(extDatasets)))
abbreviations = [re.sub('[()]','',words.split()[-1]) for words in fullLabels if re.sub('[()]','',words.split()[-1]).isupper()]
abbreviations = list(set([word for word in abbreviations if len(word)>2]))
def find_negative_candidates(sentence, labels, misc_candidates = []):
    """
    Extract negative samples for Masked Dataset Modeling from a given $sentence.
    A negative candidate should be a continuous sequence of at least 2 words, 
    each of these words either has the first letter in uppercase or is one of
    the connection words ($connection_tokens). Furthermore, the connection 
    tokens are not allowed to appear at the beginning and the end of the
    sequence. Lastly, the sequence must be quite different to any of the 
    ground truth labels (measured by Jaccard similarity).
    """
    def candidate_qualified(words, labels):
        # remove beginning words that are connection_tokens except data/dataset
        startIdx = 0
        endIdx = 0
        while len(words) and words[0].lower() in prep_tokens:
            words = words[1:]
            startIdx +=1
        # remove ending words that are connection_tokens
        while len(words) and words[-1].lower() in prep_tokens:
            words = words[:-1]
            endIdx+=1
        # comparison without connection_tokens
        # possible change to 2
        if len(words)==1 and np.random.rand()<0.0:
            if (not words[0].isnumeric()) and words[0].isupper() and len(words[0])>=3 and all(words[0] not in label for label in labels) and all(words[0].lower() not in label for label in labels):
                if any(char.isdigit() for char in words[0]):
                    return False, []
                elif words[0] not in abbreviations:
                    return True, [startIdx, endIdx]
                else:
                    return False, []
            else:
                return False, []
        elif len(words)==2:
            if any((word[0].islower() or word.isnumeric() or len(word)<=2) for word in words):
                return False, []
            elif any(word.lower() in prep_tokens for word in words):
                return False,[]
            elif any(' '.join(words) in ' '.join(label) for label in labels):
                return False, []
            else:
                return True, [startIdx, endIdx]
        elif (len(words) <= 2 or \
            any(jaccard_similarity_list(words, label) >= 0.75 for label in labels) or \
            sum([1 for word in words if not word.isnumeric()])<=2):
            return False, []
        elif len(words)==3 and (words[1] == 'and' or all(word.isupper() for word in words)):
            return False, []
        elif any([word.lower() in ['dataset', 'data', 'survey', 'study'] for word in words]):
            return False, []
        elif any(jaccard_similarity_list([word.lower() for word in words], label) >= 0.5 for label in existing_labels_list):
            return False, []
        elif any(jaccard_similarity_list([word.lower() for word in words[0:4]], label) >= 0.5 for label in existing_labels_list):
            return False, []
        elif len(words)==4 and words[-1].isnumeric() and words[1] == 'and':
            # to get rid of references, e.g. Johnson and Johnson 2018
            return False, []
        else:
            return True, [startIdx, endIdx]
    
    candidates = []
    
    phrase_start, phrase_end = -1, -1
    for id in range(1, len(sentence)):
        word = sentence[id]
        # if word is captial or connection token
        # if word[0].isupper() or word in connection_tokens:
        if word[0].isupper() or (word[0].isnumeric() and len(word)>2):
            # set as phrase start if phrase start doesn't exist, if not set as end
            if phrase_start == -1:
                phrase_start = phrase_end = id
            else:
                phrase_end = id
        elif word not in connection_tokens:
            # if current phrase fulfils dissimilarity requirement, reset phrase_start
            if phrase_start != -1:
                qualified, tmpidxs = candidate_qualified(sentence[phrase_start:phrase_end+1], labels)
                if qualified:
                    candidates.append((phrase_start+tmpidxs[0], phrase_end-tmpidxs[1]))
                phrase_start = phrase_end = -1
    
    # to deal with case where phrase end is last word
    if phrase_start != -1:
        qualified, tmpidxs = candidate_qualified(sentence[phrase_start:phrase_end+1], labels)
        if qualified:
            candidates.append((phrase_start+tmpidxs[0], phrase_end-tmpidxs[1]))
            
    for cand in misc_candidates:
        words = cand.split()
        if any(jaccard_similarity_list([word.lower() for word in words], label) >= 0.5 for label in existing_labels_list):
            continue
        elif any(jaccard_similarity_list([word.lower() for word in words], label) >= 0.5 for label in labels):
            continue
        elif any(jaccard_similarity_list([word.lower() for word in words], label) >= 0.5 for label in candidates):
            continue
        else:
            phrase_start = sentence.index(words[0])
            candidates.append((phrase_start,phrase_start+len(words)-1))
    return candidates

def pre_tokenize(sentence):
    try:
        sentence = sentence.split()
    except:
        pass
    wordlist = ['university', 'initiative','international','information']
    for i in range(len(sentence)):
        word  = sentence[i]
        if word.isupper():
            sentence[i] = '#'
        elif word[0].isupper() and len(word)>8 and word.lower() not in wordlist:
            sentence[i] = '$'
    return sentence

def pre_tokenize_mask(sentence):
    try:
        sentence = sentence.split()
    except:
        pass
    wordlist = ['university', 'initiative','international','information']
    for i in range(len(sentence)):
        word  = sentence[i]
        if word.isupper():
            sentence[i] = 'XXXX'
        elif word[0].isupper() and len(word)>8 and word.lower() not in wordlist:
            sentence[i] = 'ZZZZ'
    return sentence

def replaceWithExt(phrase, prob = 0.8):
    if np.random.rand()<prob:
        if len(phrase)>1:
            phrase = np.random.choice(extDatasets).split()
            return phrase
        else:
            return phrase
    else:
        return phrase

# NER Encoding

In [None]:
if NER:
    #from transformers import pipeline
    #from transformers import AutoModelForTokenClassification, AutoTokenizer
    #modelstrNER = 'squeezebert/squeezebert-mnli'
    #modelNER = AutoModelForTokenClassification.from_pretrained(modelstrNER)
    #tokenizerNER = AutoTokenizer.from_pretrained(modelstrNER)
    #nerPipe = pipeline('ner', grouped_entities = True, device = 0, use_fast = True, model = modelNER, tokenizer = tokenizerNER)
    import spacy
    nerPipe = spacy.load('en_core_web_sm', pipeline=["ner"])

### Extract positive and negative samples

In [None]:
from tqdm import tqdm

corpus = []
cnt_pos = 0
cnt_neg = 0
neg_phrase = []
pos_phrase = []
classLabels = []
NERDict = {}
dN = 5


pbar = tqdm(total = len(train))
for paper_id, dataset_labels in train[['Id', 'dataset_label']].itertuples(index=False):
    labels = [clean_paper_sentence(label) for label in dataset_labels.split('|')]
    labels = list(set(labels))
    labels = [label.split() for label in labels]
    if isinstance(paper_id, str):
        with open(f'{paper_train_folder}/{paper_id}.json', 'r') as f:
            paper = json.load(f)
        content = '. '.join(section['text'] for section in paper)
    else:
        with open('../input/rccfull/train_test/files/text/'+str(paper_id)+'.txt','r') as f:
            content = f.read()
    sentences = set([clean_paper_sentence(sentence) for sentence in content.split('.')])
    sentences = shorten_sentences(sentences) # make sentences short
    sentences = [sentence for sentence in sentences if len(sentence) > 10] # only accept sentences with length > 10 chars
    
    text = '.'.join(sentences)
    labels_ = [' '.join(label) for label in labels]
    tmpLabels = [label for label in fullLabels if label in text]
    
    
    if NER:
        try:
            nerEnc = NERDict[str(paper_id)]
        except:
            nerEnc = [nerPipe(sentence).ents for sentence in sentences]
            #nerEnc = nerPipe(sentences)
            NERDict[str(paper_id)] = nerEnc
        #for i_sent, sentence in enumerate(sentences):
        #    replace_words = [word for word in nerEnc[i_sent] if 
        #                     (any(word['entity_group']==entgrp for entgrp in ['PER', 'LOC']) 
        #                      and '#' not in word['word'] and word['score']>0.7)]
        #    for word in replace_words:
        #        sentences[i_sent] = sentence.replace(word['word'],'John' if word['entity_group']=='PER' else 'Chicago')
            
    
    sentences = [sentence.split() for sentence in sentences] 
    # positive samples
    for sentence in sentences:
        sentence_str = ' '.join(sentence)
        for label in tmpLabels:
            if label in labels_:
                continue
            elif label not in sentence_str:
                continue
            elif any(label in label2 for label2 in labels_):
                continue
            else:
                labels.append(label.split())
                labels_ = [' '.join(label) for label in labels]
        for label in (labels):
            for pos in find_sublist(sentence, label):
                pos_phrase.append(sentence[max(pos-dN,0):pos+len(label)+dN])
                if MASKEDLM:
                    #dt_point = sentence[:pos] + [DATASET_SYMBOL] + sentence[pos+len(label):]
                    dt_point = sentence[:pos] + ['@'] + replaceWithExt(sentence[pos:pos+len(label)]) + [DATASET_SYMBOL] + sentence[pos+len(label):]
                    dt_point = pre_tokenize_mask(dt_point)
                    corpus.append(' '.join(dt_point))
                    cnt_pos += 1
                else:
                    # dt_point = sentence[:pos] + [DATASET_SYMBOL] + sentence[pos+len(label):]
                    # tmpPhrase = sentence[max(pos-dN,0):pos+len(label)+dN]
                    tmpPhrase = sentence[max(pos-dN,0):pos] + replaceWithExt(sentence[pos:pos+len(label)]) + sentence[pos+len(label):pos+len(label)+dN]
                    dt_point = pre_tokenize(tmpPhrase)
                    corpus.append(' '.join(dt_point))
                    classLabels.append(1)
                    cnt_pos += 1
    
    # negative samples
    for ii,sentence in enumerate(sentences):
        if NER:
            # misc = [word['word'] for word in nerEnc[ii] if word['entity_group'] == 'MIS']
            misc = [word.text for word in nerEnc[ii] if word.label_ in ['GPE', 'ORG', 'NORD']]
        else:
            misc = []
        sentence_str = ' '.join(sentence)
        if all(w not in sentence_str for w in {'data', 'study', 'survey'}):
            continue
        for phrase_start, phrase_end in find_negative_candidates(sentence, labels, misc_candidates = misc):
            neg_phrase.append(sentence[phrase_start:phrase_end+1])
            if MASKEDLM:
                # dt_point = sentence[:phrase_start] + [NONDATA_SYMBOL] + sentence[phrase_end+1:]
                dt_point = sentence[:phrase_start] + ['@'] + sentence[phrase_start:phrase_end+1] + [NONDATA_SYMBOL] + sentence[phrase_end+1:]
                dt_point = pre_tokenize_mask(dt_point)
                corpus.append(' '.join(dt_point))
                cnt_neg += 1
            else:
                # dt_point = sentence[:phrase_start] + [DATASET_SYMBOL] + sentence[phrase_end+1:]
                dt_point = pre_tokenize(sentence[max(phrase_start-dN,0):phrase_end+1+dN])
                corpus.append(' '.join(dt_point))
                cnt_neg += 1
                classLabels.append(0)
    
    # process bar
    pbar.update(1)
    pbar.set_description(f'Training data size: {cnt_pos} postives + {cnt_neg} negatives')

if NER:
    import pickle
    with open('nerEnc.pickle', 'wb') as handle:
        pickle.dump(NERDict, handle)


In [None]:
[line for line in corpus[:] if 'NLSF' in line]

In [None]:
[x for x in neg_phrase[5000:5100] if len(x)==2]

In [None]:
existing_labels

In [None]:
corpus[2200:2220]

### Save data to a file

In [None]:
with open('train_mlm.json', 'w') as f:
    if not MASKEDLM:
        for sentence, lbl in zip(corpus, classLabels):
            row_json = {'text':sentence, 'label':lbl}
            json.dump(row_json, f)
            f.write('\n')
    else:
        for sentence in corpus:
            row_json = {'text':sentence}
            json.dump(row_json, f)
            f.write('\n')

# Fine-tune the Transformer

In [None]:
len(corpus)


In [None]:
if VAL:
    datasets = load_dataset('json',
                data_files={'train' : 'train_mlm.json'},
                split = 'train[:80%]')
    val_datasets = load_dataset('json',
                    data_files={'train' : 'train_mlm.json'},
                    split = 'train[80%:]')
else:
    datasets = load_dataset('json',
                data_files={'train' : 'train_mlm.json'},)


datasets

### Tokenize and collate data

In [None]:
# model_checkpoint = "bert-base-cased"
# model_checkpoint = "roberta-base"
model_checkpoint = 'microsoft/deberta-base'
# model_checkpoint = 'bert-large-cased'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [None]:
if not 'berta' in model_checkpoint:
    tokens_to_mask = [tokenizer.encode(sym)[1] for sym in [DATASET_SYMBOL,NONDATA_SYMBOL]]
else:
    tokens_to_mask = [tokenizer.encode(sym)[1] for sym in [DATASET_SYMBOL,NONDATA_SYMBOL,' '+DATASET_SYMBOL,' '+NONDATA_SYMBOL]]

In [None]:
tokenizer.encode('John')

In [None]:
def tokenize_function(examples):
    examples["text"] = tokenizer(examples["text"])
    return examples["text"]

tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=1, remove_columns=["text"])
val_tokenized_datasets = val_datasets.map(tokenize_function, batched=True, num_proc=1, remove_columns=["text"])

In [None]:
tokenized_datasets

In [None]:
tokenizer.encode('English')

In [None]:
tokens_to_mask

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

class  DataCollator(DataCollatorForLanguageModeling):
    def __init__(self,tokenizer, tokens_to_mask, mlm_probability=0.00):
        super(DataCollator, self).__init__(tokenizer=tokenizer, mlm_probability=0.00)
        self.tokens_to_mask = tokens_to_mask
    
    def mask_tokens(
        self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        for tok in self.tokens_to_mask:
            probability_matrix.masked_fill_(torch.eq(inputs, tok), value=0.99)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

### Load pre-trained model and fine-tune

In [None]:
if MASKEDLM:
    data_collator = DataCollator(tokenizer, tokens_to_mask, mlm_probability=0.0)
    # data_collator = DataCollatorForLanguageModeling(tokenizer)
    model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
else:
    data_collator = DataCollatorWithPadding(tokenizer)
    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="output-mlm" if MASKEDLM else "output-seqClass",
    evaluation_strategy = "steps",
    warmup_steps = 2000,
    learning_rate=2e-5,
    weight_decay=0.01,
    save_steps=2000,
    num_train_epochs=2.5,
    report_to="none",
    logging_steps = 500,
    save_total_limit = 1,
    per_device_train_batch_size = 16,
    per_device_eval_batch_size = 32,
    eval_steps = 1000
)

if VAL:
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets,
        eval_dataset=val_tokenized_datasets,
        data_collator=data_collator,
    )
else:
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        data_collator=data_collator,
    )

In [None]:
trainer.train()

### Save model

In [None]:
if MASKEDLM:
    trainer.model.save_pretrained('mlm-model')
else:
    trainer.model.save_pretrained('seqClass-model')

### Save tokenizer

In [None]:
config = AutoConfig.from_pretrained(model_checkpoint)

tokenizer.save_pretrained('model_tokenizer')
config.save_pretrained('model_tokenizer')