<img src="../docs/sa_logo.png" width="250" align="left">

# Question Answering NER with HuggingFace 

## Introduction

This tutorial shows an example of solving ```Named Entity Recognition task``` with [SuperAnnotate](https://www.superannotate.com/) and [HuggingFace](https://huggingface.co/).

The main goal of this tutorial is to show how one could annotate some part of data with ```SuperAnnotate``` tools and then build a model with ```HuggingFace``` to automatically annotate the rest of data and upload new annotations to [SuperAnnotate platform](https://app.superannotate.com/). These automatically generated annotations may be additionaly checked and modified manually.

All the experiments described in this tutorial were done with [Legal NER](https://paperswithcode.com/dataset/legal-ner) dataset. It is a corpus of 46545 annotated legal named entities mapped to 14 legal entity types. It is designed for named entity recognition in indian court judgement.

![](../docs/legal-ner/folders_legal_ner.png)

The tutorial starts with the assumption that we have partially annotated dataset of texts.
The data is stored on S3 bucket and splitted into two parts: 
* **train** (~40%) $-$ annotated data for training
* **unlabeled** (~60%) $-$ data that will be annotated by the model

These folders are connected with existing project on [SuperAnnotate platform](https://app.superannotate.com/) and train dataset has already been annotated manually. 

![](../docs/legal-ner/new_lner_example_train.png)

In the examples below we used ```SuperAnnotate SDK```, ```Boto3 SDK``` and ```HuggingFace```. $\ $
Some parts of code used here are provided as examples in [SuperAnnotate](https://doc.superannotate.com/docs/getting-started), [Boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) and  [HuggingFace](https://huggingface.co/) documentations.

In this tutorial we will solve Named Entity Recognition problem as Question Answering problem. The algorithm we will use was introduced in [QaNER: Prompting Question Answering Models for Few-shot Named Entity Recognitcon](https://arxiv.org/abs/2203.01543)

Some parts of code in the tutorial are based on [this unofficial implementation](https://github.com/dayyass/QaNER) of QaNER algorithm.

In this tutorial we will go through the following steps:

$\textbf{1.}$ [Environmental setup](#environmental_setup)

$\textbf{1.1}$ [User Variables Setup](#user_variables)

$\textbf{1.1}$ [Constants Setup](#constants_setup)

$\textbf{2.}$ [Download documents and labels from SuperAnnotate](#download_data)

$\textbf{2.1}$ [Get links to all files in S3 bucket](#list_all_files_s3)

$\textbf{2.2}$ [Download files](#download_files)

$\textbf{2.3}$ [Download labels from SuperAnnotate](#download_labels_from_sa)
   
$\textbf{3.}$ [Prepare data for Bert NER model](#prepare_data_for_bert_model)

$\textbf{4.}$ [Train model](#train_model)

$\textbf{5.}$ [Get predictions for unlabeled texts](#get_predictions_for_unlabeled_texts)

$\textbf{6.}$ [Make annotations in SuperAnnotate format](#make_annotations_sa_format)

$\textbf{7.}$ [Upload new annotations to SuperAnnotate platform](#upload_new_annotations_to_sa_platform)


## 1. Environmental setup
<a id='environmental_setup'></a>

In [None]:
! pip install superannotate==4.4.7 #SA SDK installation
! pip install boto3 # install boto3 client
! pip install transformers # HuggingFace transformers
! pip install seqeval # model evaluation

In [None]:
import boto3
import glob
import os
import pandas as pd
import torch

from collections import Counter, defaultdict
from seqeval.metrics import accuracy_score
from seqeval.metrics import classification_report
from seqeval.metrics import f1_score
from seqeval.scheme import IOB2
from sklearn.model_selection import train_test_split
from superannotate import SAClient
from transformers import BertTokenizerFast
from tqdm.notebook import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import SGD, Adam, NAdam
from tqdm.notebook import tqdm
from transformers import BertForTokenClassification

### 1.1 User Variables Setup
<a id='user_variables'></a>

In [None]:
#SuperAnnotate SDK token
SA_TOKEN = "ADD_YOUR_TOKEN"

In [None]:
SA_PROJECT_NAME = "ADD_SUPERANNOTATE_PROJECT_NAME"

### 1.2 Constants Setup
<a id='constants_setup'></a>

SuperAnnotate Python SDK functions work within the team scope of the platform, so a team-level authorization is required.

To authorize the package in a given team scope, get the authorization token from the team settings page.

In [None]:
sa_client = SAClient(token=SA_TOKEN) ## SuperAnnotate client

## 2. Download documents and labels from SuperAnnotate
<a id='download_data'></a>

In [None]:
s3_client = boto3.client('s3')
bucket_name = 'sa-public-datasets'

Data that is shown on SuperAnnotate page is actually stored on AWS S3 Bucket.
Here we provide name of this bucket.

In [None]:
bucket_name = "ADD_YOUR_BUCKET_NAME" # bucket where the data is stored

We should also create client to be able to work with AWS S3.

In [None]:
s3_client = boto3.client('s3') ## S3 client


### 2.1. Get links to all files in S3 bucket
<a id='list_all_files_s3'></a>

Texts shown on SuperAnnotate page are stored in S3 bucket.
We can download them to local computer and train our model for legal entities recognition.

Before that we should get links to all of them.
Since S3 SDK could list only 1000 objects per step, we could do it iteratively.

In [None]:
subset_names = ['train', 'unlabeled']

data_links_dict = {'train': [],
                   'unlabeled': []}

BUCKET_FOLDER_PATH = '/path/to/data/'

start_key = ''

for subset_name in subset_names:
    print("Processing", subset_name)
    while True:
        response = s3_client.list_objects_v2(Bucket=bucket_name,
                                             Prefix=f'{BUCKET_FOLDER_PATH}/{subset_name}/',
                                             StartAfter=start_key)
        objects = response['Contents']
        for obj in objects:
            data_links_dict[subset_name].append(obj['Key'])
        print(f"\t{len(data_links_dict[subset_name])} files in {subset_name}")
        start_key = objects[-1]['Key']
        if len(objects) < 1000:
            start_key = ''
            break

### 2.2. Download files
<a id='download_files'></a>

Now we will use these links to download all the files from S3 bucket.

In [None]:
for subset_name in subset_names:
    print(f"Loading {subset_name} docs")
    save_dir = f'./{subset_name}_sa_docs'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    for file_key in tqdm(data_links_dict[subset_name]):
        if not '.txt' in file_key:
            continue
        filename = os.path.basename(file_key)
        s3_client.download_file(Bucket=bucket_name, 
                                Key=file_key,
                                Filename=os.path.join(save_dir, filename))
        

### 2.3 Download labels from SuperAnnotate
<a id='download_labels_from_sa'></a>

Now we can download labels from SuperAnnotate for the train texts that were annotated manually. The annotations will be downloaded in [SuperAnnotate format](https://doc.superannotate.com/docs/sdk-export-annotations).

In [None]:
token = "PUT_YOUR_TOKEN_HERE"

sa_client = SAClient(token = token)

In [None]:
sa_response = sa_client.get_annotations(project=SA_PROJECT_NAME,
                                        items=[os.path.basename(x) for x \
                                               in data_links_dict['train']])

annotations = [i['instances'] for i in sa_response]

In [None]:
prompt_mapper = {'CASE_NUMBER': 'case number',
 'COURT': 'court',
 'DATE':'date',
 'GPE':'location',
 'JUDGE':'judge',
 'LAWYER':'lawyer',
 'ORG':'organization',
 'OTHER_PERSON':'other person',
 'PETITIONER':'petitioner',
 'PRECEDENT':'precedent',
 'PROVISION':'provision',
 'RESPONDENT':'respondent',
 'STATUTE':'statute',
 'WITNESS':'witness'}

## 3. Prepare data for Bert NER model
<a id='prepare_data_for_bert_model'></a>

We will use pretrained tokenizer *bert-base-cased* for our data.

Now we can create class for our dataset.

In [None]:
Instance = namedtuple("Instance", ["context", "question", "answer"])

class Dataset(torch.utils.data.Dataset):
    def __init__(self,
                 qa_sentences,
                 qa_labels,
                 prompt_mapper):
        super().__init__()
        self.prompt_mapper = prompt_mapper
        self.dataset = self._prepare_dataset(qa_sentences=qa_sentences,
                                             qa_labels=qa_labels)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx: int):
        return self.dataset[idx]

    def _prepare_dataset(self,
                         qa_sentences,
                         qa_labels):

        dataset = []
        for sentence, labels in tqdm(zip(qa_sentences, qa_labels),desc="prepare_dataset"):
            for label_tag, label_name in self.prompt_mapper.items():
                question_prompt = f"What is the {label_name}?"

                answer_list = []
                for span in labels:
                    if span['label'] == label_tag:
                        answer_list.append(span)

                if len(answer_list) == 0:
                    empty_span = {"token" : "",
                                  "label" : "O",
                                  "start_context_char_pos" : 0,
                                  "end_context_char_pos" : 0}
                    instance = Instance(
                        context=sentence,
                        question=question_prompt,
                        answer=empty_span,
                    )
                    dataset.append(instance)
                else:
                    for answer in answer_list:
                        instance = Instance(
                            context=sentence,
                            question=question_prompt,
                            answer=answer,
                        )
                        dataset.append(instance)

        return dataset

    
class Collator:
    def __init__(self, tokenizer, tokenizer_kwargs):
        self.tokenizer = tokenizer
        self.tokenizer_kwargs = tokenizer_kwargs

    def __call__(self, batch):
        context_list = []
        question_list = []
        start_end_context_char_pos_list = []

        for instance in batch:
            context_list.append(instance.context)
            question_list.append(instance.question)
            start_end_context_char_pos_list.append(
                [
                    instance.answer.start_context_char_pos,
                    instance.answer.end_context_char_pos,
                ]
            )

        tokenized_batch = self.tokenizer(question_list, context_list, **self.tokenizer_kwargs)

        offset_mapping_batch = tokenized_batch["offset_mapping"].numpy().tolist()

        assert len(offset_mapping_batch) == len(start_end_context_char_pos_list)

        start_tok_pos_list, end_tok_pos_list = [], []
        for offset_mapping, start_end_char_pos in zip(offset_mapping_batch, start_end_context_char_pos_list):
            if start_end_char_pos == [0, 0]:
                start_tok_pos_list.append(0)
                end_tok_pos_list.append(0)
            else:
                start_tok_pos, end_tok_pos = self.ch_to_tok_bounds(offset_mapping=offset_mapping,
                                                                   start_end_char_pos=start_end_char_pos)
                start_tok_pos_list.append(start_tok_pos)
                end_tok_pos_list.append(end_tok_pos)

        tokenized_batch["start_positions"] = torch.LongTensor(start_tok_pos_list)
        tokenized_batch["end_positions"] = torch.LongTensor(end_tok_pos_list)

        tokenized_batch["instances"] = batch

        return tokenized_batch

    @staticmethod
    def ch_to_tok_bounds(offset_mapping, start_end_char_pos):
        start_context_char_pos, end_context_char_pos = start_end_char_pos
        assert end_context_char_pos >= start_context_char_pos
        
        done = False
        special_tokens_cnt = 0
        for i, token_boundaries in enumerate(offset_mapping):
            if token_boundaries == [0, 0]:
                special_tokens_cnt += 1
                continue
            if special_tokens_cnt == 2:
                start_token_pos, end_token_pos = token_boundaries
                if start_token_pos == start_context_char_pos:
                    res_start_token_pos = i
                if end_token_pos == end_context_char_pos:
                    res_end_token_pos = i  
                    done = True
                    break
        if special_tokens_cnt > 2:
            res_end_token_pos = len(offset_mapping) - 1  
            res_start_token_pos = 0
            done = True
            
        assert done
        return res_start_token_pos, res_end_token_pos

Now we upload train texts that we downloaded from S3 bucket and split them into train and test samples.

In [None]:
TRAIN_DOCS_FOLDER = f'./train_sa_docs'

texts = []

for path in data_links_dict['train']:
    filepath = os.path.join(TRAIN_DOCS_FOLDER,os.path.basename(path))
    with open(filepath) as f:
        texts.append(f.read())

In [None]:
train_qa_sents, test_qa_sents, train_qa_labels, test_qa_labels = train_test_split(texts, spans, test_size=0.3)

## 4. Train model
<a id='train_model'></a>

Now we can actually create the datasets for model training.

In [None]:
train_dataset = Dataset(qa_sentences=train_qa_sents,
                        qa_labels=train_qa_labels,
                        prompt_mapper=prompt_mapper)


test_dataset = Dataset(qa_sentences=test_qa_sents,
                       qa_labels=test_qa_labels,
                       prompt_mapper=prompt_mapper)

Let's now implement the training loop and all the functions we need for it.

In [None]:
def compute_metrics(spans_true_batch,
                    spans_pred_batch_top_1,
                    prompt_mapper): 
    
    metrics = {}

    entity_mapper = {"O": 0}
    for entity_tag in prompt_mapper:
        entity_mapper[entity_tag] = len(entity_mapper)

    ner_confusion_matrix = np.zeros((len(entity_mapper), len(entity_mapper)))
    confusion_matrix_true_denominator = np.zeros(len(entity_mapper))
    confusion_matrix_pred_denominator = np.zeros(len(entity_mapper))

    for span_true, span_pred in zip(spans_true_batch, spans_pred_batch_top_1):
        span_pred = span_pred[0]
        i = entity_mapper[span_true.label]
        j = entity_mapper[span_pred.label]
        confusion_matrix_true_denominator[i] += 1
        confusion_matrix_pred_denominator[j] += 1
        if span_true == span_pred:
            ner_confusion_matrix[i, j] += 1
            
    assert (confusion_matrix_true_denominator.sum() == confusion_matrix_pred_denominator.sum())

    ner_confusion_matrix_diag = np.diag(ner_confusion_matrix)

    accuracy = np.nan_to_num(ner_confusion_matrix_diag.sum() / confusion_matrix_true_denominator.sum())
    precision_per_entity_type = np.nan_to_num(ner_confusion_matrix_diag / confusion_matrix_pred_denominator)
    recall_per_entity_type = np.nan_to_num(ner_confusion_matrix_diag / confusion_matrix_true_denominator)
    f1_per_entity_type = np.nan_to_num(2 * precision_per_entity_type * recall_per_entity_type
                                        / (precision_per_entity_type + recall_per_entity_type))

    metrics["accuracy"] = accuracy

    for label_tag, idx in entity_mapper.items():
        metrics[f"precision_tag_{label_tag}"] = precision_per_entity_type[idx]
        metrics[f"recall_tag_{label_tag}"] = recall_per_entity_type[idx]
        metrics[f"f1_tag_{label_tag}"] = f1_per_entity_type[idx]

    metrics["precision_macro"] = precision_per_entity_type.mean()
    metrics["recall_macro"] = recall_per_entity_type.mean()
    metrics["f1_macro"] = f1_per_entity_type.mean()

    metrics["precision_weighted"] = np.average(precision_per_entity_type,
                                               weights=confusion_matrix_true_denominator)
    metrics["recall_weighted"] = np.average(recall_per_entity_type,
                                            weights=confusion_matrix_true_denominator)
    metrics["f1_weighted"] = np.average(f1_per_entity_type,
                                        weights=confusion_matrix_true_denominator)
    return metrics

In [None]:
def get_top_valid_spans(context_list,
                        question_list,
                        prompt_mapper,
                        inputs,
                        outputs,
                        offset_mapping_batch,
                        n_best_size = 1,
                        max_answer_length = 100):

    batch_size = len(offset_mapping_batch)

    inv_prompt_mapper = {v: k for k, v in prompt_mapper.items()}

    assert batch_size == len(context_list)
    assert batch_size == len(question_list)
    assert batch_size == len(inputs["input_ids"])
    assert batch_size == len(inputs["token_type_ids"])
    assert batch_size == len(outputs["start_logits"])
    assert batch_size == len(outputs["end_logits"])

    top_valid_spans_batch = []

    for i in range(batch_size):
        context = context_list[i]

        offset_mapping = offset_mapping_batch[i].cpu().numpy()
        mask = inputs["token_type_ids"][i].bool().cpu().numpy()
        offset_mapping[~mask] = [0, 0]
        offset_mapping = [
            (span if span != [0, 0] else None) for span in offset_mapping.tolist()
        ]

        start_logits = outputs["start_logits"][i].cpu().numpy()
        end_logits = outputs["end_logits"][i].cpu().numpy()

        start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
        end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
        top_valid_spans = []

        for start_index, end_index in zip(start_indexes, end_indexes):
            if (start_index >= len(offset_mapping)
                or end_index >= len(offset_mapping)
                or offset_mapping[start_index] is None
                or offset_mapping[end_index] is None
            ):
                continue
            if (end_index < start_index) or (end_index - start_index + 1 > max_answer_length):
                continue
            if start_index <= end_index:
                start_context_char_char, end_context_char_char = offset_mapping[start_index]
                span = {"token" : context[start_context_char_char:end_context_char_char],
                        "label" : inv_prompt_mapper[question_list[i].split(r"What is the ")[-1].rstrip(r"?")],
                        "start_context_char_pos" : start_context_char_char,
                        "end_context_char_pos" : end_context_char_char}
                top_valid_spans.append(span)
        top_valid_spans_batch.append(top_valid_spans)
    return top_valid_spans_batch

In [None]:
def train_epoch(model,
                dataloader,
                optimizer,
                device,
                epoch):
    model.train()
    epoch_loss = []
    batch_metrics_list = defaultdict(list)
    for i, inputs in tqdm(enumerate(dataloader), total=len(dataloader)):
        optimizer.zero_grad()
        instances_batch = inputs.pop("instances")
        context_list, question_list = [], []
        for instance in instances_batch:
            context_list.append(instance.context)
            question_list.append(instance.question)
        inputs = inputs.to(device)
        offset_mapping_batch = inputs.pop("offset_mapping")
        outputs = model(**inputs)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
        with torch.no_grad():
            model.eval()
            outputs_inference = model(**inputs)
            model.train()
        spans_pred_batch_top_1 = get_top_valid_spans(context_list=context_list,
                                                     question_list=question_list,
                                                     prompt_mapper=dataloader.dataset.prompt_mapper,
                                                     inputs=inputs,
                                                     outputs=outputs_inference,
                                                     offset_mapping_batch=offset_mapping_batch,
                                                     n_best_size=1,
                                                     max_answer_length=100)

        for idx in range(len(spans_pred_batch_top_1)):
            if not spans_pred_batch_top_1[idx]:
                empty_span = {"token" : "",
                              "label" : "O",
                              "start_context_char_pos" : 0,
                              "end_context_char_pos" : 0}
                spans_pred_batch_top_1[idx] = [empty_span]

        spans_true_batch = [instance.answer for instance in instances_batch]

        batch_metrics = compute_metrics(spans_true_batch=spans_true_batch,
                                        spans_pred_batch_top_1=spans_pred_batch_top_1,
                                        prompt_mapper=dataloader.dataset.prompt_mapper)

        for metric_name, metric_value in batch_metrics.items():
            batch_metrics_list[metric_name].append(metric_value)

    avg_loss = np.mean(epoch_loss)
    print(f"Train loss: {avg_loss}\n")

    for metric_name, metric_value_list in batch_metrics_list.items():
        metric_value = np.mean(metric_value_list)
        print(f"Train {metric_name}: {metric_value}\n")


def evaluate_epoch(model, dataloader, device, epoch):
    model.eval()

    epoch_loss = []
    batch_metrics_list = defaultdict(list)

    with torch.no_grad():

        for i, inputs in tqdm(enumerate(dataloader),total=len(dataloader)):

            instances_batch = inputs.pop("instances")

            context_list, question_list = [], []
            for instance in instances_batch:
                context_list.append(instance.context)
                question_list.append(instance.question)

            inputs = inputs.to(device)
            offset_mapping_batch = inputs.pop("offset_mapping")

            outputs = model(**inputs)
            loss = outputs.loss

            epoch_loss.append(loss.item())
            spans_pred_batch_top_1 = get_top_valid_spans(context_list=context_list,
                                                         question_list=question_list,
                                                         prompt_mapper=dataloader.dataset.prompt_mapper,
                                                         inputs=inputs,
                                                         outputs=outputs,
                                                         offset_mapping_batch=offset_mapping_batch,
                                                         n_best_size=1,
                                                         max_answer_length=100)

            for idx in range(len(spans_pred_batch_top_1)):
                if not spans_pred_batch_top_1[idx]:
                    empty_span = {"token" : "",
                                  "label" : "O",
                                  "start_context_char_pos" : 0,
                                  "end_context_char_pos" : 0}
                    spans_pred_batch_top_1[idx] = [empty_span]
            spans_true_batch = [instance.answer for instance in instances_batch]
            batch_metrics = compute_metrics(spans_true_batch=spans_true_batch,
                                            spans_pred_batch_top_1=spans_pred_batch_top_1,
                                            prompt_mapper=dataloader.dataset.prompt_mapper)
            for metric_name, metric_value in batch_metrics.items():
                batch_metrics_list[metric_name].append(metric_value)
        avg_loss = np.mean(epoch_loss)
        print(f"Test loss:  {avg_loss}\n")

        for metric_name, metric_value_list in batch_metrics_list.items():
            metric_value = np.mean(metric_value_list)
            print(f"Test {metric_name}: {metric_value}\n")

In [None]:
def train(n_epochs, 
          model,
          train_dataloader,
          test_dataloader,
          optimizer,
          device):
    for epoch in range(n_epochs):
        print(f"Epoch [{epoch+1} / {n_epochs}]\n")
        train_epoch(model=model,
                    dataloader=train_dataloader,
                    optimizer=optimizer,
                    device=device,
                    epoch=epoch)
        evaluate_epoch(model=model,
                       dataloader=test_dataloader,
                       device=device,
                       epoch=epoch)

Now we can actually train our model.

In [None]:
BERT_MODEL_NAME = 'bert-base-uncased'
PATH_TO_SAVE_MODEL = 'qaner-legalner-bert-base-uncased3'
BATCH_SIZE = 4
LEARNING_RATE = 1e-5 
EPOCHS = 3

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# bert model
tokenizer = transformers.AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
model = transformers.AutoModelForQuestionAnswering.from_pretrained(bert_model_name).to(device)

tokenizer_kwargs = {"max_length": 512,
                    "truncation": "only_second",
                    "padding": True,
                    "return_tensors": "pt",
                    "return_offsets_mapping": True}

In [None]:
collator = Collator(tokenizer=tokenizer, tokenizer_kwargs=tokenizer_kwargs)

train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True,
                                               collate_fn=collator)

test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=False,
                                              collate_fn=collator)

optimizer = torch.optim.Adam(model.parameters(),
                             lr=LEARNING_RATE)

train(n_epochs=n_epochs,
      model=model,
      train_dataloader=train_dataloader,
      test_dataloader=test_dataloader,
      optimizer=optimizer,
      device=device)

model.save_pretrained(PATH_TO_SAVE_MODEL)
tokenizer.save_pretrained(PATH_TO_SAVE_MODEL)

## 5. Get predictions for unlabeled texts
<a id='get_predictions_for_unlabeled_texts'></a>

After the training is done we could evaluate our model on test data. 
We could use [seqeval](https://huggingface.co/spaces/evaluate-metric/seqeval) to get span-based metrics.

In [None]:
def merge_spans(txt, first_span, second_span):
    interm_tok = '' if first_span.end_context_char_pos == second_span.start_context_char_pos else txt[first_span.end_context_char_pos+1]
    return {"token" : first_span['token']+interm_tok+second_span['token'],
            "start_context_char_pos" : first_span['start_context_char_pos'],
            "end_context_char_pos" : second_span['end_context_char_pos'],
            "label" : first_span['label']}

def predict(context, 
            question,
            prompt_mapper,
            model,
            tokenizer,
            tokenizer_kwargs,
            n_best_size = 1,
            max_answer_length = 100):
    inputs = tokenizer([question], [context], **tokenizer_kwargs).to(model.device)
    offset_mapping_batch = inputs.pop("offset_mapping")

    with torch.no_grad():
        outputs = model(**inputs)

    predicted_spans = get_top_valid_spans(context_list=[context],
                                          question_list=[question],
                                          prompt_mapper=prompt_mapper,
                                          inputs=inputs,
                                          outputs=outputs,
                                          offset_mapping_batch=offset_mapping_batch,
                                          n_best_size=n_best_size,
                                          max_answer_length=max_answer_length)[0]
    
    predicted_spans.sort(key=lambda x: x.start_context_char_pos)
    merged_spans = []
    for span in predicted_spans:
        if not merged_spans:
            merged_spans.append(span)
            continue
        else:
            last_span = merged_spans[-1]
        if span.start_context_char_pos<=last_span.end_context_char_pos+1 and span.label==last_span.label:
            merged_spans.pop()
            merged_spans.append(merge_spans(context,last_span,span))
        else:
            merged_spans.append(span)
            
    return merged_spans

Let's upload the unlabeled docs we got from SuperAnnotate and get the predictions for them.

In [None]:
unlabeled_texts = []
names = []
for filename in glob.glob('./unlabeled_sa_docs/*.txt'):
    with open(filename) as f:
        unlabeled_texts.append(f.read())
        names.append(os.path.basename(filename)) 

In [None]:
predicted_spans = []
for text in unlabeled_texts:
    for class_tag, class_name in prompt_mapper.items():
        predicted_spans.append(predict(context,
                                       question,
                                       prompt_mapper,
                                       model,
                                       tokenizer,
                                       tokenizer_kwargs,
                                       n_best_size = 10,
                                       max_answer_length = 100))
        

## 7. Make annotations in SuperAnnotate format
<a id='make_annotations_sa_format'></a>

Based on predictions made by the model we should now create annotations in SuperAnnotate format to be able to upload them to SuperAnnotate.

In [None]:
new_annotations = []
for spans, name in zip(predicted_spans, names):
    entities = []
    for span in spans:
        entities.append({"type": "entity",
                         "className": span['label'],
                         "start": span[''],
                         "end": span[''],
                         "attributes": []})
    new_annotations.append({'instances': entities,
                            'metadata': {'name' : name}})

In [None]:
ANNOTATIONS_FOLDER = 'PATH/TO/LOCAL/DIR/' # local folder to store .json files with annotations
for annotation in new_annotations:
    filename = annotation['metadata']['name']
    with open(f'{ANNOTATIONS FOLDER}/{filename}.json','w') as f:
        json.dump(js_annotation, f)

## 8. Upload new annotations to SuperAnnotate platform
<a id='upload_new_annotations_to_sa_platform'></a>

Now we could upload annotations generated on the previous step back to SuperAnnnotate.

In [None]:
def read_json(filename):
    with open(filename) as f:
        data = json.load(f)
    return data 

In [None]:
outputs = []
files = os.listdir(ANNOTATIONS_FOLDER)
files_per_step = 500
steps = len(files) // files_per_step + 1

for step in range(steps):
    start = step * files_per_step
    end = min((step + 1)* files_per_step, len(files))

    batch = [read_json(os.path.join(ANNOTATIONS_FOLDER, f)) for f in files[start: end]]

    outputs.append(sa_client.upload_annotations(project=f'{SA_PROJECT_NAME}/unlabeled/', annotations=batch))

Now we can look at unlabeled folder at the SuperAnnotate page and see the predictions made by our model.


![](../docs/legal-ner/labeled_unlabeled.png)

All files in unlabeled folder changed their status.

![](../docs/legal-ner/lner_unlabeled_example.png)