In [None]:
import json
import random
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses, InputExample, evaluation

class EvidenceClaimRetriever:
    def __init__(self, model_name='all-mpnet-base-v2', batch_size=16, num_epochs=3):
        self.model = SentenceTransformer(model_name)
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.evidence_map = {}
        self.label_mapping = {
            'SUPPORTS': 1,
            'REFUTES': 1,
            'DISPUTED': 1,
            'NOT_ENOUGH_INFO': 0
        }

    def load_data(self, claims_path, evidence_path):
        """Load training data and evidence corpus"""
        # Load evidence corpus
        with open(evidence_path) as f:
            self.evidence_map = json.load(f)

        # Load training claims
        with open(claims_path) as f:
            train_claims = json.load(f)

        # Create training examples
        self.train_examples = []
        for claim_id, claim_data in train_claims.items():
            self._process_claim(claim_data)

    def _process_claim(self, claim_data):
        """Create training examples for a single claim"""
        claim_text = claim_data['claim_text']
        evidence_ids = claim_data['evidences']

        # Add positive examples
        for eid in evidence_ids:
            if eid in self.evidence_map:
                self.train_examples.append(InputExample(
                    texts=[claim_text, self.evidence_map[eid]],
                    label=self.label_mapping[claim_data['claim_label']]
                ))

        # Add negative examples
        self._add_negative_examples(claim_text, evidence_ids, claim_data['claim_label'])

    def _add_negative_examples(self, claim_text, positive_ids, label):
        """Add negative examples with proper ratio"""
        all_evidence_ids = list(self.evidence_map.keys())
        negative_candidates = list(set(all_evidence_ids) - set(positive_ids))

        # Determine number of negatives based on label
        if label == 'NOT_ENOUGH_INFO':
            num_negatives = min(2, len(negative_candidates))
        else:
            num_negatives = min(3 * len(positive_ids), len(negative_candidates))

        # Add selected negatives
        for eid in random.sample(negative_candidates, num_negatives):
            self.train_examples.append(InputExample(
                texts=[claim_text, self.evidence_map[eid]],
                label=0
            ))

    def load_dev_data(self, dev_claims_path):
        """Load separate dev set for validation"""
        with open(dev_claims_path) as f:
            dev_claims = json.load(f)

        # Prepare for InformationRetrievalEvaluator
        self.dev_queries = {}
        self.relevant_docs = {}

        for claim_id, claim_data in dev_claims.items():
            claim_text = claim_data['claim_text']
            evidence_ids = claim_data['evidences']

            self.dev_queries[claim_id] = claim_text
            self.relevant_docs[claim_id] = set(evidence_ids)

    def train(self, output_dir='./trained_model'):
        """Train using dedicated dev set for validation"""
        # Create evaluator using separate dev set
        dev_evaluator = evaluation.InformationRetrievalEvaluator(
            queries=self.dev_queries,
            corpus=self.evidence_map,
            relevant_docs=self.relevant_docs,
            show_progress_bar=True
        )

        # Create DataLoader directly from InputExamples
        train_dataloader = DataLoader(
            self.train_examples,  # Use the list of InputExamples directly
            shuffle=True,
            batch_size=self.batch_size
        )

        # Use MultipleNegativesRankingLoss
        train_loss = losses.MultipleNegativesRankingLoss(self.model)

        # Training configuration
        self.model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            evaluator=dev_evaluator,
            epochs=self.num_epochs,
            evaluation_steps=100,
            warmup_steps=int(len(self.train_examples) * 0.1),
            output_path=output_dir,
            show_progress_bar=True
        )

    def evaluate(self, claims_path, top_k=5):
        """Full evaluation on any claim set"""
        with open(claims_path) as f:
            claims = json.load(f)

        recall = 0
        for claim_id, claim_data in claims.items():
            # Get relevant evidence IDs
            relevant = set(claim_data['evidences'])

            # Retrieve top evidence
            results = self.retrieve(claim_data['claim_text'], top_k)
            retrieved = {eid for eid, _ in results}

            # Calculate recall
            recall += len(relevant & retrieved) / len(relevant)

        print(f"Recall@{top_k}: {recall / len(claims):.4f}")

    def retrieve(self, claim_text, top_k=5):
        """Retrieve evidence for a claim"""
        claim_embed = self.model.encode(claim_text)
        scores = torch.nn.functional.cosine_similarity(
            torch.tensor(claim_embed),
            torch.tensor(self.evidence_embeddings)
        )
        top_indices = torch.topk(scores, k=top_k).indices.tolist()
        return [(self.evidence_ids[i], scores[i].item()) for i in top_indices]

    def index_evidence(self):
        """Precompute all evidence embeddings"""
        self.evidence_ids = list(self.evidence_map.keys())
        self.evidence_texts = [self.evidence_map[eid] for eid in self.evidence_ids]
        self.evidence_embeddings = self.model.encode(
            self.evidence_texts,
            show_progress_bar=True,
            batch_size=32
        )


    

In [None]:
trainer = EvidenceClaimRetriever(
        model_name='all-MiniLM-L6-v2',
        batch_size=32,
        num_epochs=5
    )


In [None]:

# 1. Load training data and evidence
trainer.load_data(
    claims_path='data/train-claims.json',
    evidence_path='data/evidence.json'
)


In [None]:

# 2. Load separate dev set
trainer.load_dev_data('data/dev-claims.json')

In [None]:


# 3. Precompute evidence embeddings
trainer.index_evidence()

In [None]:


# 4. Train with dev set validation
trainer.train(output_dir='./climate_retriever')
