In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging as log
log.basicConfig(level=log.DEBUG)

In [2]:
import sys
sys.path.append('../')
from baselines.utils import *
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

os.environ['WANDB_ENTITY'] = 'contract-nli-db'
os.environ['WANDB_PROJECT'] = 'contract-nli-metric'
os.environ['WANDB_LOG_MODEL'] = 'end'

In [3]:
import torch

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE

device(type='cuda')

In [4]:
cfg['model_name'] = 'bert-base-uncased'
cfg['trained_model_dir'] = '/scratch/shu7bh/contract_nli/trained_model/'
cfg['batch_size'] = 32
cfg

{'raw_data_dir': '../dataset/',
 'train_path': 'train.json',
 'test_path': 'test.json',
 'dev_path': 'dev.json',
 'model_name': 'bert-base-uncased',
 'max_length': 512,
 'models_save_dir': '/scratch/shu7bh/contract_nli/models',
 'dataset_dir': '/scratch/shu7bh/contract_nli/dataset',
 'results_dir': '/scratch/shu7bh/contract_nli/results',
 'trained_model_dir': '/scratch/shu7bh/contract_nli/trained_model/',
 'batch_size': 32}

In [5]:
# create dir if not exists
from pathlib import Path
Path(cfg["models_save_dir"]).mkdir(parents=True, exist_ok=True)
Path(cfg["dataset_dir"]).mkdir(parents=True, exist_ok=True)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(cfg['model_name'])

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


In [7]:
from icecream import ic

In [8]:
def get_hypothesis_idx(hypothesis_name):
    return int(hypothesis_name.split('-')[-1])

In [9]:
from torch.utils.data import Dataset
import random
import torch

