In [None]:
!pip install datasets --no-index --find-links=file:///kaggle/input/coleridge-packages/packages/datasets
!pip install ../input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
!pip install ../input/coleridge-packages/tokenizers-0.10.1-cp37-cp37m-manylinux1_x86_64.whl
!pip install ../input/coleridge-packages/transformers-4.5.0.dev0-py3-none-any.whl

In [None]:
import os
import re
import json
import time
import datetime
import random
import glob
import importlib

import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, \
AutoModelForMaskedLM, Trainer, TrainingArguments, pipeline

sns.set()
random.seed(200)
np.random.seed(300)
torch.manual_seed(2021)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
train_path = '../input/coleridgeinitiative-show-us-the-data/train.csv'
train = pd.read_csv(train_path)

In [None]:
sample_submission_path = '../input/coleridgeinitiative-show-us-the-data/sample_submission.csv'
sample_submission = pd.read_csv(sample_submission_path)

paper_test_folder = '../input/coleridgeinitiative-show-us-the-data/test'
papers = {}
for paper_id in sample_submission['Id']:
    with open(f'{paper_test_folder}/{paper_id}.json', 'r') as f:
        paper = json.load(f)
        papers[paper_id] = paper

In [None]:
all_labels = set()

for label_1, label_2, label_3 in train[['dataset_title', 'dataset_label', 'cleaned_label']].itertuples(index=False):
    all_labels.add(str(label_1).lower())
    all_labels.add(str(label_2).lower())
    all_labels.add(str(label_3).lower())
    
print(f'No. different labels: {len(all_labels)}')

In [None]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower()).strip()

def totally_clean_text(txt):
    txt = clean_text(txt)
    txt = re.sub(' +', ' ', txt)
    return txt

In [None]:
literal_preds = []

for paper_id in sample_submission['Id']:
    paper = papers[paper_id]
    text_1 = '. '.join(section['text'] for section in paper).lower()
    text_2 = totally_clean_text(text_1)
    
    labels = set()
    for label in all_labels:
        if label in text_1 or label in text_2:
            labels.add(clean_text(label))
    
    literal_preds.append('|'.join(labels))

In [None]:
literal_preds[:5]

In [None]:
PRETRAINED_PATH = '../input/coleridge-mlm-model/mlm-model'
TOKENIZER_PATH = '../input/coleridge-mlm-model/model_tokenizer'

MAX_LENGTH = 64
OVERLAP = 25 #20

PREDICT_BATCH = 64 # 32 # a higher value requires higher GPU memory usage

DATASET_SYMBOL = '$' # this symbol represents a dataset name
NONDATA_SYMBOL = '#' # this symbol represents a non-dataset name

In [None]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True)
model = AutoModelForMaskedLM.from_pretrained(PRETRAINED_PATH)

