In [None]:
%reload_ext autoreload
%autoreload 2

In [1]:
# Set up root directory to PATH
import sys
from pathlib import Path
root_path = str(Path.cwd().parent)
if root_path not in sys.path:
    sys.path.append(root_path)

# Import external libraries
from typing import Tuple
from scandeval import load_dataset
from tqdm.auto import tqdm
import pandas as pd
import torch
from datasets import Dataset
from transformers import (AutoModelForTokenClassification, 
                          AutoConfig,
                          AutoTokenizer,
                          DataCollatorForTokenClassification,
                          TrainingArguments,
                          Trainer,
                          EarlyStoppingCallback)

# Import local scripts
from src import (ner_preprocess_data, ner_compute_metrics, NER_LABELS)

In [2]:
def get_trainer(df: pd.DataFrame) -> Tuple[Trainer, Dataset, Dataset]:
    
    # Convert dataframe to HuggingFace Dataset
    dataset_dct = dict(doc=df.doc,
                       tokens=df.tokens,
                       orig_labels=df.ner_tags)
    dataset = Dataset.from_dict(dataset_dct)
    
    # Tokenize and align labels
    dataset = ner_preprocess_data(dataset, tokenizer)

    # Set up training arguments
    training_args = TrainingArguments(
        output_dir='.',
        evaluation_strategy='epoch',
        logging_strategy='epoch',
        save_strategy='epoch',
        report_to='none',
        save_total_limit=1,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        learning_rate=2e-5,
        num_train_epochs=1000,
        warmup_steps=len(dataset) * 0.9,
        gradient_accumulation_steps=1,
        load_best_model_at_end=True
    )

    # Split the dataset into a training and validation dataset
    split = dataset.train_test_split(0.1, seed=4242)
    
    # Set up data collator for feeding the data into the model
    data_collator = DataCollatorForTokenClassification(tokenizer)
    
    # Set up early stopping callback
    early_stopping = EarlyStoppingCallback(early_stopping_patience=2)
    
    # Initialise the Trainer object
    trainer = Trainer(model=model,
                      args=training_args,
                      train_dataset=split['train'],
                      eval_dataset=split['test'],
                      tokenizer=tokenizer,
                      data_collator=data_collator,
                      compute_metrics=ner_compute_metrics,
                      callbacks=[early_stopping])
    
    # Return the trainer, the training dataset and the validation dataset
    return trainer, split['train'], split['test']

In [3]:
def load_model(model_id: str, id2label: list):
    config = dict(num_labels=len(NER_LABELS),
                  id2label=NER_LABELS,
                  label2id={lbl:id for id, lbl in enumerate(NER_LABELS)})
    config = AutoConfig.from_pretrained(trf, **config)
    tokenizer = AutoTokenizer.from_pretrained(trf)
    model = AutoModelForTokenClassification.from_pretrained(trf, config=config)
    return model, tokenizer

## Load datasets and model

In [None]:
dataset_names = ['dane', 'norne-nb', 'norne-nn', 'suc3', 'wikiann-is', 'wikiann-fo']
all_datasets = {name: pd.concat((load_dataset(name)[0], 
                                 load_dataset(name)[2]), axis=1) 
                for name in dataset_names}

In [None]:
model, tokenizer = load_model('NbAiLab/nb-bert-base', NER_LABELS)

## Concatenating all the datasets

In [None]:
fully_concatenated = (pd.concat(all_datasets.values(), axis=0)
                        .reset_index(drop=True))
print(f'There are {len(fully_concatenated):,} documents in the dataset.')
fully_concatenated.head()

In [None]:
trainer, train, val = get_trainer(fully_concatenated)

In [None]:
trainer.train()

In [None]:
trainer.evaluate(val)

## Ensuring equal language contribution

In [None]:
min_length = min(len(df) for df in dataset_dict.values())
subsampled = (pd.concat([df.sample(min_length) 
                        for df in datasets.values()], axis=0)
                .reset_index(drop=True))
print(f'There are {len(subsampled):,} documents in the dataset.')
subsampled.head()

In [None]:
trainer, train, val = get_trainer(subsampled)

In [None]:
trainer.train()subsampled

In [None]:
trainer.evaluate(val)