class NLIDataset(Dataset):
    def __init__(self, documents, tokenizer, hypothesis, context_sizes, surround_character_size):
        label_dict = get_labels()
        self.tokenizer = tokenizer

        self.tokenizer.add_special_tokens({'additional_special_tokens': ['[SPAN]']})

        data_points = []
        contexts = [{}]

        for context_size in context_sizes:
            for i, doc in enumerate(documents):
                char_idx = 0
                while char_idx < len(doc['text']):
                    ic(char_idx)
                    document_spans = doc['spans']
                    cur_context = {
                        'doc_id': i,
                        'start_char_idx': char_idx,
                        'end_char_idx': char_idx + context_size,
                        'spans' : [],
                    }

                    for j, (start, end) in enumerate(document_spans):
                        if end <= char_idx:
                            continue

                        cur_context['spans'].append({
                            'start_char_idx': max(start, char_idx),
                            'end_char_idx': min(end, char_idx + context_size),
                            'marked': start >= char_idx and end <= char_idx + context_size,
                            'span_id': j
                        })

                        if end > char_idx + context_size:
                            break

                    if cur_context == contexts[-1]:
                        char_idx = cur_context['end_char_idx'] - surround_character_size
                        continue

                    contexts.append(cur_context)
                    if len(cur_context['spans']) == 1 and not cur_context['spans'][0]['marked']:
                        char_idx = cur_context['end_char_idx'] - surround_character_size
                    else:
                        char_idx = cur_context['spans'][-1]['start_char_idx'] - surround_character_size

        contexts.pop(0)

        for nda_name, nda_desc in hypothesis.items():
            for i, context in enumerate(contexts):

                nli_label = label_dict[documents[context['doc_id']]['annotation_sets'][0]['annotations'][nda_name]['choice']]

                if nli_label == label_dict['NotMentioned'] and random.random() > 0.04:
                    continue

                if nli_label == label_dict['Entailment'] and random.random() > 0.34:
                    continue

                data_point = {}
                data_point['hypotheis'] = nda_desc
                cur_premise = ""
                data_point['marked_beg'] = context['spans'][0]['marked']
                data_point['marked_end'] = context['spans'][-1]['marked']
                doc_id = context['doc_id']
                hypothesis_id = get_hypothesis_idx(nda_name)
                span_ids = []

                if len(context['spans']) == 1:
                    data_point['marked_end'] = True

                span_labels = []

                for span in context['spans']:
                    val = int(span['span_id'] in documents[context['doc_id']]['annotation_sets'][0]['annotations'][nda_name]['spans'])

                    if val == 0 and random.random() > 0.3:
                        continue

                    if span['marked']:
                        span_labels.append(val)
                        span_ids.append(span['span_id'])

                    cur_premise += ' [SPAN] '
                    cur_premise += documents[context['doc_id']]['text'][span['start_char_idx']:span['end_char_idx']]

                evidence = any(span_labels)

                data_point['premise'] = cur_premise

                # nli_label = label_dict[documents[context['doc_id']]['annotation_sets'][0]['annotations'][nda_name]['choice']]

                if not evidence and nli_label != label_dict['NotMentioned']:
                    continue

                data_point['nli_label'] = torch.tensor(nli_label, dtype=torch.long)
                data_point['span_labels'] = torch.tensor(span_labels, dtype=torch.long)
                data_point['doc_id'] = torch.tensor(doc_id, dtype=torch.long)
                data_point['hypothesis_id'] = torch.tensor(hypothesis_id, dtype=torch.long)
                data_point['span_ids'] = torch.tensor(span_ids, dtype=torch.long)

                data_points.append(data_point)

        self.data_points = data_points
        self.span_token_id = self.tokenizer.convert_tokens_to_ids('[SPAN]')

    def __len__(self):
        return len(self.data_points)

    def __getitem__(self, idx):
        tokenized_data = self.tokenizer(
            [self.data_points[idx]['hypotheis']],
            [self.data_points[idx]['premise']],
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        tokenized_data['input_ids'] = tokenized_data['input_ids'].squeeze()
        tokenized_data['attention_mask'] = tokenized_data['attention_mask'].squeeze()
        tokenized_data['token_type_ids'] = tokenized_data['token_type_ids'].squeeze()

        span_indices = torch.where(tokenized_data['input_ids'] == self.span_token_id)[0]

        if not self.data_points[idx]['marked_beg']:
            span_indices = span_indices[1:]
        
        if not self.data_points[idx]['marked_end'] or tokenized_data['attention_mask'][-1] == 0:
            span_indices = span_indices[:-1]
        
        span_ids = self.data_points[idx]['span_ids']
        span_ids = span_ids[:len(span_indices)]

        return {
            'input_ids': tokenized_data['input_ids'],
            'attention_mask': tokenized_data['attention_mask'],
            'token_type_ids': tokenized_data['token_type_ids'],
            'span_indices': span_indices,
            'nli_label': self.data_points[idx]['nli_label'],
            'span_labels': self.data_points[idx]['span_labels'][:len(span_indices)],
            'data_for_metrics': {
                'doc_id': self.data_points[idx]['doc_id'],
                'hypothesis_id': self.data_points[idx]['hypothesis_id'],
                'span_ids': span_ids,
            }
        }

In [10]:
train_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['train_path']))
dev_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['dev_path']))
test_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['test_path']))

hypothesis = get_hypothesis(train_data)

train_data = train_data['documents']
dev_data = dev_data['documents']
test_data = test_data['documents']

# train_data = train_data[:100]
# dev_data = dev_data[:100]
# test_data = test_data[:100]

ic.disable()

ic(len(train_data), len(dev_data), len(test_data))
train_dataset = NLIDataset(train_data, tokenizer, hypothesis, [1000, 1100, 1200, 1500], 50)
dev_dataset = NLIDataset(dev_data, tokenizer, hypothesis, [1000, 1100, 1200, 1500], 50)
test_dataset = NLIDataset(test_data, tokenizer, hypothesis, [1000, 1100, 1200, 1500], 50)

ic.enable()

del train_data
del dev_data
del test_data
del hypothesis

In [11]:
print(len(train_dataset))
print(len(dev_dataset))
print(len(test_dataset))


16166
2471
4799


In [12]:
# zero_ct = 0
# one_ct = 0

# for x in train_dataset:
#     if x['nli_label'] == 1 or x['nli_label'] == 2:
#         for i in x['span_labels']:
#             if i == 0:
#                 zero_ct += 1
#             else:
#                 one_ct += 1

# ic(zero_ct, one_ct)

# zero_ct = 0
# one_ct = 0

