# NLI Validation

In [None]:
!pip install -q transformers torch
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from transformers import pipeline

nli = pipeline("text-classification", model="facebook/bart-large-mnli", device=0)
print("Model loaded")

In [None]:
# Load original and expanded documents
originals = {}
with open('/content/drive/MyDrive/hqf_de/collection_100k.tsv') as f:
    for line in f:
        parts = line.strip().split('\t', 1)
        if len(parts) == 2:
            originals[parts[0]] = parts[1]

expanded = []
with open('/content/drive/MyDrive/hqf_de/expanded_100k.tsv') as f:
    for line in f:
        parts = line.strip().split('\t', 1)
        if len(parts) == 2:
            expanded.append((parts[0], parts[1]))

print(f"Originals: {len(originals):,}, Expanded: {len(expanded):,}")

In [None]:
import re
from tqdm import tqdm

def get_expansions(original, expanded_text):
    """Extract added sentences from expanded text."""
    if not expanded_text.startswith(original):
        return []
    added = expanded_text[len(original):].strip()
    sentences = re.split(r'(?<=[.!?])\s+', added)
    return [s.strip() for s in sentences if len(s.strip()) > 15]

def is_not_contradiction(premise, hypothesis):
    """Return True if hypothesis doesn't contradict premise."""
    result = nli(f"{premise}</s></s>{hypothesis}", truncation=True)
    return result[0]['label'] != 'contradiction'

In [None]:
output_path = '/content/drive/MyDrive/hqf_de/validated_100k.tsv'

kept, removed = 0, 0

with open(output_path, 'w') as out:
    for doc_id, exp_text in tqdm(expanded):
        orig = originals.get(doc_id, "")
        if not orig:
            out.write(f"{doc_id}\t{exp_text}\n")
            continue
        
        expansions = get_expansions(orig, exp_text)
        valid = [e for e in expansions if is_not_contradiction(orig, e)]
        
        kept += len(valid)
        removed += len(expansions) - len(valid)
        
        final = orig + (" " + " ".join(valid) if valid else "")
        out.write(f"{doc_id}\t{final}\n")

print(f"\nKept: {kept:,}, Removed: {removed:,}")
print(f"Keep rate: {kept / (kept + removed) * 100:.1f}%" if kept + removed > 0 else "")