This notebook is a template on using Literal Matching + Masked Language Modeling to identify datasets in papers.

The training of the Bert model was done in another notebook: [BERT - Masked Dataset Modeling.](https://www.kaggle.com/khubchandani/bert-masked-dataset-modeling) and the trained model thus arrived at is at https://www.kaggle.com/khubchandani/bert-mlm-v6

This notebook was forked from [[Coleridge] Predict with Masked Dataset Modeling](https://www.kaggle.com/tungmphung/coleridge-predict-with-masked-dataset-modeling) during the early stages of the competition. The external dataset used is at https://www.kaggle.com/khubchandani/data-set-800-2 The dataset and the way it is used is derived from https://www.kaggle.com/mlconsult/isin-big-dataset and the datasets used therein.

The approach is:
- Locate all the sequences of capitalized words (these sequences may contain some stopwords),
- Replace each sequence with one of 2 special symbols (e.g. $ and #), implying if that sequence represents a dataset name or not.
- Have the model learn the MLM task.

# Install packages

In [None]:
!pip install datasets --no-index --find-links=file:///kaggle/input/coleridge-packages/packages/datasets
!pip install ../input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
!pip install ../input/coleridge-packages/tokenizers-0.10.1-cp37-cp37m-manylinux1_x86_64.whl
!pip install ../input/coleridge-packages/transformers-4.5.0.dev0-py3-none-any.whl

# Import

In [None]:
import os
import re
import json
import time
import datetime
import random
import glob
import importlib

import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, \
AutoModelForMaskedLM, Trainer, TrainingArguments, pipeline

import string

import spacy
nlp1 = spacy.load('en')

sns.set()
random.seed(123)
np.random.seed(456)
torch.manual_seed(2021)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Load data

In [None]:
train_path = '../input/coleridgeinitiative-show-us-the-data/train.csv'
train_data_path = '../input/coleridgeinitiative-show-us-the-data/train'
train = pd.read_csv(train_path)

In [None]:
sample_submission_path = '../input/coleridgeinitiative-show-us-the-data/sample_submission.csv'
sample_submission = pd.read_csv(sample_submission_path)

paper_test_folder = '../input/coleridgeinitiative-show-us-the-data/test'
papers = {}
for paper_id in sample_submission['Id']:
    with open(f'{paper_test_folder}/{paper_id}.json', 'r') as f:
        paper = json.load(f)
        papers[paper_id] = paper

In [None]:
def read_json_pub(filename, train_data_path=train_data_path, output='text'):
    json_path = os.path.join(train_data_path, (filename+'.json'))
    headings = []
    contents = []
    combined = []
    with open(json_path, 'r') as f:
        json_decode = json.load(f)
        for data in json_decode:
            headings.append(data.get('section_title'))
            contents.append(data.get('text'))
            combined.append(data.get('section_title'))
            combined.append(data.get('text'))
    
    all_headings = ' '.join(headings)
    all_contents = ' '.join(contents)
    all_data = '. '.join(combined)
    
    if output == 'text':
        return all_contents
    elif output == 'head':
        return all_headings
    else:
        return all_data

In [None]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower()).strip()

def totally_clean_text(txt):
    txt = clean_text(txt)
    txt = re.sub(' +', ' ', txt)
    return txt

def text_cleaning(text):
    '''
    Converts all text to lower case, Removes special charecters, emojis and multiple spaces
    text - Sentence that needs to be cleaned
    '''
    text = re.sub('[^A-Za-z0-9]+', ' ', str(text).lower()).strip()
    text = re.sub(' +', ' ', text)
    emoji_pattern = re.compile("["
                               u"\U0001F600-\U0001F64F"  # emoticons
                               u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                               u"\U0001F680-\U0001F6FF"  # transport & map symbols
                               u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                               "]+", flags=re.UNICODE)
    text = emoji_pattern.sub(r'', text)
    return text

# Masked Dataset Modeling

### Paths and Hyperparameters

In [None]:
#PRETRAINED_PATH = '../input/coleridge-mlm-model/mlm-model'
#TOKENIZER_PATH = '../input/coleridge-mlm-model/model_tokenizer'

PRETRAINED_PATH = '../input/bert-mlm-v6/mlm-model'
TOKENIZER_PATH = '../input/bert-mlm-v6/model_tokenizer'

#PRETRAINED_PATH = '../input/bert-masked-dataset-modeling/mlm-model'
#TOKENIZER_PATH = '../input/bert-masked-dataset-modeling/model_tokenizer'

#PRETRAINED_PATH = '../input/k/khubchandani/bert-masked-dataset-modeling/mlm-model'
#TOKENIZER_PATH = '../input/k/khubchandani/bert-masked-dataset-modeling/model_tokenizer'

MAX_LENGTH = 64
OVERLAP = 20

PREDICT_BATCH = 32 # a higher value requires higher GPU memory usage

DATASET_SYMBOL = '$' # this symbol represents a dataset name
NONDATA_SYMBOL = '#' # this symbol represents a non-dataset name

# Transform data to MLM format

### Load model and tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True)
model = AutoModelForMaskedLM.from_pretrained(PRETRAINED_PATH)

mlm = pipeline(
    'fill-mask', 
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

### Auxiliary functions

In [None]:
def jaccard_similarity(s1, s2):
    l1 = s1.split(" ")
    l2 = s2.split(" ")    
    intersection = len(list(set(l1).intersection(l2)))
    union = (len(l1) + len(l2)) - intersection
    return float(intersection) / union

def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

def clean_paper_sentence(s):
    """
    This function is essentially clean_text without lowercasing.
    """
    s = re.sub('[^A-Za-z0-9]+', ' ', str(s)).strip()
    s = re.sub(' +', ' ', s)
    return s

def shorten_sentences(sentences):
    """
    Sentences that have more than MAX_LENGTH words will be split
    into multiple sentences with overlappings.
    """
    short_sentences = []
    for sentence in sentences:
        words = sentence.split()
        if len(words) > MAX_LENGTH:
            for p in range(0, len(words), MAX_LENGTH - OVERLAP):
                short_sentences.append(' '.join(words[p:p+MAX_LENGTH]))
        else:
            short_sentences.append(sentence)
    return short_sentences

connection_tokens = {'s','of','and','in','on','for'}

def find_mask_candidates(sentence):
    """
    Extract masking candidates for Masked Dataset Modeling from a given $sentence.
    A candidate should be a continuous sequence of at least 2 words, 
    each of these words either has the first letter in uppercase or is one of
    the connection words ($connection_tokens). Furthermore, the connection 
    tokens are not allowed to appear at the beginning and the end of the
    sequence.
    """
    def candidate_qualified(words):
        while len(words) and words[0].lower() in connection_tokens:
            words = words[1:]
        while len(words) and words[-1].lower() in connection_tokens:
            words = words[:-1]
    
        return len(words) >= 4 #4, 2, 3, 1, 5 VARIOUS CHOICES EXPLORED
    
    candidates = []
    
    phrase_start, phrase_end = -1, -1
    for id in range(1, len(sentence)):
        word = sentence[id]
        if word[0].isupper() or word in connection_tokens:
            if phrase_start == -1:
                phrase_start = phrase_end = id
            else:
                phrase_end = id
        else:
            if phrase_start != -1:
                if candidate_qualified(sentence[phrase_start:phrase_end+1]):
                    candidates.append((phrase_start, phrase_end))
                phrase_start = phrase_end = -1
    
    if phrase_start != -1:
        if candidate_qualified(sentence[phrase_start:phrase_end+1]):
            candidates.append((phrase_start, phrase_end))
    
    return candidates

### Transform

In [None]:
mask = mlm.tokenizer.mask_token

In [None]:
all_test_data = []

for paper_id in sample_submission['Id']:
    # load paper
    paper = papers[paper_id]
    
    # extract sentences
    sentences = set([clean_paper_sentence(sentence) for section in paper 
                     for sentence in re.split("[,.;:]", section['text'])
                     #for sentence in section['text'].split('.')
                    ])
    sentences = shorten_sentences(sentences) # make sentences short
    sentences = [sentence for sentence in sentences if len(sentence) > 21] # only accept sentences greater than this length
    sentences = [sentence for sentence in sentences if any(word in sentence.lower() for word in ['survey','study','database','catalog','dataset'])]
    sentences = [sentence.split() for sentence in sentences] # sentence = list of words
    test_data = []
    for sentence in sentences:
        for phrase_start, phrase_end in find_mask_candidates(sentence):
            dt_point = sentence[:phrase_start] + [mask] + sentence[phrase_end+1:]
            test_data.append((' '.join(dt_point), ' '.join(sentence[phrase_start:phrase_end+1]))) # (masked text, phrase)
    
    all_test_data.append(test_data)

### Predict

In [None]:
pred_mlm_labels = []

pbar = tqdm(total = len(all_test_data))
for test_data in all_test_data:
    pred_bag = set()
    
    if len(test_data):
        texts, phrases = list(zip(*test_data))
        mlm_pred = []
        for p_id in range(0, len(texts), PREDICT_BATCH):
            batch_texts = texts[p_id:p_id+PREDICT_BATCH]
            batch_pred = mlm(list(batch_texts), targets=[f' {DATASET_SYMBOL}', f' {NONDATA_SYMBOL}'])
            
            if len(batch_texts) == 1:
                batch_pred = [batch_pred]
            
            mlm_pred.extend(batch_pred)
        
        for (result1, result2), phrase in zip(mlm_pred, phrases):
            if (result1['score'] > result2['score']*1.6 and result1['token_str'] == DATASET_SYMBOL) or\
               (result2['score'] > result1['score']*1.6 and result2['token_str'] == NONDATA_SYMBOL):
                pred_bag.add(clean_text(phrase))
    
    # filter labels by jaccard score 
    filtered_labels = []
    
    for label in sorted(pred_bag, key=len, reverse=True):
        if len(filtered_labels) == 0 or all(jaccard(label, got_label) < 0.75 for got_label in filtered_labels):#0.75
            filtered_labels.append(label)
            
    pred_mlm_labels.append('|'.join(filtered_labels))
    pbar.update(1)

In [None]:
pred_mlm_labels[:5]

In [None]:
start_time = time.time()
from fuzzywuzzy import fuzz
def get_ratio(name, name1):
    return fuzz.token_set_ratio(name, name1)
final = []
for preds in pred_mlm_labels:
    got_label=preds.split('|')
    filtered=[]
    filtered_labels = ''
    for label in sorted(got_label, key=len, reverse=True):
        if len(filtered) == 0 or all(get_ratio(label, got) < 100 for got in filtered):
            filtered.append(label)
            if filtered_labels!='':
                filtered_labels += '|' + label
            if filtered_labels=='':
                filtered_labels=label
    final.append(filtered_labels)
print("--- %s seconds ---" % (time.time() - start_time))

In [None]:
pred_mlm_labels = final

In [None]:
pred_mlm_labels[:5]

In [None]:
df2=pd.read_csv('../input/data-set-800-2/data_set_800_4.csv')

In [None]:
def is_in_big(to_append):
    large_string = str(read_json_pub(to_append[0],path))
    clean_string=text_cleaning(large_string)
    for index, row2 in df2.iterrows():
        query_string = str(row2['title'])
        if query_string in clean_string:
            if to_append[1]!='' and clean_text(query_string) not in to_append[1]:
                to_append[1] += '|' + clean_text(query_string)
            if to_append[1]=='':
                to_append[1]= clean_text(query_string)
    got_label=to_append[1].split('|')
    filtered=[]
    filtered_labels = ''
    for label in sorted(got_label, key=len):
        if len(filtered) == 0 or all(jaccard(label, got) < 1.0 for got in filtered):
            filtered.append(label)
            if filtered_labels!='':
                filtered_labels += '|' + label
            if filtered_labels=='':
                filtered_labels=label
    to_append[1] = filtered_labels     
    return to_append

def submit(chunk):
    chunk_sub = pd.DataFrame(columns = column_names)
    for index, row in chunk.iterrows():
        to_append=[row['Id'],'']
        to_append = is_in_big(to_append)
        df_length = len(chunk_sub)
        chunk_sub.loc[df_length] = to_append
    return chunk_sub

def literals(chunk):
    chunk_preds = []
    for index, row in chunk.iterrows():
        to_append=[row['Id'],'']
        to_append = is_in_big(to_append)
        chunk_preds.append(to_append[1])
    return chunk_preds

In [None]:
import multiprocessing as mp
num_processes = mp.cpu_count()
print(num_processes)
chunk_size = int(sample_submission.shape[0]/num_processes)
print(chunk_size)
chunks = [sample_submission.iloc[i:i + chunk_size,:] for i in range(0, sample_submission.shape[0], chunk_size)]
path = paper_test_folder
column_names = ["Id", "PredictionString"]

start_time = time.time()
pool = mp.Pool(processes=num_processes)
submission = pd.concat(pool.map(submit, chunks))
pool.close()
pool.join()
literal_preds = submission["PredictionString"].tolist()
print("--- %s seconds ---" % (time.time() - start_time))
submission

## Aggregate final predictions and write submission file

In [None]:
final_predictions = []
for literal_match, mlm_pred in zip(literal_preds, pred_mlm_labels):
    if literal_match:
        final_predictions.append(literal_match)
    else:
        final_predictions.append(mlm_pred)
#final_predictions = pred_mlm_labels # when string matching solution is to be omitted

In [None]:
final_predictions[:5]

In [None]:
sample_submission['PredictionString'] = final_predictions
sample_submission.to_csv('submission.csv', index=False)

In [None]:
sample_submission.head()