In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging as log
log.basicConfig(level=log.DEBUG)

In [2]:
import sys
sys.path.append('../')
from baselines.utils import *
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

os.environ['WANDB_ENTITY'] = 'contract-nli-db'
os.environ['WANDB_PROJECT'] = 'contract-nli-metric'
os.environ['WANDB_LOG_MODEL'] = 'end'

In [3]:
from ipynb.fs.defs.contract_nli_bert_train import *

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpzo8acc1k
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpzo8acc1k/_remote_module_non_scriptable.py


In [4]:
import torch

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE

device(type='cpu')

In [5]:
# cfg['model_name'] = 'nlpaueb/legal-bert-base-uncased'
cfg['model_name'] = 'bert-base-uncased'
cfg['trained_model_dir'] = './scratch/shu7bh/contract_nli/trained_model/'
cfg['batch_size'] = 32
cfg

{'raw_data_dir': '../dataset/',
 'train_path': 'train.json',
 'test_path': 'test.json',
 'dev_path': 'dev.json',
 'model_name': 'bert-base-uncased',
 'max_length': 512,
 'models_save_dir': './scratch/shu7bh/contract_nli/models',
 'dataset_dir': './scratch/shu7bh/contract_nli/dataset',
 'results_dir': './scratch/shu7bh/contract_nli/results',
 'trained_model_dir': './scratch/shu7bh/contract_nli/trained_model/',
 'batch_size': 32}

In [6]:
# create dir if not exists
from pathlib import Path
Path(cfg["models_save_dir"]).mkdir(parents=True, exist_ok=True)
Path(cfg["dataset_dir"]).mkdir(parents=True, exist_ok=True)

In [7]:
tokenizer = AutoTokenizer.from_pretrained(cfg['model_name'])

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


In [8]:
from icecream import ic

In [9]:
train_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['train_path']))
dev_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['dev_path']))
test_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['test_path']))

hypothesis = get_hypothesis(train_data)

train_data = train_data['documents']
dev_data = dev_data['documents']
test_data = test_data['documents']

train_data = train_data[:2]
dev_data = dev_data[:2]
test_data = test_data[:2]

ic.disable()

ic(len(train_data), len(dev_data), len(test_data))
train_dataset = NLIDataset(train_data, tokenizer, hypothesis, [1000, 1100, 1200, 1500], 50)
dev_dataset = NLIDataset(dev_data, tokenizer, hypothesis, [1000, 1100, 1200, 1500], 50)
test_dataset = NLIDataset(test_data, tokenizer, hypothesis, [1000, 1100, 1200, 1500], 50)

ic.enable()

del train_data
del dev_data
del test_data
del hypothesis

  "metadata": {},


In [10]:
print(len(train_dataset))
print(len(dev_dataset))
print(len(test_dataset))


1734
3026
2346


In [11]:
# nli_weights, span_weight = get_class_weights(train_dataset)

In [17]:
# nli_weights, span_weight

In [18]:
from sklearn.metrics import precision_recall_curve
import numpy as np
def get_micro_average_precision_at_recall(y_true, y_pred, recall_level):
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    return np.interp(recall_level, recall[::-1], precision[::-1])

In [19]:
# Import numpy and sklearn.metrics
import numpy as np
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import precision_score
def calculate_micro_average_precision(y_true, y_pred):
    """Calculate the micro average precision score.

    Args:
        y_true (np.array): True labels.
        y_pred (np.array): Predicted labels.

    Returns:
        float: Micro average precision score.
    """
    # Get the number of classes
    num_classes = len(np.unique(y_true))
    
    if num_classes == 0:
        return 0.0

    # initialize the average precision score
    average_precision = 0.0

    # loop over all classes
    for class_idx in range(num_classes):
        # get the indices for this class
        y_true_indices = np.where(y_true == class_idx)
        # calculate the average precision score for this class
        average_precision += ic(precision_score(
            y_true[y_true_indices], y_pred[y_true_indices], average="micro"
        ))

    # return the average over all classes
    return average_precision / num_classes

In [20]:
from sklearn.metrics import f1_score
def calculate_f1_score_for_class(y_true, y_pred, class_idx):
    """Calculate the F1 score for a given class.

    Args:
        y_true (np.array): True labels.
        y_pred (np.array): Predicted labels.
        class_idx (int): Index of the class.

    Returns:
        float: F1 score for the given class.
    """
    # get the indices for the given class
    y_true_indices = np.where(y_true == class_idx)
    # calculate the F1 score for the given class
    return f1_score(
        y_true[y_true_indices], y_pred[y_true_indices], average="macro"
    )

