In [None]:
!pip install ../input/pytorchcrf/pytorch_crf-0.7.2-py3-none-any.whl

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
import random
import re
import json
from tqdm import tqdm

from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchcrf import CRF

from nltk.corpus import stopwords

In [None]:
test_files_path = '../input/coleridgeinitiative-show-us-the-data/test'
files = [test_files_path + '/' + f for f in os.listdir(test_files_path)]

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained('../input/k/lichena/coleridge-ner/tokenizer/', add_prefix_space=True)

In [None]:
def clean_text(txt):
    return re.sub('[^A-Za-z()0-9]+', ' ', str(txt))
def clean_label(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt)).lower()

In [None]:
df = pd.read_csv('../input/coleridgeinitiative-show-us-the-data/train.csv')
data = df.groupby(by='Id')['dataset_label'].apply('|'.join)
id_input = data.index.values
labels_input = data.values
label_set = df['dataset_label'].unique().tolist()
label_set = [clean_text(label) for label in label_set]

In [None]:
def chunk_text(full_text, length=250, overlap=25):
    full_text = full_text.split()
    text_len = len(full_text)
    results = []
    i = 0
    while i < text_len:
        results.append(' '.join(full_text[i:i+length]))
        i = i + length - overlap
    return results

In [None]:
def encode(filename, min_length=5):
    txts = []
    ids = []
    naive_results = []
    with open(filename, 'r') as f:
        json_decode = json.load(f)
        full_text = ''
        for data in json_decode: # for each section of the document
            full_text += ' ' + data['section_title'] + ' ' + data['text']
        full_text = clean_text(full_text)
        for label in label_set:
            if label in full_text and label not in naive_results:
                naive_results.append(label.lower())
        chunks = chunk_text(full_text)
        for chunk in chunks:
            txts.append(chunk)
            ids.append(os.path.basename(filename)[:-5])
        
    return txts, ids, os.path.basename(filename)[:-5], naive_results

In [None]:
results = {}
texts = []
ids = []
for f in tqdm(files):
    txts, id_list, _id, naive_results = encode(f)
    texts += txts
    ids += id_list
#     results[_id] = naive_results
    results[_id] = []

In [None]:
class ColeridgeDataset(Dataset):
    def __init__(self, texts, ids):
        self.texts = texts
        self.ids = ids
        
    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx], self.ids[idx]

def collate_fn(batch):
    texts = [item[0] for item in batch]
    ids = [item[1] for item in batch]
    encoding = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
    return encoding.input_ids, encoding.attention_mask, ids

In [None]:
dataset = ColeridgeDataset(texts, ids)
dataloader = DataLoader(dataset, batch_size=128, collate_fn=collate_fn)

In [None]:
def get_titles(input_ids, pred, score, _id):
    titles = []
    scores = []
    ids = []
    for idx in range(input_ids.shape[0]):
        is_title = False
        tmp_toks = []
        tmp_scores = []
        for row in range(input_ids.shape[1]):
            if pred[idx][row] > 0:
                if is_title and pred[idx][row] == 1:
                    titles.append(tokenizer.decode(tmp_toks).strip())
                    scores.append((sum(tmp_scores) / len(tmp_scores)).item())
                    ids.append(_id[idx])
                    tmp_toks = []
                    tmp_scores = []
                tmp_toks.append(input_ids[idx][row])
                tmp_scores.append(score[idx][row])
                is_title = True
            elif is_title and (pred[idx][row] == 0 or row == input_ids.shape[1] - 1 or input_ids[idx][row] == 102):
                is_title = False
                titles.append(tokenizer.decode(tmp_toks).strip())
                scores.append((sum(tmp_scores) / len(tmp_scores)).item())
                ids.append(_id[idx])
                tmp_toks = []
                tmp_scores = []
            elif is_title:
                tmp_toks.append(input_ids[idx][row])
                tmp_scores.append(score[idx][row])
    return titles, scores, ids

