In [1]:
!pip install datasets --no-index --find-links=file:///kaggle/input/coleridge-packages/packages/datasets
!pip install ../input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
!pip install ../input/coleridge-packages/tokenizers-0.10.1-cp37-cp37m-manylinux1_x86_64.whl
!pip install ../input/coleridge-packages/transformers-4.5.0.dev0-py3-none-any.whl

Looking in links: file:///kaggle/input/coleridge-packages/packages/datasets
Processing /kaggle/input/coleridge-packages/packages/datasets/datasets-1.5.0-py3-none-any.whl
Processing /kaggle/input/coleridge-packages/packages/datasets/xxhash-2.0.0-cp37-cp37m-manylinux2010_x86_64.whl
Processing /kaggle/input/coleridge-packages/packages/datasets/tqdm-4.49.0-py2.py3-none-any.whl
Processing /kaggle/input/coleridge-packages/packages/datasets/huggingface_hub-0.0.7-py3-none-any.whl
Installing collected packages: tqdm, xxhash, huggingface-hub, datasets
  Attempting uninstall: tqdm
    Found existing installation: tqdm 4.56.2
    Uninstalling tqdm-4.56.2:
      Successfully uninstalled tqdm-4.56.2
Successfully installed datasets-1.5.0 huggingface-hub-0.0.7 tqdm-4.49.0 xxhash-2.0.0
Processing /kaggle/input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2
Processing /kaggle/input/coleridge-packages/tokenizers-

In [2]:
import os
import re
import json
import time
import datetime
import random
import glob
import importlib

import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, \
AutoModelForMaskedLM, Trainer, TrainingArguments, pipeline

sns.set()
random.seed(123)
np.random.seed(456)
torch.manual_seed(2021)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
PRETRAINED_PATH = '../input/coleridge-bert-mlmv4/mlm-model'
TOKENIZER_PATH = '../input/coleridge-bert-mlmv4/model_tokenizer'

MAX_LEN = 64
OVERLAP = 20
PREDICT_BATCH = 32

PREDICT_BATCH = 32

DATASET_SYMBOL = '$'
NONDATA_SYMBOL = '#'

In [4]:
submission_path = '../input/coleridgeinitiative-show-us-the-data/sample_submission.csv'
submission = pd.read_csv(submission_path)

In [5]:
test_folder = '../input/coleridgeinitiative-show-us-the-data/test'
paper_text = {}
for p_id in submission['Id']:
    with open(f'{test_folder}/{p_id}.json','r') as f:
        text = json.load(f)
        paper_text[p_id] = text

In [6]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
model = AutoModelForMaskedLM.from_pretrained(PRETRAINED_PATH)

mlm = pipeline(
    'fill-mask', 
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

mask = mlm.tokenizer.mask_token

In [7]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower()).strip()

def jaccard_similarity(s1, s2):
    l1 = s1.split(" ")
    l2 = s2.split(" ")    
    intersection = len(list(set(l1).intersection(l2)))
    union = (len(l1) + len(l2)) - intersection
    return float(intersection) / union

def clean_paper_sentence(s):
    s = re.sub('[^A-Za-z0-9]+', ' ', str(s)).strip()
    s = re.sub(' +', ' ', s)
    return s

def shorten_sentences(sentences):
    short_sentences = []
    for sentence in sentences:
        words = sentence.split()
        if len(words) > MAX_LEN:
            for p in range(0, len(words), MAX_LEN - OVERLAP):
                short_sentences.append(' '.join(words[p:p+MAX_LEN]))
        else:
            short_sentences.append(sentence)
    return short_sentences

connection_tokens = {'s', 'of', 'and', 'in', 'on', 'for', 'data', 'dataset'}
def find_mask_candidates(sentence):
    def candidate_qualified(words):
        while len(words) and words[0].lower() in connection_tokens:
            words = words[1:]
        while len(words) and words[-1].lower() in connection_tokens:
            words = words[:-1]
        
        return len(words) >= 2
    
    candidates = []
    
    phrase_start, phrase_end = -1, -1
    for id in range(1, len(sentence)):
        word = sentence[id]
        if word[0].isupper() or word in connection_tokens:
            if phrase_start == -1:
                phrase_start = phrase_end = id
            else:
                phrase_end = id
        else:
            if phrase_start != -1:
                if candidate_qualified(sentence[phrase_start:phrase_end+1]):
                    candidates.append((phrase_start, phrase_end))
                phrase_start = phrase_end = -1
    
    if phrase_start != -1:
        if candidate_qualified(sentence[phrase_start:phrase_end+1]):
            candidates.append((phrase_start, phrase_end))
    
    return candidates

In [8]:
# all_test_data = []

# for paper_id in submission['Id']:
#     paper = paper_text[paper_id]
    
# #     sentences = set([clean_paper_sentence(sentence) for section in paper 
# #                      for sentence in section['text'].split('.')
# #                     ])
# #     sentences = shorten_sentences(sentences)
# #     sentences = [sentence for sentence in sentences if len(sentence) > 10]
# #     sentences = [sentence for sentence in sentences if any(word in sentence.lower() for word in ['data', 'study'])]
# #     sentences = [sentence.split() for sentence in sentences]