In [None]:
def precision_at_recall(y_true, y_scores, recall_threshold):
    precision, recall, threshold = precision_recall_curve(y_true, y_scores)
    idx = (np.abs(recall - recall_threshold)).argmin()  # Find nearest recall value to threshold
    # ic(threshold[idx])
    return precision[idx]

In [21]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    auto_find_batch_size=True,
    output_dir=cfg['results_dir'],   # output directory
    num_train_epochs=10,            # total number of training epochs
    gradient_accumulation_steps=4,   # number of updates steps to accumulate before performing a backward/update pass
    logging_strategy='epoch',
    # eval_steps=0.25,
    # save_steps=0.25,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,
    load_best_model_at_end=True,
    # fp16=True,
    label_names=['nli_label', 'span_labels', 'data_for_metrics'],
    report_to='none',
)

In [22]:
cfg['trained_model_dir']

'./scratch/shu7bh/contract_nli/trained_model/'

In [23]:
from transformers import EarlyStoppingCallback

In [24]:
import wandb
api = wandb.Api()
artifact = api.artifact('contract-nli-db/contract-nli/model-3xtagn3w:v0', type='model')
artifact_dir = artifact.download(cfg['trained_model_dir'])

model = ContractNLI.from_pretrained(artifact_dir).to(DEVICE)

DEBUG:git.cmd:Popen(['git', 'version'], cwd=/home/ayush/ANLP/Project/Contract-NLI/source_code, universal_newlines=False, shell=None, istream=None)
DEBUG:git.cmd:Popen(['git', 'version'], cwd=/home/ayush/ANLP/Project/Contract-NLI/source_code, universal_newlines=False, shell=None, istream=None)
DEBUG:wandb.docker.auth:Trying paths: ['/home/ayush/.docker/config.json', '/home/ayush/.dockercfg']
DEBUG:wandb.docker.auth:No config file found
DEBUG:sentry_sdk.errors:[Tracing] Create new propagation context: {'trace_id': '60faf9a9a51643e5b9bf1e5a84fe0729', 'span_id': '8fe7168d135a87e1', 'parent_span_id': None, 'dynamic_sampling_context': None}
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): api.wandb.ai:443
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 253
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): api.wandb.ai:443
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 362
DEBUG:urllib3.connect

In [25]:
artifact_dir

'./scratch/shu7bh/contract_nli/trained_model/'

In [51]:
from transformers import Trainer
from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score
from tqdm import tqdm
import numpy as np

