# Installation

In [None]:
!pip install pytorch-crf

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
import random
import re
import json
import gc
from tqdm import tqdm

from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoConfig
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchcrf import CRF

In [None]:
MODEL_TYPE = 'roberta-base'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained(MODEL_TYPE, add_prefix_space=True)
config = AutoConfig.from_pretrained(MODEL_TYPE, add_prefix_space=True)
tokenizer.save_pretrained("./tokenizer/")
config.save_pretrained('./tokenizer')

# Preprocessing the data

In [None]:
# read in the csvs
df = pd.read_csv('../input/k/lichena/k/lichena/coleridge-pre-processing/data.csv')
labels_list = pd.read_csv('../input/k/lichena/k/lichena/coleridge-pre-processing/labels.csv').labels.unique().tolist()
label_freq = {}
y_true = {}
with open('../input/k/lichena/k/lichena/coleridge-pre-processing/frequencies.json', 'r') as f:
    label_freq = json.load(f)

with open('../input/k/lichena/k/lichena/coleridge-pre-processing/y_true.json', 'r') as f:
    y_true = json.load(f)

In [None]:
# {k: v for k, v in sorted(label_freq.items(), key=lambda item: item[1], reverse=True)}

In [None]:
def clean_text(txt):
    text = re.sub('[^A-Za-z0-9()]+', ' ', str(txt)).strip()
    return re.sub('\s+', ' ', text)

def clean_label_result(label):
    return re.sub('[^A-Za-z0-9]+', ' ', str(label)).lower().strip()

In [None]:
# tokenize the labels and create a map of label -> embedding
def tokenize_labels(labels):
    label_set = {}
    for label in labels:
        label_set[label] = torch.tensor(tokenizer(label,max_length=512).input_ids[1:-1])
    return label_set

cleaned_labels = [clean_text(label) for label in labels_list]
label_set = tokenize_labels(cleaned_labels)

In [None]:
# split labels into train/test groups
num_tests = 0
x = 0
test_label_set = []
while x < 150:
    sample = random.sample(list(label_set), 1)[0]
    if sample in label_freq:
        if label_freq[sample] < 300:
            num_tests += label_freq[sample]
            test_label_set.append(sample)
            x+=1

In [None]:
pd.Series(test_label_set).to_csv('test_label_set.csv', index=False)

In [None]:
def generate_dataset(df):
    train_tp_texts = []
    train_tp_labels = []
    train_fp_texts = []
    train_tp_ids = []
    train_fp_ids = []
    test_tp_texts = []
    test_tp_labels = []
    test_fp_texts = []
    test_tp_ids = []
    test_fp_ids = []
    for index, row in tqdm(df.iterrows(), total=len(df)):
        chunk, labels = row['chunks'], row['labels']
        if labels == labels:
            is_train = True
            for label in labels.split("|"):
                if label in test_label_set:
                    is_train=False
            if is_train:
                train_tp_texts.append(chunk)
                train_tp_labels.append(labels.split("|"))
                train_tp_ids.append(row['ids'])
            else:
                test_tp_texts.append(chunk)
                test_tp_labels.append(labels.split("|"))
                test_tp_ids.append(row['ids'])   
        else:
            if random.random() > 0.5:
                train_fp_texts.append(chunk)
                train_fp_ids.append(row['ids'])
            else:
                test_fp_texts.append(chunk)
                test_fp_ids.append(row['ids'])
    return train_tp_texts, train_tp_labels, train_fp_texts, train_tp_ids, train_fp_ids, test_tp_texts, test_tp_labels, test_fp_texts, test_tp_ids, test_fp_ids
train_tp_texts, train_tp_labels, train_fp_texts, train_tp_ids, train_fp_ids, test_tp_texts, test_tp_labels, test_fp_texts, test_tp_ids, test_fp_ids = generate_dataset(df)

In [None]:
def label_text(labels, parent_array):
    idxs = torch.zeros(len(parent_array))
    for sub_array in labels:
        sub_len = len(sub_array)
        for idx, e in enumerate(parent_array):
            if e == sub_array[0]:
                if sub_len == len(parent_array[idx:idx+sub_len]):
                    if torch.all(parent_array[idx:idx+sub_len].eq(sub_array)):
                        idxs[idx] = 1
                        if sub_len > 1:
                            idxs[range(idx+1, idx+sub_len)] = 2
    return idxs