# for x in dev_dataset:
#     if x['nli_label'] == 1 or x['nli_label'] == 2:
#         for i in x['span_labels']:
#             if i == 0:
#                 zero_ct += 1
#             else:
#                 one_ct += 1

# ic(zero_ct, one_ct)

# zero_ct = 0
# one_ct = 0

# for x in test_dataset:
#     if x['nli_label'] == 1 or x['nli_label'] == 2:
#         for i in x['span_labels']:
#             if i == 0:
#                 zero_ct += 1
#             else:
#                 one_ct += 1

# ic(zero_ct, one_ct)

In [13]:
# # how many datapoints have no evidence
# from collections import Counter

# ic(Counter([x['nli_label'].item() for x in train_dataset]))
# ic(Counter([x['nli_label'].item() for x in dev_dataset]))
# ic(Counter([x['nli_label'].item() for x in test_dataset]))

In [14]:
from transformers import PreTrainedModel, PretrainedConfig

class ContractNLIConfig(PretrainedConfig):
    def __init__(self, lambda_ = 1, bert_model_name = cfg['model_name'], num_labels = len(get_labels()), ignore_index = get_labels()['Ignore'], **kwargs):
        super().__init__(**kwargs)
        self.bert_model_name = bert_model_name
        self.num_labels = num_labels
        self.ignore_index = ignore_index
        self.lambda_ = lambda_

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmppy05uriu
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmppy05uriu/_remote_module_non_scriptable.py


In [15]:
from transformers import AutoModel
from torch import nn

class ContractNLI(PreTrainedModel):
    config_class = ContractNLIConfig

    def __init__(self, config):
        super().__init__(config)
        self.bert = AutoModel.from_pretrained(config.bert_model_name)
        self.bert.resize_token_embeddings(self.bert.config.vocab_size + 1, pad_to_multiple_of=8)
        self.bert.eval()
        for param in self.bert.parameters():
            param.requires_grad = False

        self.embedding_dim = self.bert.config.hidden_size
        self.num_labels = config.num_labels
        self.lambda_ = config.lambda_
        self.nli_criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_index)
        self.span_criterion = nn.BCEWithLogitsLoss()

        self.span_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim // 2),
            nn.ReLU(),
            nn.Linear(self.embedding_dim // 2, 1)
        )

        self.nli_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim // 2),
            nn.ReLU(),
            nn.Linear(self.embedding_dim // 2, self.num_labels)
        )

        # initialize weights
        self.init_weights()

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # use the same initialization as bert
            module.weight.data.normal_(mean=0.0, std=self.bert.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, input_ids, attention_mask, token_type_ids, span_indices):
        outputs = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True).hidden_states[-1]

        gather = torch.gather(outputs, 1, span_indices.unsqueeze(2).expand(-1, -1, outputs.shape[-1]))

        masked_gather = gather[span_indices != 0]
        span_logits = self.span_classifier(masked_gather)
        nli_logits = self.nli_classifier(outputs[:, 0, :])

        return span_logits, nli_logits

In [16]:
from sklearn.metrics import precision_recall_curve
import numpy as np
def get_micro_average_precision_at_recall(y_true, y_pred, recall_level):
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    return np.interp(recall_level, recall[::-1], precision[::-1])

In [17]:
# Import numpy and sklearn.metrics
import numpy as np
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import precision_score
def calculate_micro_average_precision(y_true, y_pred):
    """Calculate the micro average precision score.

    Args:
        y_true (np.array): True labels.
        y_pred (np.array): Predicted labels.

    Returns:
        float: Micro average precision score.
    """
    # Get the number of classes
    num_classes = len(np.unique(y_true))

    # initialize the average precision score
    average_precision = 0.0

    # loop over all classes
    for class_idx in range(num_classes):
        # get the indices for this class
        y_true_indices = np.where(y_true == class_idx)
        # calculate the average precision score for this class
        average_precision += ic(precision_score(
            y_true[y_true_indices], y_pred[y_true_indices], average="micro"
        ))

    # return the average over all classes
    return average_precision / num_classes