class ContractNLIMetricTrainer(ContractNLITrainer):
    def __init__(self, *args, data_collator=None, **kwargs):
        super().__init__(*args, data_collator=data_collator, **kwargs)

    def evaluate(self, eval_dataset=None, ignore_keys=None):
        ic(eval_dataset)
        self.model.eval()
        self.dataloader = ic(self.get_eval_dataloader(eval_dataset))

        eval_nli_labels = []
        eval_nli_preds = []

        for inputs in tqdm(self.dataloader):
            inputs = self._prepare_inputs(inputs)
            ic(inputs['input_ids'].shape)
            span_labels = inputs.pop('span_labels')
            nli_labels = inputs.pop('nli_label')
            data_for_metrics = inputs.pop('data_for_metrics')
            ic(nli_labels)
            ic(span_labels)
            ic(data_for_metrics)

            span_indices_to_consider = torch.where(span_labels != -1)[0]
            ic(span_indices_to_consider)

            with torch.no_grad():
                outputs = self.model(**inputs)
                span_logits, nli_logits = outputs[0], outputs[1]


                
                
                span_labels = span_labels.float()
                span_logits = span_logits.float()
                
                span_labels = span_labels.view(-1)
                span_logits = span_logits.view(-1)

                # start_index = 0
                
                indices_considered = 0 # total number of span indices considered

                probs_per_span = {}
                true_labels_per_span = {}
                # find the corresponding span index in data_for_metrics['span_ids'] considering -1 to be padding index
                # ic(span_index)
                for i, span_index_row in enumerate(data_for_metrics['span_ids']):
                    current_index = 0 # current row's first -1 index
                    ic(span_index_row)
                    first_minus_one_index = torch.where(span_index_row == -1)[0]
                    ic(first_minus_one_index)
                    if len(first_minus_one_index) == 0:
                        first_minus_one_index = len(span_index_row)
                    else:
                        first_minus_one_index = first_minus_one_index[0].item()
                    
                    current_index = first_minus_one_index
                    indices_considered += current_index
                    ic(indices_considered)
                    ic(current_index)
                    cnt = 0 # count to keep track of the number of span indices added in dictionary
                    
                    for span_index in span_indices_to_consider:

                        if span_index < indices_considered:
                            cnt += 1
                            value_index = span_index - (indices_considered - current_index)
                            doc_id = data_for_metrics['doc_id'][i]
                            hypothesis_id = data_for_metrics['hypothesis_id'][i]
                            span_id = data_for_metrics['span_ids'][i][value_index]
                            key = str(doc_id)+ '-' + str(hypothesis_id)+ '-' + str(span_id)
                            true_labels_per_span[key] = span_labels[span_index]
                            if key in probs_per_span:
                                probs_per_span[key].append(torch.sigmoid(span_logits[span_index]))
                                # probs_per_span[key].append(span_logits[value_index])
                            else:
                                probs_per_span[key] = [torch.sigmoid(span_logits[span_index])]
                                # probs_per_span[key] = [span_logits[value_index]]
                        else: 
                            break 
                    
                    span_indices_to_consider = span_indices_to_consider[cnt:]

                # eval_span_preds = torch.tensor(eval_span_preds.squeeze(1), dtype=torch.long)
                nli_preds = torch.argmax(torch.softmax(nli_logits, dim=1), dim=1)

                eval_nli_labels.extend(nli_labels.cpu().numpy())
                eval_nli_preds.extend(nli_preds.cpu().numpy())


        eval_nli_acc = accuracy_score(eval_nli_labels, eval_nli_preds)
        eval_span_labels = []
        eval_span_preds = []

        for key in true_labels_per_span:
            eval_span_labels.append(true_labels_per_span[key].item())
            eval_span_preds.append(torch.mean(torch.stack(probs_per_span[key])).item())
        
        ic.enable()
        ic(len(eval_span_labels), len(eval_span_preds))

        ic(sum(eval_span_labels), sum(eval_span_preds))

        # find threshold for 80% recall
        # precision, recall, thresholds = precision_recall_curve(eval_span_labels, eval_span_preds)


        mAP = (average_precision_score(eval_span_labels, eval_span_preds, pos_label=0) + average_precision_score(eval_span_labels, eval_span_preds, pos_label=1))/2

        # mAP = average_precision_score(torch.tensor(true_span_labels), torch.tensor(pred_span_labels))
        precision_at_80_recall = precision_at_recall(torch.tensor(eval_span_labels), torch.tensor(eval_span_preds), 0.8)
        f1_score_for_entailment = calculate_f1_score_for_class(torch.tensor(eval_nli_labels), torch.tensor(eval_nli_preds), get_labels()['Entailment'])
        f1_score_for_contradiction = calculate_f1_score_for_class(torch.tensor(eval_nli_labels), torch.tensor(eval_nli_preds), get_labels()['Contradiction'])
        
        return {
            'mAP' : mAP,
            'precision_at_80_recall' : precision_at_80_recall,
            'nli_acc': eval_nli_acc,
            'f1_score_for_entailment': f1_score_for_entailment,
            'f1_score_for_contradiction': f1_score_for_contradiction
        }

In [52]:
trainer = ContractNLIMetricTrainer(
    model=model,                          # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=dev_dataset,            # evaluation dataset
    data_collator=ContractNLIMetricTrainer.collate_fn,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.001)],
    model_init=model_init,
)



In [57]:
# logits = torch.tensor([10, 0.2, -0.3])
# sigmoid_logits = torch.sigmoid(logits)

# mean_sigmoid_logits = torch.mean(sigmoid_logits)
# mean_logits = torch.mean(logits)
# sigmoid_mean_logits = torch.sigmoid(mean_logits)
# ic.enable()
# ic(mean_sigmoid_logits)
# ic(sigmoid_mean_logits)

ic| mean_sigmoid_logits: tensor(0.6584)