In [None]:
class ColeridgeDataset(Dataset):
    def __init__(self, tp_texts, tp_labels, fp_texts, tp_ids, fp_ids, mode='train', tp_frac=0.5):
        self.tp_texts = tp_texts
        self.tp_labels = tp_labels
        self.fp_texts = fp_texts
        self.tp_frac = tp_frac
        self.tp_ids = tp_ids
        self.fp_ids = fp_ids
        self.mode = mode
        
    def __len__(self):
        return int(1/self.tp_frac * len(self.tp_texts))

    def __getitem__(self, idx):
        labels = []
        if idx < len(self.tp_texts):
            text = self.tp_texts[idx]
            labels = self.tp_labels[idx]
            _id = self.tp_ids[idx]
            weight = [1 / label_freq[l] for l in labels]
            weight = sum(weight) / len(weight)
        else:
            index = random.randint(0, len(self.fp_texts)-1)
            text = self.fp_texts[index]
            _id = self.fp_ids[idx]
            weight = 0.1
        return text, labels, _id, weight

def collate_fn(batch):
    texts = [item[0] for item in batch]
    ids = [item[2] for item in batch]
    weights = [item[3] for item in batch]
    encoding = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
    input_ids = encoding.input_ids
    labels = torch.zeros(input_ids.shape[0], input_ids.shape[1]) # outer, start, end, inner
    text_indices = []
    for idx, item in enumerate(batch):
        dataset_titles = item[1]
        if dataset_titles:
            dataset_toks = [label_set[label] for label in dataset_titles if label != 'labels']
            labels[idx] = label_text(dataset_toks, input_ids[idx])
    return input_ids, encoding.attention_mask, labels.type(torch.LongTensor),  ids, torch.tensor(weights)

In [None]:
train_dataset = ColeridgeDataset(train_tp_texts, train_tp_labels, train_fp_texts, train_tp_ids, train_fp_ids, mode='train', tp_frac=0.8)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

test_dataset = ColeridgeDataset(test_tp_texts, test_tp_labels, test_fp_texts, test_tp_ids, test_fp_ids, mode='test', tp_frac=0.2)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

In [None]:
x = 0
for batch in train_dataloader:
    input_ids, attention_mask, labels, ids, weight = batch
    print(labels[0])
    break

# Importing and defining the model

In [None]:
model = AutoModelForTokenClassification.from_pretrained(MODEL_TYPE, num_labels=3)
crf = CRF(3, batch_first=True)
# checkpoint = torch.load('../input/coleridge-ner/checkpoint.pt', map_location=DEVICE)
# model = checkpoint['model']
# crf = checkpoint['crf']

# Training the model

In [None]:
from nltk.corpus import stopwords

In [None]:
stops = stopwords.words('english')
stops.append('[sep]')
stops.append('[PAD]')
stops.append('PAD')
stops.append('pad')

def clean_front(split):
    for idx, e in enumerate(split):
        if e not in stops:
            return split[idx:]
        
def clean_back(split):
    for i in reversed(range(len(split))):
        if split[i] not in stops:
            return split[:i+1]
        
def clean_result(title):
#     title = title.lower()
    title = title.replace('<s>', '')
    title = title.replace('</s>', '')
    split = title.split()
    if split:
        if len(split[-1]) <= 2:
            split = split[:-1]
    if split:
        split = clean_front(split)
    if split:
        split = clean_back(split)
    if split:
        title = ' '.join(split)
        return title

