In [None]:
import json
import random
import pandas as pd
from src.column import Column
from src.experiments.load_experiment import Experiment, load_experiment_df
from src.models.claim_recall import ClaimRecall
from src.models.manual_annotations import EvaluationBatch, Generation, SentenceWithCitations, TargetCitation

used_experiments = []
samples = []

TO_BASE_MAP = {
    Experiment.POST_HOC_LLAMA_8B: Experiment.BASE_LLAMA_8B,
    Experiment.POST_HOC_MISTRAL_7B: Experiment.BASE_MISTRAL_7B,
    Experiment.POST_HOC_SAUL_7B: Experiment.BASE_SAUL_7B,
    Experiment.POST_HOC_LLAMA_70B: Experiment.BASE_LLAMA_70B,
}

def get_claim_recall(e: Experiment, row_number: int):
    if e in TO_BASE_MAP:
        df, _ = load_experiment_df(TO_BASE_MAP[e])
    else: 
        df, _ = load_experiment_df(e)
    row = df.loc[row_number]
    claim_recall = row[Column.CLAIM_RECALL]
    if pd.isna(claim_recall):
        # print(f"Claim recall not found for {e}")
        claim_recall = 0
    else:
        try: 
            claim_recall = ClaimRecall.model_validate_json(claim_recall)
            claim_recall = claim_recall.claim_recall
        except:
            # print(f"Claim recall is invalid for {e}: {claim_recall}")
            claim_recall = 0
    return claim_recall

def generation_from_row(row, e: Experiment):
    if Column.GENERATED_CITATIONS in row:
        generated_citations = row[Column.GENERATED_CITATIONS]
        generated_citations = json.loads(generated_citations)
        generated_citations = [SentenceWithCitations.model_validate(c) for c in generated_citations]
    else:
        generated_citations = None

    # print(e, row[Column.GENERATED_ANSWER])
    return Generation(
        experiment=e.value,
        answer=row[Column.GENERATED_ANSWER],
        sentences_with_citations=generated_citations,
    )

def get_random_samples_for_annotation(experiments: list[Experiment]):
    r = random.randint(0, 1100)
    while r in used_experiments or r in {544, 161, 806, 304, 688, 178, 957, 323, 325, 336, 469, 733, 94, 872, 123, 237, 244, 378, 635}:
        r = random.randint(0, 1100)

    rows = []
    for e in experiments:
        df, _ = load_experiment_df(e)
        # print(len(df), r, e)
        row = df.iloc[r]
        if not Column.GENERATED_CITATIONS in row or not Column.GENERATED_ANSWER in row or pd.isna(row[Column.GENERATED_ANSWER]):
            return
        
        rows.append(row)

    question = rows[0][Column.QUESTION]
    target_answer = rows[0][Column.TARGET_ANSWER]
    target_citations = rows[0][Column.TARGET_CITATIONS]
    target_citations = json.loads(target_citations)
    target_citations = [TargetCitation.model_validate(c) for c in target_citations]
    
    generations = [generation_from_row(row, e) for row, e in zip(rows, experiments)]

    # have a requirement of 2 citations per answer
    for g in generations:
        nr_citations = 0
        for s in g.sentences_with_citations:
            nr_citations += len(s.citations)

        if nr_citations < 2:
            # print("Not enough citations")
            return
            
    # we want to have different claim recall for each
    claim_recalls = [get_claim_recall(e, r) for e in experiments]
    # make sure they have at least 3 different claim recalls
    if len(set(claim_recalls)) < 2:
        # print(f"Not enough different claim recalls: {claim_recalls}")
        return
    # print("Claim recalls", claim_recalls)

    evaluation_batch = EvaluationBatch(
        axis="approach",
        question_number=r,
        question=question,
        answer=target_answer,
        citations=target_citations,
        generations=generations,
        annotation=None,
    )
    samples.append(evaluation_batch.model_dump())
    used_experiments.append(r)

for i in range(1000):
    experiments = [Experiment.POST_HOC_MISTRAL_7B, Experiment.RAG_GTR_k_10_MISTRAL_7B, Experiment.LLATRIEVAL_GTR_k_10_MISTRAL_7B, Experiment.RARR_MISTRAL_7B]
    # experiments = [Experiment.RAG_GTR_k_10_MISTRAL_7B, Experiment.RAG_GTR_k_10_SAUL_7B, Experiment.RAG_GTR_k_10_LLAMA_8B, Experiment.RAG_GTR_k_10_LLAMA_70B]
    get_random_samples_for_annotation(experiments)
    # print(len(samples))
    if len(samples) == 1:
        break

print(json.dumps(samples[0:1], indent=4))