ic| sigmoid_mean_logits: tensor(0.9644)


tensor(0.9644)

In [61]:
# from sklearn.metrics import precision_recall_curve

# def precision_at_recall(y_true, y_scores, recall_threshold):
#     precision, recall, threshold = precision_recall_curve(y_true, y_scores)
#     idx = (np.abs(recall - recall_threshold)).argmin()  # Find nearest recall value to threshold
#     # ic(threshold[idx])
#     return precision[idx]

# y_true = np.array([0, 1, 1, 0, 1, 1, 1, 0, 0, 1])
# y_scores = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.7, 0.8, 0.9])
# recall_threshold = 0.8
# precision_at_recall(y_true, y_scores, recall_threshold)

# # calculate average precision score for a given class
# from sklearn.metrics import average_precision_score

# y_true = np.array([0, 1, 1, 0, 1, 1, 1, 0, 0, 1])
# y_scores = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.7, 0.8, 0.9])

# ic(average_precision_score(y_true, y_scores, pos_label=0))
# ic(average_precision_score(y_true, y_scores))

ic| average_precision_score(y_true, y_scores, pos_label=0): 0.4988095238095238
ic| average_precision_score(y_true, y_scores): 0.6763888888888888


0.6763888888888888

In [53]:
ic.disable()
# ic.enable()
results = trainer.evaluate()
results

  3%|▎         | 12/379 [01:49<56:00,  9.16s/it]


KeyboardInterrupt: 

In [40]:
# span_index_row = torch.tensor([2, 5, -1, -1, -1])
# first_minus_one_index = torch.where(span_index_row == -1)[0]
# ic(first_minus_one_index)

ic| first_minus_one_index: tensor([2, 3, 4])


tensor([2, 3, 4])

In [None]:
# from transformers import Trainer
# from sklearn.metrics import accuracy_score
# from sklearn.metrics import average_precision_score
# from tqdm import tqdm
# import numpy as np

# class ContractNLITrainer(Trainer):
#     def _init_(self, *args, data_collator=None, **kwargs):
#         super()._init_(*args, data_collator=data_collator, **kwargs)

#     def compute_loss(self, model, inputs, return_outputs=False):
#         span_label = inputs.pop('span_labels')
#         nli_label = inputs.pop('nli_label')
#         inputs.pop('data_for_metrics')

#         outputs = model(**inputs)
#         span_logits, nli_logits = outputs[0], outputs[1]

#         true_span_labels = []
#         pred_span_labels = []

#         start_index = 0
#         end_index = 0

#         for i, span_index_row in enumerate(inputs['span_indices']):
#             # find first zero index
#             first_zero_index = torch.where(span_index_row == 0)[0]

#             if len(first_zero_index) == 0:
#                 first_zero_index = len(span_index_row)
#             else:
#                 first_zero_index = first_zero_index[0].item()
#             end_index += first_zero_index

#             if nli_label[i] != get_labels()['Ignore']:
#                 if nli_label[i] != get_labels()['NotMentioned']:
#                     true_span_labels.extend(span_label[start_index:end_index].tolist())
#                     pred_span_labels.extend(span_logits[start_index:end_index].tolist())

#             start_index = end_index

#         true_span_labels = torch.tensor(true_span_labels, dtype=torch.float32, device=DEVICE)
#         pred_span_labels = torch.tensor(pred_span_labels, dtype=torch.float32, device=DEVICE)

#         true_span_labels = true_span_labels.view(-1)
#         pred_span_labels = pred_span_labels.view(-1)

#         if len(true_span_labels) == 0 or len(pred_span_labels) != len(true_span_labels):
#             span_loss = torch.tensor(0, dtype=torch.float32, device=DEVICE)
#         else:
#             span_loss = self.model.span_criterion(pred_span_labels, true_span_labels)

#         nli_loss = self.model.nli_criterion(nli_logits, nli_label)

#         if torch.isnan(nli_loss):
#             nli_loss = torch.tensor(0, dtype=torch.float32, device=DEVICE)

#         if torch.isnan(span_loss):
#             span_loss = torch.tensor(0, dtype=torch.float32, device=DEVICE)

#         loss = span_loss + self.model.lambda_ * nli_loss

#         if loss.item() == 0:
#             loss = torch.tensor(0, dtype=torch.float32, device=DEVICE, requires_grad=True)

