# Setup

In [1]:
import os
import sys
from tqdm import tqdm
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List,Tuple, Optional
import json
#
module_path = os.path.abspath(os.path.join('../'))
if module_path not in sys.path:
    sys.path.insert(0,module_path)
print('module_path : ', module_path)
from config import SciFactT5Config
import definitions
sys.path.append(os.path.dirname(definitions.PROJECT_VARS.ROOT_DIR))
print('definitions.PROJECT_VARS.ROOT_DIR : ', definitions.PROJECT_VARS.ROOT_DIR)
print('os.path.dirname(definitions.PROJECT_VARS.ROOT_DIR) : ',os.path.dirname(definitions.PROJECT_VARS.ROOT_DIR))
#

from T5ParEvo.src.data.data import Claim, Label, ClaimPredictions,GoldDataset
from T5ParEvo.src.linguistic.ner_abbr import NEREntity, Abbreviation
from multivers import util
from multivers.data_r import ClaimDataLoaderGenerator, get_dataloader, DataLoaderGenerator
from multivers.model_r import MultiVerSModel

from T5ParEvo.target_system.multivers.multivers_interface import PredictionParams, ModelPredictor

module_path :  /home/qudratealahyratu/research/nlp/fact_checking/my_work/T5ParEvo
root dir :  /home/qudratealahyratu/research/nlp/fact_checking/my_work/T5ParEvo
definitions.PROJECT_VARS.ROOT_DIR :  /home/qudratealahyratu/research/nlp/fact_checking/my_work/T5ParEvo
os.path.dirname(definitions.PROJECT_VARS.ROOT_DIR) :  /home/qudratealahyratu/research/nlp/fact_checking/my_work
/home/qudratealahyratu/research/nlp/fact_checking/my_work/T5ParEvo


  "The `@auto_move_data` decorator is deprecated in v1.3 and will be removed in v1.5."


In [2]:
# This dataset to be used only for training
cfg= SciFactT5Config()
ds_train = GoldDataset(cfg.target_dataset.loc_target_dataset_corpus,
                    cfg.target_dataset.loc_target_dataset_train)
claim_train = ds_train.get_claim(39)

In [3]:

params = PredictionParams(
    checkpoint_path= "/home/qudratealahyratu/research/nlp/fact_checking/my_work/multivers/checkpoints/scifact.ckpt",
    output_file= None,#"prediction/pred_opt_scifact.jsonl",
    batch_size=5,
    device=0,
    num_workers=4,
    no_nei=False,
    force_rationale=False,
    debug=False,
)
corpus_file = cfg.target_dataset.loc_target_dataset_corpus#cfg.target_dataset.loc_target_dataset_test#"/home/qudratealahyratu/research/nlp/fact_checking/my_work/multivers/data/scifact/corpus.jsonl"

gold_claims = Claim.load_claims_from_file(cfg.target_dataset.loc_target_dataset_test) 


In [4]:
#get unique claims
unique_gold_claims = Claim.get_unique_claims(gold_claims)
# Predict for unique claims
dataloader_generator = DataLoaderGenerator(params, unique_gold_claims, corpus_file)
dataloader = dataloader_generator.get_dataloader_by_claims()
predictor = ModelPredictor(params, dataloader)
prediction_formatted = predictor.run()

