# Imports & Preamble

In [None]:
!pip install -qU --no-warn-conflicts transformers --no-index --find-links=file:///kaggle/input/coleridge-packages
!pip install -qU --no-warn-conflicts tokenizers --no-index --find-links=file:///kaggle/input/coleridge-packages
!pip install -qU --no-warn-conflicts datasets --no-index --find-links=file:///kaggle/input/coleridge-packages
!pip install -qU --no-warn-conflicts fsspec --no-index --find-links=file:///kaggle/input/coleridge-packages

In [None]:
from __future__ import annotations
import os
import pandas as pd
import json
import re
from typing import Iterable
from tqdm.notebook import tqdm

from transformers import (
    BigBirdTokenizerFast,
)
from datasets import (
    Dataset,
    Features,
    Sequence,
    Value,
    ClassLabel,
)

from coleridge_helpers import (
    clean_text,
    get_text_as_word_array,
    find_last_period_in_string_array,
    get_tags_for_snippet,
    get_snippets_from_paper,
)

# Load & Preprocess Data

In [None]:
dataset_path = "../input/coleridgeinitiative-show-us-the-data/"
trainfiles_path = dataset_path + "train/"
train_metadata = pd.read_csv(dataset_path + "train.csv")

## Create Context Snippets

In [None]:
all_snippets = []

# Loop through all the files and create a collection of snippets (training examples). 
# These snippets should provide as much context as possible. So we make them as close 
# as we can to the maximum length BigBird will accept while keeping sections intact.

for filename in tqdm(os.listdir(trainfiles_path)):
    snippets = get_snippets_from_paper(f"{trainfiles_path}{filename}")
    all_snippets.extend(snippets)


## Find Labels in Snippets and Create Corresponding Tags

In [None]:
# Create token and tag arrays for each snippet in a dataframe

# N.B. the tokens here are words and punctuation, not the subword tokens 
# that will later be created by the BigBird Tokenizer

# We need to reduce the size of our dataset in order for a single 
# training epoch to finish within Kaggle's 9hr limit
FRAC_OF_NEGATIVE_EXAMPLES_TO_KEEP = 0.05

unique_labels = train_metadata["dataset_label"].unique()
rows = []
rows_for_snippets_without_datasets = []

for snippet in tqdm(all_snippets):

    tokens = get_text_as_word_array(snippet)
    found_labels = set()

    for label in unique_labels:
        if re.search(f"\\b{label}\\b", snippet):
            found_labels.add(label)

    tags = get_tags_for_snippet(tokens, found_labels)

    row = {"tokens": tokens, "tags": tags}

    if len(found_labels) == 0:
        rows_for_snippets_without_datasets.append(row)
    else:
        rows.append(row)

snippets_negative_samples = pd.DataFrame(rows_for_snippets_without_datasets)
snippets_positive_samples = pd.DataFrame(rows)
snippets_df = pd.concat(
    [
        snippets_positive_samples,
        snippets_negative_samples.sample(frac=FRAC_OF_NEGATIVE_EXAMPLES_TO_KEEP),
    ],
    ignore_index=True,
)

## Create a HuggingFace Dataset

In [None]:
label2id = {"O": 0, "B": 1, "I": 2}

snippets_df["tags"] = snippets_df["tags"].apply(
    lambda tags: [label2id[tag] for tag in tags]
)

features = Features(
    {
        "tokens": Sequence(Value("string")),
        "tags": Sequence(ClassLabel(names=["O", "B", "I"])),
    }
)

dataset = Dataset.from_pandas(snippets_df, features=features)

tokenizer = BigBirdTokenizerFast.from_pretrained(
    "../input/huggingfacebigbirdrobertabase"
)

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True
    )

    labels = []
    for i, label in enumerate(examples["tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. 
            # We set the label to -100 so they are automatically ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to the current label
            else:
                label_ids.append(label[word_idx])
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs


tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)
tokenized_dataset = tokenized_dataset.shuffle(seed=42)

In [None]:
# need to limit the batch size to stop the kernel running out of memory
tokenized_dataset.to_json("tokenized_dataset.json", batch_size=64)