#     content = '. '.join(section['text'] for section in paper)
#     sentences = set([clean_paper_sentence(sentence) for sentence in content.split('.')])
#     sentences = shorten_sentences(sentences) # make sentences short
#     sentences = [sentence for sentence in sentences if len(sentence) > 10] # only accept sentences with length > 10 chars
#     sentences = [sentence.split() for sentence in sentences]
    
#     # mask
#     test_data = []
#     for sentence in sentences:
#         for phrase_start, phrase_end in find_mask_candidates(sentence):
#             dt_point = sentence[:phrase_start] + [mask] + sentence[phrase_end+1:]
#             test_data.append((' '.join(dt_point), ' '.join(sentence[phrase_start:phrase_end+1]))) # (masked text, phrase)
    
#     all_test_data.append(test_data)

In [9]:
all_test_data = []

for paper_id in submission['Id']:
    paper = paper_text[paper_id]
    
#     sentences = set([clean_paper_sentence(sentence) for section in paper 
#                      for sentence in section['text'].split('.')
#                     ])
#     sentences = shorten_sentences(sentences)
#     sentences = [sentence for sentence in sentences if len(sentence) > 10]
#     sentences = [sentence for sentence in sentences if any(word in sentence.lower() for word in ['data', 'study'])]
#     sentences = [sentence.split() for sentence in sentences]

    content = '. '.join(section['text'] for section in paper)
    sentences = set([clean_paper_sentence(sentence) for sentence in content.split('.')])
    sentences = shorten_sentences(sentences) # make sentences short
    sentences = [sentence for sentence in sentences if len(sentence) > 10] # only accept sentences with length > 10 chars
    sentences = [sentence.split() for sentence in sentences]
    
    # mask
    test_data = []
    for sentence in sentences:
        for phrase_start, phrase_end in find_mask_candidates(sentence):
            dt_point = sentence[:phrase_start] + [mask] + sentence[phrase_end+1:]
            test_data.append((' '.join(dt_point), ' '.join(sentence[phrase_start:phrase_end+1]))) # (masked text, phrase)
    
    all_test_data.append(test_data)

In [10]:
# all_test_data = []

# # pbar = tqdm(total = len(sample_submission))
# for paper_id in sample_submission['Id']:
#     # load paper
#     paper = papers[paper_id]
    
#     # extract sentences
#     sentences = set([clean_paper_sentence(sentence) for section in paper 
#                      for sentence in section['text'].split('.')
#                     ])
#     sentences = shorten_sentences(sentences) # make sentences short
#     sentences = [sentence for sentence in sentences if len(sentence) > 10] # only accept sentences with length > 10 chars
#     sentences = [sentence for sentence in sentences if any(word in sentence.lower() for word in ['data', 'study'])]
#     sentences = [sentence.split() for sentence in sentences] # sentence = list of words
    
#     # mask
#     test_data = []
#     for sentence in sentences:
#         for phrase_start, phrase_end in find_mask_candidates(sentence):
#             dt_point = sentence[:phrase_start] + [mask] + sentence[phrase_end+1:]
#             test_data.append((' '.join(dt_point), ' '.join(sentence[phrase_start:phrase_end+1]))) # (masked text, phrase)
    
#     all_test_data.append(test_data)
    
#     # process bar
# #     pbar.update(1)

### Predict

In [11]:
pred_labels = []

pbar = tqdm(total = len(all_test_data))
for test_data in all_test_data:
    pred_bag = set()
    
    if len(test_data):
        texts, phrases = list(zip(*test_data))
        mlm_pred = []
        for p_id in range(0, len(texts), PREDICT_BATCH):
            batch_texts = texts[p_id:p_id+PREDICT_BATCH]
            batch_pred = mlm(list(batch_texts), targets=[f' {DATASET_SYMBOL}', f' {NONDATA_SYMBOL}'])
            
            if len(batch_texts) == 1:
                batch_pred = [batch_pred]
            
            mlm_pred.extend(batch_pred)
        
        for (result1, result2), phrase in zip(mlm_pred, phrases):
            if (result1['score'] > result2['score']*1.5 and result1['token_str'] == DATASET_SYMBOL) or\
               (result2['score'] > result1['score']*1.5 and result2['token_str'] == NONDATA_SYMBOL):
                pred_bag.add(clean_text(phrase))
    
    # filter labels by jaccard score 
    filtered_labels = []
    
    for label in sorted(pred_bag, key=len, reverse=True):
        if len(filtered_labels) == 0 or all(jaccard_similarity(label, got_label) < 0.75 for got_label in filtered_labels):
            filtered_labels.append(label)
            
    pred_labels.append('|'.join(filtered_labels))
    pbar.update(1)

submission['PredictionString'] = pred_labels
submission.to_csv('submission.csv', index=False)

100%|██████████| 4/4 [02:04<00:00, 30.26s/it]

In [12]:
submission['PredictionString'].head()

0    cohorts for heart and aging research in genomi...
1    center for education statistics institute of e...
2    dataset data management in arcgis|s office for...
3                           food access research atlas
Name: PredictionString, dtype: object