Some weights of the model checkpoint at allenai/longformer-large-4096 were not used when initializing LongformerModel: ['lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  stream(template_mgs % msg_args)
100%|██████████| 594/594 [05:03<00:00,  1.96it/s]


In [5]:
# prediction_formatted
claim_org_predictions: List[ClaimPredictions] = []
# format all the predictions
for cur_prediction in tqdm(prediction_formatted, desc="Formatting predictions"):
    cur_claim = Claim.get_claim_by_id(gold_claims, cur_prediction['id'])
    claim_predictions = ClaimPredictions.from_formatted_prediction(cur_prediction, gold_claim = cur_claim)
    claim_org_predictions.append(claim_predictions)

Formatting predictions: 100%|██████████| 297/297 [00:00<00:00, 50060.61it/s]


In [6]:
claim_org_predictions[0].gold

Example 7: 10-20% of people with severe mental disorder receive no treatment in low and middle income countries.

In [7]:
import copy

par_claim_1 = copy.deepcopy(claim_org_predictions[0])
par_claim_1.gold.claim = '10-20 percent of people with high mental disorder receive no treatment in low and middle income countries.'

# from T5ParEvo.src.data.data import Claim, Label, ClaimPredictions,GoldDataset,ParaphrasedClaim



## Entailment

In [8]:
# Assuming cfg is your OmegaConf object
# from T5ParEvo.src.linguistic.entailment import NliLabels, EntailmentModel#, EntailmentChecker
from T5ParEvo.src.data.data import Claim, ParaphrasedClaim
import torch
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
from enum import Enum


@dataclass(frozen=True)
class NliLabels(Enum):
    CONTRADICTION = 0
    NEUTRAL = 1
    ENTAILMENT = 2
    
    def __str__(self):
        return self.name
    
    def __repr__(self):
        return self.name
    @property
    def description(self):
        descriptions = {
            "CONTRADICTION": "The sentences have opposing meanings.",
            "NEUTRAL": "The sentences are not related in any specific way.",
            "ENTAILMENT": "The sentences have the same meaning or one implies the other."
        }
        return descriptions[self.name]
    
    @classmethod
    def from_string(cls, label_name):
        return cls[label_name.upper()]
    
@dataclass 
class EntailmentModel:
    model_repo : str = 'pytorch/fairseq'
    model_name : str = 'roberta.large.mnli'    


@dataclass
class EntailmentChecker:
    model: torch.nn.Module = field(init=False)
    model_config: EntailmentModel = field(default=EntailmentModel())
    label_mapping: Dict[int, NliLabels] = field(default_factory=lambda: {0: NliLabels.CONTRADICTION, 1: NliLabels.NEUTRAL, 2: NliLabels.ENTAILMENT})
    device: str = field(default='cuda' if torch.cuda.is_available() else 'cpu')

    def __post_init__(self):
        self.model = self._load_model()
        self.model.to(self.device)
        self.model.eval()

    def _load_model(self) -> torch.nn.Module:
        return torch.hub.load(self.model_config.model_repo, self.model_config.model_name)

    def check_entailment_by_paraphrased_claim(self, paraphrased_claim: ParaphrasedClaim) -> None:
        labels_org_gen = self._get_labels(paraphrased_claim.original_claim.claim, paraphrased_claim.paraphrased_claim.claim)
        labels_gen_org = self._get_labels(paraphrased_claim.paraphrased_claim.claim, paraphrased_claim.original_claim.claim)

        paraphrased_claim.nli_label = labels_org_gen[1] if labels_org_gen[1] == labels_gen_org[1] else None

    def check_entailment_by_claims(self, original_claim: Claim, paraphrased_claim: Claim) -> Dict[str, Union[NliLabels, float]]:
        """_summary_

        Args:
            original_claim (Claim): _description_
            paraphrased_claim (Claim): _description_

        Returns:
            Dict[str, Union[NliLabels, float]]: return {'nli_label_org_gen' : x[1], 'nli_label_gen_org': y[1], 'nli_val_org_gen': x[0] , 'nli_val_gen_org': y[0]}
        """
        labels_org_gen = self._get_labels(original_claim.claim, paraphrased_claim.claim)
        labels_gen_org = self._get_labels(paraphrased_claim.claim, original_claim.claim)
        # print(labels_org_gen, labels_gen_org)

        return {'nli_label_org_gen' : labels_org_gen[1], 'nli_label_gen_org': labels_gen_org[1], 'nli_val_org_gen': labels_org_gen[0] , 'nli_val_gen_org': labels_gen_org[0]}
        # paraphrased_claim.nli_label = labels_org_gen[1] if labels_org_gen[1] == labels_gen_org[1] else None
        # return paraphrased_claim

    def print_label_mapping(self):
        for key, value in self.label_mapping.items():
            print(f'{key}: {value.name}')

    def _get_labels(self, sentence1: str, sentence2: str) -> List[Union[int, NliLabels]]:
        tokens_sentences = self.model.encode(sentence1, sentence2)
        logprobs_sentences = self.model.predict('mnli', tokens_sentences)
        cal_val_mlnli = logprobs_sentences.argmax(dim=1).item()
        # print(cal_val_mlnli)
        cal_label_mlnli = self.label_mapping[cal_val_mlnli]

        return [cal_val_mlnli, cal_label_mlnli]

checker = EntailmentChecker()
# checker.print_label_mapping()
# print(checker.label_mapping)

# For a given original claim and generated claim

# You can get the entailment label
entailment_labels = checker.check_entailment_by_claims(original_claim =  claim_org_predictions[0].gold, paraphrased_claim = par_claim_1.gold)


Using cache found in /home/qudratealahyratu/.cache/torch/hub/pytorch_fairseq_main
2023-06-24 22:52:16 | INFO | fairseq.file_utils | loading archive file http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz from cache at /home/qudratealahyratu/.cache/torch/pytorch_fairseq/7685ba8546f9a5ce1a00c7a6d7d44f7e748d22681172f0f391c3d48f487c801c.74e37d47306b3cc51c5f8d335022a392c29f1906c8cd9e9cd3446d7422cf55d8
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  caller_stack_depth=caller_stack_depth + 1,
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
  See {url} for more information"""
'config' is validated against ConfigStore schema with the same name.
This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.
  state = load_

ANTLR runtime and generated code versions disagree: 4.9.3!=4.8
ANTLR runtime and generated code versions disagree: 4.9.3!=4.8


The strict flag in the compose API is deprecated.
See https://hydra.cc/docs/upgrades/0.11_to_1.0/strict_mode_flag_deprecated for more info.

  """
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  caller_stack_depth=caller_stack_depth + 1,
'config' is validated against ConfigStore schema with the same name.
This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.
  **kwargs,
2023-06-24 22:52:20 | INFO | fairseq.tasks.masked_lm | dictionary: 50264 types
2023-06-24 22:52:32 | INFO | fairseq.models.roberta.model | {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 100, 'log_format': 'json', 'log_file': None, 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 8, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_eff

In [9]:
entailment_labels

{'nli_label_org_gen': ENTAILMENT,
 'nli_label_gen_org': ENTAILMENT,
 'nli_val_org_gen': 2,
 'nli_val_gen_org': 2}