In [3]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging as log
log.basicConfig(level=log.DEBUG)

In [4]:
import sys
sys.path.append('../')
from baselines.utils import *
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

os.environ['WANDB_ENTITY'] = 'contract-nli-db'
os.environ['WANDB_PROJECT'] = 'contract-nli-metric'
os.environ['WANDB_LOG_MODEL'] = 'end'

In [5]:
from ipynb.fs.defs.contract_nli_bert_train import *

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmps5sbe_m2
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmps5sbe_m2/_remote_module_non_scriptable.py


In [6]:
import torch

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE

device(type='cuda')

In [7]:
# cfg['model_name'] = 'nlpaueb/legal-bert-base-uncased'
cfg['model_name'] = 'bert-base-uncased'
cfg['trained_model_dir'] = './scratch/shu7bh/contract_nli/trained_model/'
cfg['batch_size'] = 32
cfg

{'raw_data_dir': '../dataset/',
 'train_path': 'train.json',
 'test_path': 'test.json',
 'dev_path': 'dev.json',
 'model_name': 'bert-base-uncased',
 'max_length': 512,
 'models_save_dir': './scratch/shu7bh/contract_nli/models',
 'dataset_dir': './scratch/shu7bh/contract_nli/dataset',
 'results_dir': './scratch/shu7bh/contract_nli/results',
 'trained_model_dir': './scratch/shu7bh/contract_nli/trained_model/',
 'batch_size': 32}

In [8]:
# create dir if not exists
from pathlib import Path
Path(cfg["models_save_dir"]).mkdir(parents=True, exist_ok=True)
Path(cfg["dataset_dir"]).mkdir(parents=True, exist_ok=True)

In [9]:
tokenizer = AutoTokenizer.from_pretrained(cfg['model_name'])

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


In [10]:
from icecream import ic

In [11]:
dev_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['dev_path']))
test_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['test_path']))

hypothesis = get_hypothesis(dev_data)

dev_data = dev_data['documents']
test_data = test_data['documents']

dev_data = dev_data[:2]
test_data = test_data[:2]

ic.disable()

ic(len(dev_data), len(test_data))
dev_dataset = NLIDataset(dev_data, tokenizer, hypothesis, [1000, 1100, 1200, 1500], 50)
test_dataset = NLIDataset(test_data, tokenizer, hypothesis, [1000, 1100, 1200, 1500], 50)

ic.enable()

del dev_data
del test_data
del hypothesis

  "metadata": {},


In [12]:
print(len(dev_dataset))
print(len(test_dataset))


3026
2346


In [13]:
# nli_weights, span_weight = get_class_weights(train_dataset)

In [14]:
# nli_weights, span_weight

In [15]:
from sklearn.metrics import precision_recall_curve
import numpy as np
def get_micro_average_precision_at_recall(y_true, y_pred, recall_level):
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    return np.interp(recall_level, recall[::-1], precision[::-1])

In [16]:
# Import numpy and sklearn.metrics
import numpy as np
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import precision_score
def calculate_micro_average_precision(y_true, y_pred):
    """Calculate the micro average precision score.

    Args:
        y_true (np.array): True labels.
        y_pred (np.array): Predicted labels.

    Returns:
        float: Micro average precision score.
    """
    # Get the number of classes
    num_classes = len(np.unique(y_true))
    
    if num_classes == 0:
        return 0.0

    # initialize the average precision score
    average_precision = 0.0

    # loop over all classes
    for class_idx in range(num_classes):
        # get the indices for this class
        y_true_indices = np.where(y_true == class_idx)
        # calculate the average precision score for this class
        average_precision += ic(precision_score(
            y_true[y_true_indices], y_pred[y_true_indices], average="micro"
        ))

    # return the average over all classes
    return average_precision / num_classes

