This notebook runs inference on a fine-tuned Google BigBird model for the Coleridge Intiaitive Show US the data competition.  

In order to be able to run data-processing, training and inference within Kaggle's 9hr timout limit,  
I separated out the data preparation, here: https://www.kaggle.com/danieldorosz/show-us-the-data-bigbird-dataprep  
and the model fine-tuning, here: https://www.kaggle.com/danieldorosz/show-us-the-data-bigbird-fine-tuning  

A chunk of the logic is farmed-out to a coleridge-helpers utility script.   

The main intuition behind this effort was that I wanted to include as much context as possible in my training examples. 
Also that I wanted to keep related context together. We have a ready-provided demarkation of context expressed as 
sections in the training data. So what I did was create contextual 'snippets' as my training examples. Each snippet 
contains one or more sections such that my training examples get as close as possible to BigBird's maximum of 4096
tokens, without breaking up any sections. If a single section is longer the training example limit, I break it up 
at the last period prior to the limit.  

The code is very much a rough-and-ready first draft, please don't judge me ;-) There is much to be improved for which 
I didn't have time. This mainly serves as a baseline to assess the score I could expect from this kind of approach.

I ran fine-tuning a couple of times using the last checkpoint from the first (timed-out) run as input to the next.

# Imports & Preamble

In [None]:
!pip install -qU --no-warn-conflicts transformers --no-index --find-links=file:///kaggle/input/coleridge-packages
!pip install -qU --no-warn-conflicts tokenizers --no-index --find-links=file:///kaggle/input/coleridge-packages

In [None]:
import os
import pandas as pd
from tqdm.notebook import tqdm
import torch
from transformers import (
    BigBirdForTokenClassification,
    BigBirdConfig,
    BigBirdTokenizerFast,
)

from coleridge_helpers import (
    clean_text,
    get_snippets_from_paper,
    find_datasets_by_literal_matching,
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset_path = "../input/coleridgeinitiative-show-us-the-data/"
testfiles_path = dataset_path + "test/"

In [None]:
# create a set of dataset titles we will use for literal string matching

train_metadata = pd.read_csv(dataset_path + "train.csv")

train_titles = set(train_metadata["dataset_title"].unique())
train_titles = {title.lower() for title in train_titles}

train_labels = set(train_metadata["dataset_label"].unique())
train_labels = {title.lower() for title in train_titles}

train_datasets = train_titles.union(train_labels)

extra_gov_datasets = set(pd.read_csv("../input/bigger-govt-dataset-list/data_set_800.csv")["title"].to_list())
extra_gov_datasets = {dataset for dataset in extra_gov_datasets if not dataset.startswith("blog |")}

all_datasets = extra_gov_datasets.union(train_datasets)

# Instantiate Pretrained Bert Model & Tokenizer

In [None]:
# BigBird roberta-base
model_class, tokenizer_class, pretrained_weights = (BigBirdForTokenClassification, BigBirdTokenizerFast, '../input/huggingfacebigbirdrobertabase')

tokenizer = tokenizer_class.from_pretrained(pretrained_weights)

label_list = ["O", "B", "I"]
label2id = {label : id for id, label in enumerate(label_list)}
id2label = {id : label for label, id in label2id.items()}

def get_pretrained_model(checkpoint=pretrained_weights):
    config = BigBirdConfig(attention_type="block_sparse", gradient_checkpointing=True, num_labels=3, id2label=id2label, label2id=label2id)
    return model_class.from_pretrained(checkpoint, config=config)

# Make Predictions

In [None]:
model = get_pretrained_model("../input/coleridgemodelcheckpoint")
model = model.to(device)
model.eval()

with torch.no_grad():

    rows = []
    for filename in tqdm(os.listdir(testfiles_path)):
        filepath = f"{testfiles_path}{filename}"
        
        # do string match first
        found_labels = find_datasets_by_literal_matching(filepath, all_datasets)
        
        snippets = get_snippets_from_paper(filepath)
        for snippet in snippets:
            encoded_snippet = tokenizer(snippet, truncation=True, return_tensors="pt")[
                "input_ids"
            ]
            encoded_snippet = encoded_snippet.to(device)
            model_outputs = model(encoded_snippet)
            
            snippet_preds = model_outputs["logits"][0]
            token_ids = encoded_snippet.squeeze()

            # Remove ignored index (special tokens)
            cleaned_snippet_preds = [
                token_preds
                for token_preds, token_id in zip(snippet_preds, token_ids)
                if token_id not in tokenizer.all_special_ids
            ]

            predicted_tags = [
                label_list[token_preds.argmax(0)]
                for token_preds in cleaned_snippet_preds
            ]

            tokenized_snippet = tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=True)

            label_tokens = []
            for token, tag in zip(tokenized_snippet, predicted_tags):
                if tag == "B":
                    label_tokens.append(token)
                elif tag == "I" and len(label_tokens) > 0:
                    label_tokens.append(token)
                else:
                    if len(label_tokens) > 0:
                        found_label = tokenizer.convert_tokens_to_string(label_tokens)
                        found_labels.add(clean_text(found_label))
                        label_tokens = []

        prediction_string = "|".join(sorted(found_labels))

        rows.append({"Id": filename[:-5], "PredictionString": prediction_string})

# Make Submission

In [None]:
submission_df = pd.DataFrame(rows)
submission_df.to_csv("submission.csv", index=False)