In [None]:
def eval_fn(dataloader, model):
    model.eval()
    with torch.no_grad():
        titles = []
        scores = []
        ids = []
        loader = tqdm(dataloader)
        for batch in loader:
            input_ids, attention_mask, _id = batch
            input_ids = input_ids.to(DEVICE)
            attention_mask = attention_mask.type(torch.uint8).to(DEVICE)
            emissions = model(input_ids, attention_mask=attention_mask).logits
            pred = torch.tensor(crf.decode(emissions))
            score = torch.max(torch.softmax(emissions.permute(0,2,1), dim=1), dim=1).values
            _titles, _scores, _ids = get_titles(input_ids, pred, score, _id)
            titles += _titles
            scores += _scores
            ids += _ids
        return titles, scores, ids

In [None]:
checkpoint = torch.load('../input/k/lichena/coleridge-ner/checkpoint.pt', map_location=DEVICE)
model = checkpoint['model']
crf = checkpoint['crf']
model.to(DEVICE);
crf.to(DEVICE);

In [None]:
titles, scores, ids = eval_fn(dataloader, model)

In [None]:
def lcs(X, Y):
    m = len(X)
    n = len(Y)
 
    # Create a table to store lengths of
    # longest common suffixes of substrings.
    # Note that LCSuff[i][j] contains length
    # of longest common suffix of X[0..i-1] and
    # Y[0..j-1]. The first row and first
    # column entries have no logical meaning,
    # they are used only for simplicity of program
    LCSuff = [[0 for i in range(n + 1)]
                 for j in range(m + 1)]
 
    # To store length of the
    # longest common substring
    length = 0
 
    # To store the index of the cell
    # which contains the maximum value.
    # This cell's index helps in building
    # up the longest common substring
    # from right to left.
    row, col = 0, 0
 
    # Following steps build LCSuff[m+1][n+1]
    # in bottom up fashion.
    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                LCSuff[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                LCSuff[i][j] = LCSuff[i - 1][j - 1] + 1
                if length < LCSuff[i][j]:
                    length = LCSuff[i][j]
                    row = i
                    col = j
            else:
                LCSuff[i][j] = 0
 
    # if true, then no common substring exists
    if length == 0:
        return
 
    # allocate space for the longest
    # common substring
    resultStr = ['0'] * length
 
    # traverse up diagonally form the
    # (row, col) cell until LCSuff[row][col] != 0
    while LCSuff[row][col] != 0:
        length -= 1
        resultStr[length] = X[row - 1] # or Y[col-1]
 
        # move diagonally up to previous cell
        row -= 1
        col -= 1
 
    # required longest common substring
    return ''.join(resultStr)

In [None]:
stops = stopwords.words('english')
stops.append('[sep]')
stops.append('[PAD]')
stops.append('[pad]')
stops.append('pad')
stops.append('PAD')


def clean_front(split):
    for idx, e in enumerate(split):
        if e not in stops:
            return split[idx:]
        
def clean_back(split):
    for i in reversed(range(len(split))):
        if split[i] not in stops:
            return split[:i+1]
        
def clean_result(title):
    title = title.replace('<s>', '')
    title = title.replace('</s>', '')
    split = title.split()
    if split:
        if len(split[-1]) <= 2:
            split = split[:-1]
    if split:
        split = clean_front(split)
    if split:
        split = clean_back(split)
    if split:
        title = ' '.join(split)
        return title

In [None]:
threshold = checkpoint['thresh']
print(checkpoint['thresh'])
print(checkpoint['epoch'])
# threshold = 0.6
for idx, _id in enumerate(tqdm(ids)):
    if _id not in results:
        results[_id] = []
    else:
        title = clean_label(titles[idx])
        if title and len(title) > 2 and scores[idx] > threshold and ' ' in title:
            cleaned = clean_result(title)
            if cleaned and cleaned not in results[_id]:
                is_new = True
                for idx, element in enumerate(results[_id]):
                    if element in cleaned:
                        is_new = False
                        results[_id][idx] = cleaned
                    elif cleaned in element:
                        is_new = False
                        results[_id][idx] = element
                if is_new:
                    results[_id].append(cleaned)

In [None]:
import csv 

# name of csv file 
filename = "submission.csv"
    
# writing to csv file 
with open(filename, 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile) 
    
    # writing the fields "
    csvwriter.writerow(["Id", "PredictionString"]) 

    for (k,v) in results.items():
        if len(v) > 0:
            # writing the data rows 
            csvwriter.writerow([k, '|'.join(v)])
        else:
            csvwriter.writerow([k, ''])

In [None]:
# results

In [None]:
# titles
# results
# run 51 was pretty good