In [17]:
from sklearn.metrics import f1_score
def calculate_f1_score_for_class(y_true, y_pred, class_idx):
    """Calculate the F1 score for a given class.

    Args:
        y_true (np.array): True labels.
        y_pred (np.array): Predicted labels.
        class_idx (int): Index of the class.

    Returns:
        float: F1 score for the given class.
    """
    # get the indices for the given class
    y_true_indices = np.where(y_true == class_idx)
    # calculate the F1 score for the given class
    return f1_score(
        y_true[y_true_indices], y_pred[y_true_indices], average="macro"
    )

In [18]:
def precision_at_recall(y_true, y_scores, recall_threshold):
    precision, recall, threshold = precision_recall_curve(y_true, y_scores)
    idx = (np.abs(recall - recall_threshold)).argmin()  # Find nearest recall value to threshold
    ic(threshold[idx])
    return precision[idx]

In [19]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    auto_find_batch_size=True,
    output_dir=cfg['results_dir'],   # output directory
    num_train_epochs=10,            # total number of training epochs
    gradient_accumulation_steps=4,   # number of updates steps to accumulate before performing a backward/update pass
    logging_strategy='epoch',
    # eval_steps=0.25,
    # save_steps=0.25,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,
    load_best_model_at_end=True,
    # fp16=True,
    label_names=['nli_label', 'span_labels', 'data_for_metrics'],
    report_to='none',
)

In [20]:
cfg['trained_model_dir']

'./scratch/shu7bh/contract_nli/trained_model/'

In [21]:
from transformers import EarlyStoppingCallback

In [22]:
import wandb
api = wandb.Api()
artifact = api.artifact('contract-nli-db/contract-nli/model-ax48o5vj:v0', type='model')
artifact_dir = artifact.download(cfg['trained_model_dir'])

model = ContractNLI.from_pretrained(artifact_dir).to(DEVICE)

DEBUG:git.cmd:Popen(['git', 'version'], cwd=/home2/shu7bh/Courses/ANLP/Project/Contract-NLI/source_code, universal_newlines=False, shell=None, istream=None)
DEBUG:git.cmd:Popen(['git', 'version'], cwd=/home2/shu7bh/Courses/ANLP/Project/Contract-NLI/source_code, universal_newlines=False, shell=None, istream=None)
DEBUG:wandb.docker.auth:Trying paths: ['/home2/shu7bh/.docker/config.json', '/home2/shu7bh/.dockercfg']
DEBUG:wandb.docker.auth:No config file found
DEBUG:sentry_sdk.errors:[Tracing] Create new propagation context: {'trace_id': 'a363d3dc253643eeb5a5f0332cb648ea', 'span_id': 'a61d9abe991486c9', 'parent_span_id': None, 'dynamic_sampling_context': None}


DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): api.wandb.ai:443
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 253
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): api.wandb.ai:443
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 362
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 521
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): storage.googleapis.com:443
DEBUG:urllib3.connectionpool:https://storage.googleapis.com:443 "GET /wandb-production.appspot.com/contract-nli-db/contract-nli/ax48o5vj/artifact/628170767/wandb_manifest.json?Expires=1699125232&GoogleAccessId=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com&Signature=wX4jPDQMeWVQdQjlxAF7x%2F9K5zGvAFwk2kDWIuJ1asByPhUWnuZrXvuXnE3f7UupNq%2F9FJWuYZletEA4358GeYO%2BJaBSttfiy5L1x6zXoEo95HApZ1isM9E%2BFWRBFH3G2AfUWu0aaDNf5VP05qJrZA6wLEtyKYpzF%2BTwBbIVrHiGFwoWWppYENps

In [23]:
artifact_dir

'./scratch/shu7bh/contract_nli/trained_model/'

In [24]:
from transformers import Trainer
from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score
from tqdm import tqdm
import numpy as np