In [None]:
def compute_fbeta(y_true,
                  y_pred,
                  beta = 0.5) -> float:
    """Compute the Jaccard-based micro FBeta score.

    References
    ----------
    - https://www.kaggle.com/c/coleridgeinitiative-show-us-the-data/overview/evaluation
    """

    fp_list = []
    tp_list = []
    fn_list = []
    def _jaccard_similarity(str1: str, str2: str) -> float:
        a = set(str1.split()) 
        b = set(str2.split())
        c = a.intersection(b)
        return float(len(c)) / (len(a) + len(b) - len(c))

    tp = 0  # true positive
    fp = 0  # false positive
    fn = 0  # false negative
    for ground_truth_list, predicted_string_list in zip(y_true, y_pred):
        predicted_string_list_sorted = sorted(predicted_string_list)
        if len(ground_truth_list) == 0 and len(predicted_string_list_sorted) > 0:
            fp += len(predicted_string_list_sorted)
            fp_list += predicted_string_list_sorted
        else:
            for ground_truth in sorted(ground_truth_list):
                if len(predicted_string_list_sorted) == 0:
                    fn += 1
                    fn_list.append(ground_truth)
                else:
                    similarity_scores = [
                        _jaccard_similarity(ground_truth, predicted_string)
                        for predicted_string in predicted_string_list_sorted
                    ]
                    matched_idx = np.argmax(similarity_scores)
                    if similarity_scores[matched_idx] >= 0.5:
                        tp_list.append(predicted_string_list_sorted[matched_idx])
                        predicted_string_list_sorted.pop(matched_idx)
                        tp += 1
                    else:
                        fn_list.append(ground_truth)
                        fn += 1
            fp += len(predicted_string_list_sorted)
            fp_list += predicted_string_list_sorted

    tp *= (1 + beta ** 2)
    fn *= beta ** 2
    fbeta_score = tp / (tp + fp + fn)
    print('fp: ',fp, '\ttp: ',tp, '\tfn: ', fn)
    return fbeta_score, fp_list, fn_list, tp_list

In [None]:
def lcs(X, Y):
    m = len(X)
    n = len(Y)
 
    # Create a table to store lengths of
    # longest common suffixes of substrings.
    # Note that LCSuff[i][j] contains length
    # of longest common suffix of X[0..i-1] and
    # Y[0..j-1]. The first row and first
    # column entries have no logical meaning,
    # they are used only for simplicity of program
    LCSuff = [[0 for i in range(n + 1)]
                 for j in range(m + 1)]
 
    # To store length of the
    # longest common substring
    length = 0
 
    # To store the index of the cell
    # which contains the maximum value.
    # This cell's index helps in building
    # up the longest common substring
    # from right to left.
    row, col = 0, 0
 
    # Following steps build LCSuff[m+1][n+1]
    # in bottom up fashion.
    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                LCSuff[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                LCSuff[i][j] = LCSuff[i - 1][j - 1] + 1
                if length < LCSuff[i][j]:
                    length = LCSuff[i][j]
                    row = i
                    col = j
            else:
                LCSuff[i][j] = 0
 
    # if true, then no common substring exists
    if length == 0:
        return
 
    # allocate space for the longest
    # common substring
    resultStr = ['0'] * length
 
    # traverse up diagonally form the
    # (row, col) cell until LCSuff[row][col] != 0
    while LCSuff[row][col] != 0:
        length -= 1
        resultStr[length] = X[row - 1] # or Y[col-1]
 
        # move diagonally up to previous cell
        row -= 1
        col -= 1
 
    # required longest common substring
    return ''.join(resultStr)

In [None]:
def score_results(titles, scores, ids, threshold = 0.95):
    results = {}
    actual = {}
    for _id in pd.Series(test_tp_ids + test_fp_ids).unique():
        results[_id] = []
    for idx, _id in enumerate(test_tp_ids):
        if _id not in actual:
            actual[_id] = []
        actual[_id] += [clean_label_result(l) for l in test_tp_labels[idx]]
    for idx, _id in enumerate(tqdm(ids)):
        title = titles[idx]
        if len(title) > 2 and scores[idx] > threshold:
            cleaned = clean_result(clean_label_result(title))
            if cleaned and cleaned not in results[_id] and ' ' in cleaned:
                is_new = True
                for idx, element in enumerate(results[_id]):
                    if element in cleaned:
                        is_new = False
                        results[_id][idx] = cleaned
                    elif cleaned in element:
                        is_new = False
                        results[_id][idx] = element
                if is_new:
                    results[_id].append(cleaned)
    y_pred = []
    target = []
    for _id in pd.Series(test_tp_ids + test_fp_ids).unique():
        y_pred.append(results[_id])
        if _id in actual:
            target.append(actual[_id])
        else:
            target.append([])
    fbeta, fp_list, fn_list, tp_list = compute_fbeta(target, y_pred)
    return fbeta, fp_list, fn_list, tp_list

In [None]:
def get_titles(input_ids, pred, score, _id):
    titles = []
    scores = []
    ids = []
    for idx in range(input_ids.shape[0]):
        is_title = False
        tmp_toks = []
        tmp_scores = []
        for row in range(input_ids.shape[1]):
            if pred[idx][row] > 0:
                if is_title and pred[idx][row] == 1:
                    titles.append(tokenizer.decode(tmp_toks).strip())
                    scores.append((sum(tmp_scores) / len(tmp_scores)).item())
                    ids.append(_id[idx])
                    tmp_toks = []
                    tmp_scores = []
                tmp_toks.append(input_ids[idx][row])
                tmp_scores.append(score[idx][row])
                is_title = True
            elif is_title and (pred[idx][row] == 0 or row == input_ids.shape[1] - 1 or input_ids[idx][row] == 102):
                is_title = False
                titles.append(tokenizer.decode(tmp_toks).strip())
                scores.append((sum(tmp_scores) / len(tmp_scores)).item())
                ids.append(_id[idx])
                tmp_toks = []
                tmp_scores = []
            elif is_title:
                tmp_toks.append(input_ids[idx][row])
                tmp_scores.append(score[idx][row])
    return titles, scores, ids

In [None]:
def train_fn(dataloader, model, optimizer, scheduler):
    gc.collect()
    model.train()
    loader = tqdm(dataloader)
    avg_loss = 0
    for idx, batch in enumerate(loader):
        gc.collect()
        input_ids, attention_mask, labels, ids, weight = batch
        labels = labels.to(DEVICE)
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.type(torch.uint8).to(DEVICE)
        weight = weight.to(DEVICE)
        emissions = model(input_ids, attention_mask=attention_mask).logits
        loss = -crf(emissions, labels, mask=attention_mask, reduction='none')
        loss = (loss * weight).sum()
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()
        optimizer.zero_grad()
        if idx % 50 == 0:
            print("Step = ",idx,"loss = ",loss.detach().item())
        avg_loss += loss.detach().item()
    return avg_loss / len(dataloader)

In [None]:
def eval_fn(dataloader, model):
    gc.collect()
    model.eval()
    with torch.no_grad():
        titles = []
        scores = []
        ids = []
        loader = tqdm(dataloader)
        for batch in loader:
            gc.collect()
            input_ids, attention_mask, labels, _id, weight = batch
            input_ids = input_ids.to(DEVICE)
            attention_mask = attention_mask.type(torch.uint8).to(DEVICE)
            emissions = model(input_ids, attention_mask=attention_mask).logits
            pred = torch.tensor(crf.decode(emissions))
            score = torch.max(torch.softmax(emissions.permute(0,2,1), dim=1), dim=1).values
            _titles, _scores, _ids = get_titles(input_ids, pred, score, _id)
            titles += _titles
            scores += _scores
            ids += _ids
        return titles, scores, ids

In [None]:
from transformers import get_scheduler

NUM_EPOCHS = 3
model.to(DEVICE)
crf.to(DEVICE)
params = list(model.parameters()) + list(crf.parameters())
optimizer = torch.optim.AdamW(params, lr=3e-5, weight_decay=0.1)

num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0.1 * len(train_dataloader),
    num_training_steps=num_training_steps
)