In [18]:
from sklearn.metrics import f1_score
def calculate_f1_score_for_class(y_true, y_pred, class_idx):
    """Calculate the F1 score for a given class.

    Args:
        y_true (np.array): True labels.
        y_pred (np.array): Predicted labels.
        class_idx (int): Index of the class.

    Returns:
        float: F1 score for the given class.
    """
    # get the indices for the given class
    y_true_indices = np.where(y_true == class_idx)
    # calculate the F1 score for the given class
    return f1_score(
        y_true[y_true_indices], y_pred[y_true_indices], average="macro"
    )

In [19]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    auto_find_batch_size=True,
    output_dir=cfg['results_dir'],   # output directory
    num_train_epochs=10,            # total number of training epochs
    gradient_accumulation_steps=4,   # number of updates steps to accumulate before performing a backward/update pass
    logging_strategy='epoch',
    # eval_steps=0.25,
    # save_steps=0.25,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=True,
    label_names=['nli_label', 'span_labels', 'data_for_metrics'],
    report_to='none',
)

In [20]:
cfg['trained_model_dir']

'/scratch/shu7bh/contract_nli/trained_model/'

In [21]:
from transformers import EarlyStoppingCallback

In [38]:
import wandb
api = wandb.Api()
# artifact = api.artifact('contract-nli-db/contract-nli/model-3xmty1sv:v0', type='model')
# artifact = api.artifact('contract-nli-db/contract-nli/model-ma3khgd5:v0', type='model')
artifact = api.artifact('contract-nli-db/contract-nli/model-q6yqzpj9:v0', type='model')
# artifact = api.artifact('contract-nli-db/contract-nli/model-fbypywbc:v0', type='model')



# artifact = api.artifact('contract-nli-db/contract-nli/model-ayhbizbq:v0', type='model')
# artifact = api.artifact('contract-nli-db/contract-nli/model-xa10s0tb:v0', type='model')
# artifact = api.artifact('contract-nli-db/contract-nli/model-s3qw7z3d:v0', type='model')
artifact_dir = artifact.download(cfg['trained_model_dir'])

model = ContractNLI.from_pretrained(artifact_dir).to(DEVICE)

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): api.wandb.ai:443
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 253
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): api.wandb.ai:443
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 364
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 513
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): storage.googleapis.com:443
DEBUG:urllib3.connectionpool:https://storage.googleapis.com:443 "GET /wandb-production.appspot.com/contract-nli-db/contract-nli/q6yqzpj9/artifact/625752025/wandb_manifest.json?Expires=1698929226&GoogleAccessId=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com&Signature=k1CjNfv4oY2X149O9ac4PxAwhUuyMtEBOF05f8QDBJzVNSSMMD8%2BpEANrbGRUKQQlHoGwMmRMSazhWbu6vDLvhjdKS5nKMDEpTnRYTLPF70ShU447i9TA%2BP0cV895IAxxCpKOb6aLjuit5hBtZlxll3J5%2BLKfKhiE%2FZgTc%2BqmnbDEFpwnZ3gwtt5

In [39]:
artifact_dir

'/scratch/shu7bh/contract_nli/trained_model/'

In [58]:
from transformers import Trainer
from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score
from tqdm import tqdm
import numpy as np