class ContractNLIMetricTrainer(ContractNLITrainer):
    def __init__(self, *args, data_collator=None, **kwargs):
        super().__init__(*args, data_collator=data_collator, **kwargs)

    def evaluate(self, eval_dataset=None, ignore_keys=None):
        self.model.eval()
        self.dataloader = ic(self.get_eval_dataloader(eval_dataset))

        eval_nli_labels = []
        eval_nli_preds = []
        true_labels_per_span = {}
        probs_per_span = {}

        nli_metrics = {}

        for inputs in tqdm(self.dataloader):
            inputs = self._prepare_inputs(inputs)
            span_labels = inputs.pop('span_labels')
            nli_labels = inputs.pop('nli_label')
            data_for_metrics = inputs.pop('data_for_metrics')

            span_indices_to_consider = torch.where(span_labels != -1)[0]

            with torch.no_grad():
                outputs = self.model(**inputs)
                span_logits, nli_logits = outputs[0], outputs[1]

                span_labels = span_labels.float()
                span_logits = span_logits.float()
                
                span_labels = span_labels.view(-1)
                span_logits = span_logits.view(-1)

                # start_index = 0
                
                indices_considered = 0 # total number of span indices considered

                # find the corresponding span index in data_for_metrics['span_ids'] considering -1 to be padding index
                # ic(span_index)
                for i, span_index_row in enumerate(data_for_metrics['span_ids']):
                    current_index = 0 # current row's first -1 index
                    ic(span_index_row)
                    first_minus_one_index = torch.where(span_index_row == -1)[0]
                    ic(first_minus_one_index)
                    if len(first_minus_one_index) == 0:
                        first_minus_one_index = len(span_index_row)
                    else:
                        first_minus_one_index = first_minus_one_index[0].item()

                    key = str(data_for_metrics['doc_id'][i].item())+ '-' + str(data_for_metrics['hypothesis_id'][i].item())

                    # mask span_labels and span_logits for the current row
                    mask = span_labels[indices_considered:indices_considered+first_minus_one_index] != -1
                    span_logits_masked = span_logits[indices_considered:indices_considered+first_minus_one_index][mask]

                    spans_contribution = torch.sum(torch.sigmoid(span_logits_masked)) / (len(span_logits_masked)) 

                    if key in nli_metrics:
                        nli_metrics[key]['spans_contribution'].append(spans_contribution)
                        nli_metrics[key]['nli_logits'].append(nli_logits[i])
                    else:
                        nli_metrics[key] = {}
                        nli_metrics[key]['true_nli_labels'] = nli_labels[i]
                        nli_metrics[key]['spans_contribution'] = [spans_contribution]
                        nli_metrics[key]['nli_logits'] = [nli_logits[i]]
                    
                    current_index = first_minus_one_index
                    indices_considered += current_index
                    
                    ic(indices_considered)
                    ic(current_index)
                    cnt = 0 # count to keep track of the number of span indices added in dictionary
                    
                    for span_index in span_indices_to_consider:

                        if span_index < indices_considered:
                            cnt += 1
                            value_index = span_index - (indices_considered - current_index)
                            doc_id = data_for_metrics['doc_id'][i]
                            hypothesis_id = data_for_metrics['hypothesis_id'][i]
                            span_id = data_for_metrics['span_ids'][i][value_index]
                            key = str(doc_id)+ '-' + str(hypothesis_id)+ '-' + str(span_id)
                            true_labels_per_span[key] = span_labels[span_index]
                            if key in probs_per_span:
                                probs_per_span[key].append(torch.sigmoid(span_logits[span_index]))
                                # probs_per_span[key].append(span_logits[value_index])
                            else:
                                probs_per_span[key] = [torch.sigmoid(span_logits[span_index])]
                                # probs_per_span[key] = [span_logits[value_index]]
                        else: 
                            break 
                    
                    span_indices_to_consider = span_indices_to_consider[cnt:]

                # eval_span_preds = torch.tensor(eval_span_preds.squeeze(1), dtype=torch.long)

                nli_preds = torch.argmax(torch.softmax(nli_logits, dim=1), dim=1)
                eval_nli_labels.extend(nli_labels.cpu().numpy())
                eval_nli_preds.extend(nli_preds.cpu().numpy())

        eval_span_labels = []
        eval_span_preds = []

        for key in true_labels_per_span:
            eval_span_labels.append(true_labels_per_span[key].item())
            eval_span_preds.append(torch.mean(torch.stack(probs_per_span[key])).item())

        ##### For NLI probablities #####

        # for key in nli_metrics:
        #     nli_metrics[key]['nli_logits'] = torch.stack(nli_metrics[key]['nli_logits'])
        #     nli_metrics[key]['spans_contribution'] = torch.stack(nli_metrics[key]['spans_contribution'])

        #     span_sum = torch.sum(nli_metrics[key]['spans_contribution'])
        #     spans_contribution = nli_metrics[key]['spans_contribution'].transpose(0, -1) @ nli_metrics[key]['nli_logits']

        #     eval_nli_preds.append(torch.argmax(torch.softmax(spans_contribution/span_sum, dim=0)).item())
        #     eval_nli_labels.append(nli_metrics[key]['true_nli_labels'].item())

        ##### END #####

        eval_nli_acc = accuracy_score(eval_nli_labels, eval_nli_preds)

        ic.enable()
        ic(list(zip(eval_span_labels, eval_span_preds)))
        ic(len(eval_span_labels), len(eval_span_preds))
        ic(sum(eval_span_labels), sum(eval_span_preds))

        # find threshold for 80% recall
        # precision, recall, thresholds = precision_recall_curve(eval_span_labels, eval_span_preds)


        mAP = (average_precision_score(eval_span_labels, eval_span_preds, pos_label=0) + average_precision_score(eval_span_labels, eval_span_preds, pos_label=1))/2

        # mAP = average_precision_score(torch.tensor(true_span_labels), torch.tensor(pred_span_labels))
        precision_at_80_recall = precision_at_recall(torch.tensor(eval_span_labels), torch.tensor(eval_span_preds), 0.8)
        f1_score_for_entailment = calculate_f1_score_for_class(torch.tensor(eval_nli_labels), torch.tensor(eval_nli_preds), get_labels()['Entailment'])
        f1_score_for_contradiction = calculate_f1_score_for_class(torch.tensor(eval_nli_labels), torch.tensor(eval_nli_preds), get_labels()['Contradiction'])
        
        return {
            'mAP' : mAP,
            'precision_at_80_recall' : precision_at_80_recall,
            'nli_acc': eval_nli_acc,
            'f1_score_for_entailment': f1_score_for_entailment,
            'f1_score_for_contradiction': f1_score_for_contradiction
        }