mlm = pipeline(
    'fill-mask', 
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

In [None]:
def jaccard_similarity(s1, s2):
    l1 = s1.split(" ")
    l2 = s2.split(" ")    
    intersection = len(list(set(l1).intersection(l2)))
    union = (len(l1) + len(l2)) - intersection
    return float(intersection) / union

def clean_paper_sentence(s):
    """
    This function is essentially clean_text without lowercasing.
    """
    s = re.sub('[^A-Za-z0-9]+', ' ', str(s)).strip()
    s = re.sub(' +', ' ', s)
    return s

def shorten_sentences(sentences):
    """
    Sentences that have more than MAX_LENGTH words will be split
    into multiple sentences with overlappings.
    """
    short_sentences = []
    for sentence in sentences:
        words = sentence.split()
        if len(words) > MAX_LENGTH:
            for p in range(0, len(words), MAX_LENGTH - OVERLAP):
                short_sentences.append(' '.join(words[p:p+MAX_LENGTH]))
        else:
            short_sentences.append(sentence)
    return short_sentences

connection_tokens = {'s', 'of', 'and', 'in', 'on', 'for', 'data', 'dataset'}
def find_mask_candidates(sentence):
    """
    Extract masking candidates for Masked Dataset Modeling from a given $sentence.
    A candidate should be a continuous sequence of at least 2 words, 
    each of these words either has the first letter in uppercase or is one of
    the connection words ($connection_tokens). Furthermore, the connection 
    tokens are not allowed to appear at the beginning and the end of the
    sequence.
    """
    def candidate_qualified(words):
        while len(words) and words[0].lower() in connection_tokens:
            words = words[1:]
        while len(words) and words[-1].lower() in connection_tokens:
            words = words[:-1]
        
        return len(words) >= 2
    
    candidates = []
    
    phrase_start, phrase_end = -1, -1
    for id in range(1, len(sentence)):
        word = sentence[id]
        if word[0].isupper() or word in connection_tokens:
            if phrase_start == -1:
                phrase_start = phrase_end = id
            else:
                phrase_end = id
        else:
            if phrase_start != -1:
                if candidate_qualified(sentence[phrase_start:phrase_end+1]):
                    candidates.append((phrase_start, phrase_end))
                phrase_start = phrase_end = -1
    
    if phrase_start != -1:
        if candidate_qualified(sentence[phrase_start:phrase_end+1]):
            candidates.append((phrase_start, phrase_end))
    
    return candidates

In [None]:
mask = mlm.tokenizer.mask_token

In [None]:
all_test_data = []

# pbar = tqdm(total = len(sample_submission))
for paper_id in sample_submission['Id']:
    # load paper
    paper = papers[paper_id]
    
    # extract sentences
    sentences = set([clean_paper_sentence(sentence) for section in paper 
                     for sentence in section['text'].split('.')
                    ])
    sentences = shorten_sentences(sentences) # make sentences short
    sentences = [sentence for sentence in sentences if len(sentence) > 10] # only accept sentences with length > 10 chars
    sentences = [sentence for sentence in sentences if any(word in sentence.lower() for word in ['data', 'study'])]
    sentences = [sentence.split() for sentence in sentences] # sentence = list of words
    
    # mask
    test_data = []
    for sentence in sentences:
        for phrase_start, phrase_end in find_mask_candidates(sentence):
            dt_point = sentence[:phrase_start] + [mask] + sentence[phrase_end+1:]
            test_data.append((' '.join(dt_point), ' '.join(sentence[phrase_start:phrase_end+1]))) # (masked text, phrase)
    
    all_test_data.append(test_data)
    
    # process bar
#     pbar.update(1)

In [None]:
pred_mlm_labels = []

pbar = tqdm(total = len(all_test_data))
for test_data in all_test_data:
    pred_bag = set()
    
    if len(test_data):
        texts, phrases = list(zip(*test_data))
        mlm_pred = []
        for p_id in range(0, len(texts), PREDICT_BATCH):
            batch_texts = texts[p_id:p_id+PREDICT_BATCH]
            batch_pred = mlm(list(batch_texts), targets=[f' {DATASET_SYMBOL}', f' {NONDATA_SYMBOL}'])
            
            if len(batch_texts) == 1:
                batch_pred = [batch_pred]
            
            mlm_pred.extend(batch_pred)
        
        for (result1, result2), phrase in zip(mlm_pred, phrases):
            if (result1['score'] > result2['score']*1.5 and result1['token_str'] == DATASET_SYMBOL) or\
               (result2['score'] > result1['score']*1.5 and result2['token_str'] == NONDATA_SYMBOL):
                pred_bag.add(clean_text(phrase))
    
    # filter labels by jaccard score 
    filtered_labels = []
    
    for label in sorted(pred_bag, key=len, reverse=True):
        if len(filtered_labels) == 0 or all(jaccard_similarity(label, got_label) < 0.75 for got_label in filtered_labels):
            filtered_labels.append(label)
            
    pred_mlm_labels.append('|'.join(filtered_labels))
    pbar.update(1)

In [None]:
pred_mlm_labels[:5]

In [None]:
final_predictions = []
for literal_match, mlm_pred in zip(literal_preds, pred_mlm_labels):
    if literal_match:
        final_predictions.append(literal_match)
    else:
        final_predictions.append(mlm_pred)

In [None]:
sample_submission['PredictionString'] = final_predictions
sample_submission.to_csv('submission.csv', index=False)