class ContractNLITrainer(Trainer):
    def __init__(self, *args, data_collator=None, **kwargs):
        super().__init__(*args, data_collator=data_collator, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        span_label = inputs.pop('span_labels')
        nli_label = inputs.pop('nli_label')
        inputs.pop('data_for_metrics')

        outputs = model(**inputs)
        span_logits, nli_logits = outputs[0], outputs[1]

        true_span_labels = []
        pred_span_labels = []

        start_index = 0
        end_index = 0

        for i, span_index_row in enumerate(inputs['span_indices']):
            # find first zero index
            first_zero_index = torch.where(span_index_row == 0)[0]

            if len(first_zero_index) == 0:
                first_zero_index = len(span_index_row)
            else:
                first_zero_index = first_zero_index[0].item()
            end_index += first_zero_index

            if nli_label[i] != get_labels()['Ignore']:
                if nli_label[i] != get_labels()['NotMentioned']:
                    true_span_labels.extend(span_label[start_index:end_index].tolist())
                    pred_span_labels.extend(span_logits[start_index:end_index].tolist())

            start_index = end_index

        true_span_labels = torch.tensor(true_span_labels, dtype=torch.float32, device=DEVICE)
        pred_span_labels = torch.tensor(pred_span_labels, dtype=torch.float32, device=DEVICE)

        true_span_labels = true_span_labels.view(-1)
        pred_span_labels = pred_span_labels.view(-1)

        if len(true_span_labels) == 0 or len(pred_span_labels) != len(true_span_labels):
            span_loss = torch.tensor(0, dtype=torch.float32, device=DEVICE)
        else:
            span_loss = self.model.span_criterion(pred_span_labels, true_span_labels)

        nli_loss = self.model.nli_criterion(nli_logits, nli_label)

        if torch.isnan(nli_loss):
            nli_loss = torch.tensor(0, dtype=torch.float32, device=DEVICE)

        if torch.isnan(span_loss):
            span_loss = torch.tensor(0, dtype=torch.float32, device=DEVICE)

        loss = span_loss + self.model.lambda_ * nli_loss

        if loss.item() == 0:
            loss = torch.tensor(0, dtype=torch.float32, device=DEVICE, requires_grad=True)

        return (loss, outputs) if return_outputs else loss

    @staticmethod
    def collate_fn(features):
        span_indices_list = [feature['span_indices'] for feature in features]
        max_len = max([len(span_indices) for span_indices in span_indices_list])
        span_indices_list = [torch.cat([span_indices, torch.zeros(max_len - len(span_indices), dtype=torch.long)]) for span_indices in span_indices_list]

        span_ids_list = [feature['data_for_metrics']['span_ids'] for feature in features]
        max_len = max([len(span_ids) for span_ids in span_ids_list])
        span_ids_list = [torch.cat([span_ids, torch.zeros(max_len - len(span_ids), dtype=torch.long)]) for span_ids in span_ids_list]

        input_ids = torch.stack([feature['input_ids'] for feature in features])
        attention_mask = torch.stack([feature['attention_mask'] for feature in features])
        token_type_ids = torch.stack([feature['token_type_ids'] for feature in features])
        span_indices = torch.stack(span_indices_list)
        nli_label = torch.stack([feature['nli_label'] for feature in features])
        span_label = torch.cat([feature['span_labels'] for feature in features], dim=0)
        data_for_metrics = {
            'doc_id': torch.stack([feature['data_for_metrics']['doc_id'] for feature in features]),
            'hypothesis_id': torch.stack([feature['data_for_metrics']['hypothesis_id'] for feature in features]),
            'span_ids': torch.stack(span_ids_list),
        }

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'span_indices': span_indices,
            'nli_label': nli_label,
            'span_labels': span_label,
            'data_for_metrics': data_for_metrics,
        }

    def evaluate(self, eval_dataset=None, ignore_keys=None):
        ic(eval_dataset)
        self.model.eval()
        self.dataloader = ic(self.get_eval_dataloader(eval_dataset))
        
        true_span_labels = []
        pred_span_labels = []
        
        true_nli_labels = []
        pred_nli_labels = []

        for inputs in tqdm(self.dataloader):
            inputs = self._prepare_inputs(inputs)
            ic(inputs['input_ids'].shape)
            eval_span_labels = inputs.pop('span_labels')
            eval_nli_labels = inputs.pop('nli_label')
            inputs.pop('data_for_metrics')

            with torch.no_grad():
                outputs = self.model(**inputs)
                span_logits, nli_logits = outputs[0], outputs[1]

                span_preds = ic(torch.sigmoid(span_logits)) >= 0.475
                span_preds = torch.tensor(span_preds.squeeze(1), dtype=torch.long)
                nli_preds = torch.argmax(torch.softmax(nli_logits, dim=1), dim=1)

                start_index = 0
                end_index = 0

                for i, span_index_row in enumerate(inputs['span_indices']):
                    # find first zero index
                    first_zero_index = torch.where(span_index_row == 0)[0]

                    if len(first_zero_index) == 0:
                        first_zero_index = len(span_index_row)
                    else:
                        first_zero_index = first_zero_index[0].item()
                    end_index += first_zero_index

                    if eval_nli_labels[i] != get_labels()['NotMentioned']:
                        true_span_labels.extend(eval_span_labels[start_index:end_index].tolist())
                        pred_span_labels.extend(span_preds[start_index:end_index].tolist())

                    true_nli_labels.append(eval_nli_labels[i].item())
                    pred_nli_labels.append(nli_preds[i].item())

                    start_index = end_index


                # true_span_labels.extend(eval_span_labels.tolist())
                # pred_span_labels.extend(span_preds.tolist())
                
                # true_nli_labels.extend(eval_nli_labels.tolist())
                # pred_nli_labels.extend(nli_preds.tolist())
            
            # zip and print the true_span_labels and pred_span_labels together
            # ic(list(zip(true_span_labels, pred_span_labels)))
            # ic(list(zip(true_nli_labels, pred_nli_labels)))

        eval_nli_acc = accuracy_score(true_nli_labels, pred_nli_labels)
        
        ic.enable()

        ic(len(true_span_labels), len(pred_span_labels))

        # print any span label if pred_span_label is 1
        ic(sum(pred_span_labels), sum(true_span_labels))

        mAP = calculate_micro_average_precision(torch.tensor(true_span_labels), torch.tensor(pred_span_labels))        

        # mAP = average_precision_score(torch.tensor(true_span_labels), torch.tensor(pred_span_labels))
        precision_at_80_recall = get_micro_average_precision_at_recall(torch.tensor(true_span_labels), torch.tensor(pred_span_labels), 0.8)
        f1_score_for_entailment = calculate_f1_score_for_class(torch.tensor(true_nli_labels), torch.tensor(pred_nli_labels), get_labels()['Entailment'])
        f1_score_for_contradiction = calculate_f1_score_for_class(torch.tensor(true_nli_labels), torch.tensor(pred_nli_labels), get_labels()['Contradiction'])
        
        return {
            'mAP' : mAP,
            'precision_at_80_recall' : precision_at_80_recall,
            'nli_acc': eval_nli_acc,
            'f1_score_for_entailment': f1_score_for_entailment,
            'f1_score_for_contradiction': f1_score_for_contradiction
        }