In [25]:
trainer = ContractNLIMetricTrainer(
    model=model,                          # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    # train_dataset=train_dataset,         # training dataset
    eval_dataset=dev_dataset,            # evaluation dataset
    data_collator=ContractNLIMetricTrainer.collate_fn,
)

In [26]:
ic.disable()
# ic.enable()
results = trainer.evaluate()
results

100%|██████████| 379/379 [01:12<00:00,  5.20it/s]
ic| list(zip(eval_span_labels, eval_span_preds)): [(1.0, 0.4390585720539093),
                                                   (1.0, 0.5365281105041504),
                                                   (0.0, 0.02797091193497181),
                                                   (0.0, 0.03854534402489662),
                                                   (0.0, 0.038280781358480453),
                                                   (0.0, 0.03133000433444977),
                                                   (0.0, 0.02188614197075367),
                                                   (0.0, 0.059788208454847336),
                                                   (0.0, 0.11410053074359894),
                                                   (0.0, 0.09806304425001144),
                                                   (0.0, 0.5150961875915527),
                                                   (0.0, 0.48969197273254395),
   

{'mAP': 0.6862852110255822,
 'precision_at_80_recall': 0.31724137931034485,
 'nli_acc': 0.3198942498347654,
 'f1_score_for_entailment': 0.1134901134901135,
 'f1_score_for_contradiction': 0.15954415954415954}