In [None]:
for e in range(NUM_EPOCHS):
    try:
        flag = 0
        best_fbeta = -1
        fp_list = []
        fn_list = []
        tp_list = []
#         train_loss = train_fn(train_dataloader, model, optimizer, None)
        train_loss = train_fn(train_dataloader, model, optimizer, lr_scheduler)
        gc.collect()
        titles, scores, ids = eval_fn(test_dataloader, model)
        gc.collect()
        thresholds = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        thresh_fbeta = 0
        thresh = 0
        for t in thresholds:
            fbeta, fp_list, fn_list, tp_list = score_results(titles, scores, ids, t)
            if fbeta >= thresh_fbeta:
                thresh_fbeta = fbeta
                thresh = t
        print("fbeta:\t" + str(thresh_fbeta))
        if thresh_fbeta >= best_fbeta:
            best_fbeta = thresh_fbeta
            checkpoint = {
                'fbeta': best_fbeta,
                'thresh': thresh,
                'model': model,
                'crf': crf,
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': e
            }
            torch.save(checkpoint, "./checkpoint.pt")
            pd.Series(fp_list).to_csv('fp_list.csv', index=False)
            pd.Series(fn_list).to_csv('fn_list.csv', index=False)
            pd.Series(tp_list).to_csv('tp_list.csv', index=False)
        else:
            flag +=1
            if (flag > 2):
                break
    except Exception as exception:
        print(exception)
        checkpoint = {
            'fbeta': 0,
            'thresh': 0,
            'model': model,
            'crf': crf,
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': e
        }
        torch.save(checkpoint, "./checkpoint.pt")