In [59]:
trainer = ContractNLITrainer(
    model=model,                          # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=dev_dataset,            # evaluation dataset
    data_collator=ContractNLITrainer.collate_fn,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.001)],
)

In [60]:
ic.disable()
results = trainer.evaluate()
results

  span_preds = torch.tensor(span_preds.squeeze(1), dtype=torch.long)
100%|██████████| 309/309 [00:50<00:00,  6.14it/s]
ic| len(true_span_labels): 2797, len(pred_span_labels): 2797
ic| sum(pred_span_labels): 1404, sum(true_span_labels): 1284
ic| precision_score(
        y_true[y_true_indices], y_pred[y_true_indices], average="micro"
    ): 0.4930601454064772
ic| precision_score(
        y_true[y_true_indices], y_pred[y_true_indices], average="micro"
    ): 0.4961059190031153


{'mAP': 0.49458303220479627,
 'precision_at_80_recall': 0.4569360182099378,
 'nli_acc': 0.5880210441116956,
 'f1_score_for_entailment': 0.2794649313087491,
 'f1_score_for_contradiction': 0.1797752808988764}

In [42]:
ic.disable()
results = trainer.evaluate()
results

  span_preds = torch.tensor(span_preds.squeeze(1), dtype=torch.long)
100%|██████████| 309/309 [00:33<00:00,  9.11it/s]
ic| len(true_span_labels): 2797, len(pred_span_labels): 2797
ic| sum(pred_span_labels): 18, sum(true_span_labels): 1284
ic| precision_score(
        y_true[y_true_indices], y_pred[y_true_indices], average="micro"
    ): 0.9940515532055518
ic| precision_score(
        y_true[y_true_indices], y_pred[y_true_indices], average="micro"
    ): 0.007009345794392523


{'mAP': 0.5005304494999722,
 'precision_at_80_recall': 0.467308418683884,
 'nli_acc': 0.5880210441116956,
 'f1_score_for_entailment': 0.2794649313087491,
 'f1_score_for_contradiction': 0.1797752808988764}