#         return (loss, outputs) if return_outputs else loss

#     @staticmethod
#     def collate_fn(features):
#         span_indices_list = [feature['span_indices'] for feature in features]
#         max_len = max([len(span_indices) for span_indices in span_indices_list])
#         span_indices_list = [torch.cat([span_indices, torch.zeros(max_len - len(span_indices), dtype=torch.long)]) for span_indices in span_indices_list]

#         span_ids_list = [feature['data_for_metrics']['span_ids'] for feature in features]
#         max_len = max([len(span_ids) for span_ids in span_ids_list])
#         # span_ids_tensors = [torch.tensor(span_ids) for span_ids in span_ids_list]

#         # Padding
#         # span_ids_list_padded = torch.nn.utils.rnn.pad_sequence(span_ids_tensors, batch_first=True, padding_value=-1)
        
#         span_ids_list = [torch.cat([span_ids, torch.full((max_len - len(span_ids),), -1)]) for span_ids in span_ids_list]
        
#         # span_ids_list = [torch.cat([span_ids, pad_tensor]) for span_ids in span_ids_list]
#         # ic(span_ids_list)
#         # span_ids_list = [torch.cat([span_ids, torch.zeros(max_len - len(span_ids), dtype=torch.long)]) for span_ids in span_ids_list]

#         input_ids = torch.stack([feature['input_ids'] for feature in features])
#         attention_mask = torch.stack([feature['attention_mask'] for feature in features])
#         token_type_ids = torch.stack([feature['token_type_ids'] for feature in features])
#         span_indices = torch.stack(span_indices_list)
#         nli_label = torch.stack([feature['nli_label'] for feature in features])
#         span_label = torch.cat([feature['span_labels'] for feature in features], dim=0)
#         data_for_metrics = {
#             'doc_id': torch.stack([feature['data_for_metrics']['doc_id'] for feature in features]),
#             'hypothesis_id': torch.stack([feature['data_for_metrics']['hypothesis_id'] for feature in features]),
#             'span_ids': torch.stack(span_ids_list),
#         }

#         return {
#             'input_ids': input_ids,
#             'attention_mask': attention_mask,
#             'token_type_ids': token_type_ids,
#             'span_indices': span_indices,
#             'nli_label': nli_label,
#             'span_labels': span_label,
#             'data_for_metrics': data_for_metrics,
#         }

#     def evaluate(self, eval_dataset=None, ignore_keys=None):
#         ic(eval_dataset)
#         self.model.eval()
#         self.dataloader = ic(self.get_eval_dataloader(eval_dataset))
        
#         # true_span_labels = []
#         # pred_span_labels = []
        
#         true_nli_labels = []
#         pred_nli_labels = []

#         probs_per_span = {}
#         true_labels_per_span = {}

#         for inputs in tqdm(self.dataloader):
#             inputs = self._prepare_inputs(inputs)
#             ic(inputs['input_ids'].shape)
#             eval_span_labels = inputs.pop('span_labels')
#             ic(eval_span_labels)
#             eval_nli_labels = inputs.pop('nli_label')
#             data_for_metrics = inputs.pop('data_for_metrics')
#             ic(data_for_metrics)

#             with torch.no_grad():
#                 outputs = self.model(**inputs)
#                 span_logits, nli_logits = outputs[0], outputs[1]
#                 # span_preds = ic(torch.sigmoid(span_logits)) >= 0.499
#                 # ic(span_logits)
                
#                 # for i in range(len(data_for_metrics['doc_id'])):
#                 #     # ic(point)
#                 #     for j in range(len(data_for_metrics['span_ids'][i])):
#                 #         key = str(data_for_metrics['doc_id'][i])+ '-' + str(data_for_metrics['hypothesis_id'][i])+ '-' + str(data_for_metrics['span_ids'][i][j])
#                 #         true_labels_per_span[key] = eval_span_labels[i]

#                 #         if key in probs_per_span:
#                 #             probs_per_span[key].append(torch.sigmoid(span_logits[i]))
#                 #         else:
#                 #             probs_per_span[key] = [torch.sigmoid(span_logits[i])]

#                 # span_preds = torch.tensor(span_preds.squeeze(1), dtype=torch.long)
#                 nli_preds = torch.argmax(torch.softmax(nli_logits, dim=1), dim=1)


#                 start_index = 0
#                 end_index = 0

#                 for i, span_index_row in enumerate(inputs['span_indices']):
#                     # find first zero index
#                     # ic(i)
#                     # ic(span_index_row)
#                     first_zero_index = torch.where(span_index_row == 0)[0]

#                     if len(first_zero_index) == 0:
#                         first_zero_index = len(span_index_row)
#                     else:
#                         first_zero_index = first_zero_index[0].item()
#                     end_index += first_zero_index

#                     if start_index == end_index:
#                         continue
#                     # ic(start_index, end_index)
#                     if eval_nli_labels[i] != get_labels()['NotMentioned']:
#                         for j in range(start_index, end_index):
#                             ic(j)
#                             ic(data_for_metrics['span_ids'][i])
#                             key = str(data_for_metrics['doc_id'][i])+ '-' + str(data_for_metrics['hypothesis_id'][i])+ '-' + str(data_for_metrics['span_ids'][i][j-start_index])
#                             true_labels_per_span[key] = eval_span_labels[j]

#                             if key in probs_per_span:
#                                 probs_per_span[key].append(torch.sigmoid(span_logits[j]))
#                                 # probs_per_span[key].append(span_logits[j])
#                             else:
#                                 probs_per_span[key] = [torch.sigmoid(span_logits[j])]
#                                 # probs_per_span[key] = [span_logits[j]]
#                         # true_span_labels.extend(eval_span_labels[start_index:end_index].tolist())
#                         # pred_span_labels.extend(span_preds[start_index:end_index].tolist())

#                     true_nli_labels.append(eval_nli_labels[i].item())
#                     pred_nli_labels.append(nli_preds[i].item())

#                     start_index = end_index


#                 # true_span_labels.extend(eval_span_labels.tolist())
#                 # pred_span_labels.extend(span_preds.tolist())
                
#                 # true_nli_labels.extend(eval_nli_labels.tolist())
#                 # pred_nli_labels.extend(nli_preds.tolist())
            
#             # zip and print the true_span_labels and pred_span_labels together
#             # ic(list(zip(true_span_labels, pred_span_labels)))
#             # ic(list(zip(true_nli_labels, pred_nli_labels)))

#         pred_span_labels = []
#         true_span_labels = []

#         for key, probs in probs_per_span.items():
#             probs = torch.stack(probs).mean(dim=0)
#             # probs = torch.sigmoid(torch.stack(probs).mean(dim=0))
#             pred_span_labels.append(int(probs >= 0.50))
#             true_span_labels.append(true_labels_per_span[key])

#         pred_span_labels = torch.tensor(pred_span_labels, dtype=torch.long)
#         true_span_labels = torch.tensor(true_span_labels, dtype=torch.long)

#         ic(pred_span_labels)
#         ic(true_span_labels)

#         eval_nli_acc = accuracy_score(true_nli_labels, pred_nli_labels)
        
#         ic.enable()

#         ic(len(true_span_labels), len(pred_span_labels))

#         # print any span label if pred_span_label is 1
#         ic(sum(pred_span_labels), sum(true_span_labels))

#         # mAP = calculate_micro_average_precision(torch.tensor(true_span_labels), torch.tensor(pred_span_labels))    
#         # calculate micro average precision (mAP)
        
#         mAP = average_precision_score(torch.tensor(true_span_labels), torch.tensor(pred_span_labels), average='micro')    

#         # mAP = average_precision_score(torch.tensor(true_span_labels), torch.tensor(pred_span_labels))
#         precision_at_80_recall = get_micro_average_precision_at_recall(torch.tensor(true_span_labels), torch.tensor(pred_span_labels), 0.8)
#         f1_score_for_entailment = calculate_f1_score_for_class(torch.tensor(true_nli_labels), torch.tensor(pred_nli_labels), get_labels()['Entailment'])
#         f1_score_for_contradiction = calculate_f1_score_for_class(torch.tensor(true_nli_labels), torch.tensor(pred_nli_labels), get_labels()['Contradiction'])
        
#         return {
#             'mAP' : mAP,
#             'precision_at_80_recall' : precision_at_80_recall,
#             'nli_acc': eval_nli_acc,
#             'f1_score_for_entailment': f1_score_for_entailment,
#             'f1_score_for_contradiction': f1_score_for_contradiction
#         }