## **NOTE**:

1. This notebook submission does not include outputs: the distillation process can take > 8 hours to complete a single dataset (we used two - squad and fever). This notebook was revised after the distillation dataset creation to provide better structure to the reader.

2. The datasets are available on huggingface. Please be advised that running this will require a gemini account and will accure monetary cost per request.

## Installations and Imports

In [None]:
%%capture
!pip install datasets==3.6.0

In [None]:
from google.colab import drive, userdata
from google import genai
from google.genai import types
from google.genai.errors import APIError
from typing import List, Dict, Any, Optional, Set
import time
from random import choice
import requests
import os
import json
from pydantic import BaseModel, Field, constr
import textwrap
import pandas as pd
import numpy as np
import random
from datasets import load_dataset, Dataset, DatasetDict
from tqdm.auto import tqdm

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
os.environ['GEMINI_API_KEY'] = userdata.get('GEMINI_API_KEY')

## Helper Functions

In [None]:
def checkpoint(dataset: pd.DataFrame, path: str, filename: str):
    os.makedirs(path, exist_ok=True)
    dataset.to_csv(
        f"{path}/{filename}",
        mode='a',
        header=not os.path.exists(f"{path}/{filename}") or os.path.getsize(f"{path}/{filename}") == 0,
        index=False
    )

def load_checkpoint(path: str, filename: str) -> Optional[pd.DataFrame]:
    filepath = f"{path}/{filename}"
    if os.path.exists(filepath):
        print(f"Checkpoint found! Loading data from {filepath}")
        return pd.read_csv(filepath)
    print(f"No checkpoint found at {filepath}. Starting fresh.")
    return None

def initialize_checkpoint_df(
    checkpoint_df: Optional[pd.DataFrame],
    expected_columns: List[str]
) -> pd.DataFrame:
    if checkpoint_df is None or checkpoint_df.empty:
        return pd.DataFrame(columns=expected_columns)
    return checkpoint_df

## [FEVER] Knowledge Distillation

In [None]:
## Used to keep training and test sets consistent
def prepare_fever():
  fever_datasets = load_dataset("fever", "v1.0", trust_remote_code=True)
  train_dataset = fever_datasets['train']
  target_labels = ["SUPPORTS", "REFUTES"]
  filtered_train_dataset = train_dataset.filter(
      lambda example: example['label'] in target_labels
  )
  split_datasets = filtered_train_dataset.train_test_split(
    test_size=0.2,
    seed=42
    )
  final_dataset_dict = DatasetDict({
      'train': split_datasets['train'],
      'test': split_datasets['test']
  })
  final_dataset_dict.push_to_hub("FEVER")

### Definitions

In [None]:
class FEVERDistillationPoint(BaseModel):
    claim: constr(min_length=1) = Field(description="The original FEVER claim statement.")
    label: constr(min_length=1) = Field(description="The original FEVER label: 'SUPPORTS' or 'REFUTES'.")
    context: constr(min_length=1) = Field(description="The comprehensive evidence/document that fully and explicitly verifies or contradicts the claim.")
    rationale: constr(min_length=1) = Field(description="A step-by-step reasoning/rationale generated by the teacher model explaining the verdict based on the context.")
    verdict: bool = Field(description="True if the claim is SUPPORTS, False if REFUTES.")

class FEVERDistilledBatch(BaseModel):
    results: List[FEVERDistillationPoint] = Field(description="A list containing the structured fact-checking analysis for every claim in the input.")

In [None]:
FEVER_OUTPUT_COLUMNS = ['original_fever_id', 'claim', 'label', 'context', 'rationale', 'verdict']

### Helper Functions

In [None]:
def _construct_fever_prompt_and_config(
    fever_batch_data: List[Dict[str, str]],
    teacher_model_name: str
) -> tuple[str, str, types.GenerateContentConfig]:

    system_instruction = textwrap.dedent("""
        You are an expert fact-checking training data synthesizer. Your task is to generate the appropriate 'context' (evidence) and a 'rationale' for each provided FEVER claim, focusing only on 'SUPPORTS' and 'REFUTES' labels.

        **FOR 'SUPPORTS' CLAIMS:** Generate a complete paragraph for the 'context' that **fully and explicitly verifies** the claim. The 'rationale' must clearly state how the context proves the claim. Set 'verdict' to true.
        **FOR 'REFUTES' CLAIMS:** Generate a complete paragraph for the 'context' that **fully and explicitly contradicts** the claim. The 'rationale' must clearly state which specific fact in the context refutes the claim. Set 'verdict' to false.

        All generated contexts must be coherent and informative, and all claims, labels, and rationales must be non-empty strings.
    """).strip()

    input_entries = []
    for i, entry in enumerate(fever_batch_data):
        entry_text = textwrap.dedent(f"""
            --- FEVER Entry {i+1} ---
            CLAIM: {entry.get('claim', 'N/A')}
            ORIGINAL LABEL: {entry.get('label', 'N/A')}
        """).strip()
        input_entries.append(entry_text)

    output_schema_dict = FEVERDistilledBatch.model_json_schema()

    prompt = (
        f"Process the following {len(fever_batch_data)} binary FEVER entries. Your final output MUST be a single JSON object that strictly conforms to the FEVERDistilledBatch schema. "
        "The 'results' list must contain one fully generated entry for each input claim. Do not include any text outside the JSON object.\n"
        + "\n\n".join(input_entries)
    )

    config = types.GenerateContentConfig(
            system_instruction=system_instruction,
            response_mime_type="application/json",
            response_schema=output_schema_dict,
            temperature=0.7,
        )

    return system_instruction, prompt, config

In [None]:
def generate_fever_distillation_data(
    fever_batch_data: List[Dict[str, str]],
    teacher_model_name: str = 'gemini-2.5-pro'
) -> List[Dict[str, Any]]:

    if not fever_batch_data:
        return []

    system_instruction, prompt, config = _construct_fever_prompt_and_config(
        fever_batch_data, teacher_model_name
    )

    print(f"-> Sending request for {len(fever_batch_data)} entries to teacher model ({teacher_model_name})...")

    client = genai.Client()

    try:
      response = client.models.generate_content(
          model=teacher_model_name,
          contents=[system_instruction, prompt],
          config=config
      )

      return json.loads(response.text)

    except Exception as e:
      print(f"API Call Validation Exception: {e}. Returning empty results for this batch.")
      return {'results': []}

In [None]:
def _process_and_flatten_fever_output(
    distilled_batch_output: Dict[str, List[Dict[str, Any]]], # Dictionary with 'results' list
    batch_df: pd.DataFrame
) -> pd.DataFrame:

    if 'results' not in distilled_batch_output or not distilled_batch_output['results']:
        print("API response missing 'results' list or list is empty.")
        return pd.DataFrame(columns=FEVER_OUTPUT_COLUMNS)

    claims_list = distilled_batch_output['results']
    batch_results = []
    for i, distilled_point in enumerate(claims_list):

        if i >= len(batch_df):
            print("Error: Too many claims returned by API. Stopping processing.")
            break

        original_fever_id = batch_df.iloc[i]['id']

        try:
            batch_results.append({
                'original_fever_id': original_fever_id,
                'claim': distilled_point['claim'],
                'label': distilled_point['label'],
                'context': distilled_point['context'],
                'rationale': distilled_point['rationale'],
                'verdict': distilled_point['verdict']
            })

        except Exception as e:
            print(f"Data Structure Error for FEVER ID {original_fever_id}. Skipping claim. Error: {e}")
            continue

    return pd.DataFrame(batch_results, columns=FEVER_OUTPUT_COLUMNS)

In [None]:
def distill_fever_to_dataset(
    input_dataset: Dataset,
    batch_size: int = 50,
    teacher_model_name: str = 'gemini-2.5-flash',
    checkpoint_dir: str = 'distillation_checkpoints',
    checkpoint_filename: str = 'fever_distilled_data.csv'
) -> pd.DataFrame:

    if not input_dataset:
        return pd.DataFrame()

    df_full = input_dataset.to_pandas()

    df_binary = df_full[df_full['label'].isin(['SUPPORTS', 'REFUTES'])].copy()
    print(f"Filtered out 'NOT ENOUGH INFO'. {len(df_binary)} claims remaining for distillation.")

    current_checkpoint_df = initialize_checkpoint_df(
        load_checkpoint(checkpoint_dir, checkpoint_filename), FEVER_OUTPUT_COLUMNS
    )

    processed_fever_ids: Set[int] = set(current_checkpoint_df['original_fever_id'].astype(int).tolist())
    print(f"Current checkpoint size (claims): {len(current_checkpoint_df)}")

    df_unprocessed = df_binary[~df_binary['id'].isin(processed_fever_ids)].copy()
    print(f"Removed {len(df_binary) - len(df_unprocessed)} already processed FEVER IDs.")

    if df_unprocessed.empty:
        print("All FEVER samples already processed. Returning final dataset.")
        return current_checkpoint_df

    df_unprocessed = df_unprocessed.sample(frac=1).reset_index(drop=True)
    total_samples = len(df_unprocessed)
    batch_indices = np.array_split(df_unprocessed.index, np.ceil(total_samples / batch_size))
    num_batches = len(batch_indices)
    print(f"FEVER IDs remaining: {total_samples} in {num_batches} batches.")

    for i, indices in tqdm(enumerate(batch_indices), total=num_batches, desc="Distilling FEVER Batches"):
        batch_df = df_unprocessed.loc[indices]
        fever_batch_input = batch_df[['claim', 'label']].to_dict('records')

        try:
            distilled_batch_output = generate_fever_distillation_data(
                fever_batch_data=fever_batch_input,
                teacher_model_name=teacher_model_name
            )
            batch_results_df = _process_and_flatten_fever_output(
                distilled_batch_output, batch_df
            )
            checkpoint(batch_results_df, checkpoint_dir, checkpoint_filename)

        except Exception as e:
            print(f"Critical Error in FEVER Batch {i+1}. Halting distillation. Error: {e}")
            break

    try:
        final_df = pd.read_csv(f"{checkpoint_dir}/{checkpoint_filename}")
        final_df = final_df.drop_duplicates(subset=['claim']).reset_index(drop=True)
        print(f"FEVER Distillation finished. Total final processed samples: {len(final_df)}")
        return final_df
    except FileNotFoundError:
        print("Final checkpoint file not found. Returning empty DataFrame.")
        return pd.DataFrame()

### Execution

In [None]:
CHECKPOINT_PATH_FULL = '/content/drive/MyDrive/checkpoints/FactGuard/fever_distilled_data.csv'
CHECKPOINT_DIR = '/content/drive/MyDrive/checkpoints/FactGuard/'
CHECKPOINT_FILENAME = 'fever_distilled_data.csv'

In [None]:
# Use this to keep consistency across TRAIN / TEST splits.
fever_datasets = load_dataset("rickpereira/FEVER")

In [None]:
distilled_dataset = distill_fever_to_dataset(
    fever_datasets['train'],
    checkpoint_dir = CHECKPOINT_DIR,
    checkpoint_filename = CHECKPOINT_FILENAME)

### Push to HF

In [None]:
fever_dataset = load_checkpoint(CHECKPOINT_DIR, CHECKPOINT_FILENAME)
fever_datasets = Dataset.from_pandas(fever_dataset)
fever_dataset_dict = DatasetDict({
    "train": fever_datasets
})
fever_dataset_dict.push_to_hub("factguard_fever_distilled_datasets")

## [SQuAD] Knowledge Distillation

### Definitions

In [None]:
FINAL_OUTPUT_COLUMNS = ['original_squad_id', 'claim', 'label', 'rationale', 'context']

In [None]:
class FactCheckDataPoint(BaseModel):
    claim: constr(min_length=1) = Field(description="The generated statement provided for fact-checking.")
    label: constr(min_length=1) = Field(description="The final classification of the claim's accuracy: 'TRUE' or 'FALSE'.")
    rationale: constr(min_length=1) = Field(description="A concise, direct, and single-sentence explanation of why the claim is true or false, citing the relevant part of the CONTEXT.")
    verdict: bool = Field(description="True if the claim is factually correct, False otherwise.")

class BatchFactCheckResult(BaseModel):
    results: List[FactCheckDataPoint] = Field(description="A list containing the fact-checking entry.")

### Helper Functions

In [None]:
def _get_unprocessed_squad_data(
    input_dataset: Dataset,
    checkpoint_dir: str,
    checkpoint_filename: str
) -> tuple[pd.DataFrame, pd.DataFrame]:

    df_full = input_dataset.to_pandas()
    df_full['squad_answer'] = df_full['answers'].apply(lambda x: x.get('text', [''])[0])

    current_checkpoint_df = initialize_checkpoint_df(
        load_checkpoint(checkpoint_dir, checkpoint_filename), FINAL_OUTPUT_COLUMNS
    )

    processed_squad_ids: Set[str] = set(current_checkpoint_df['original_squad_id'].astype(str).tolist())
    print(f"Current checkpoint size (claims): {len(current_checkpoint_df)}")
    df_unprocessed = df_full[~df_full['id'].isin(processed_squad_ids)].copy()
    print(f"Removed {len(df_full[df_full['id'].isin(processed_squad_ids)])} already processed SQuAD IDs.")

    return df_unprocessed, current_checkpoint_df

In [None]:
def _process_and_flatten_batch_output(
    distilled_batch_output: List[Dict[str, Any]],
    batch_df: pd.DataFrame
) -> pd.DataFrame:

    batch_results = []
    input_indices = batch_df.index.tolist()

    for j, flat_claim in enumerate(distilled_batch_output):
        original_squad_index_in_batch = j // 2

        if original_squad_index_in_batch >= len(input_indices):
            print(f"Error: Claim index {j} outside bounds of input batch. Stopping.")
            break

        original_row_index = input_indices[original_squad_index_in_batch]

        original_id = batch_df.loc[original_row_index, 'id']
        original_context = batch_df.loc[original_row_index, 'context']

        try:
            batch_results.append({
                'original_squad_id': original_id,
                'claim': flat_claim['claim'],
                'label': flat_claim['label'],
                'rationale': flat_claim['rationale'],
                'context': original_context,
                'verdict': flat_claim['verdict']
            })

        except Exception as e:
            print(f"Validation Error for SQuAD ID {original_id}. Skipping both claims. Error: {e}")
            continue

    return pd.DataFrame(batch_results, columns=FINAL_OUTPUT_COLUMNS)

In [None]:
def _construct_distillation_prompt_and_config(
    squad_batch_data: List[Dict[str, str]],
    teacher_model_name: str
) -> tuple[str, str, types.GenerateContentConfig]:

    system_instruction = textwrap.dedent("""
        You are an expert fact-checking data generator. For each provided record (Context, Question, Answer), you must generate TWO distinct claims and structure them as a list under the 'results' key:

        1.  **Factual Claim:** A new, single sentence that is a concise factual statement directly supported by the CONTEXT and related to the ANSWER. Set 'label' to 'TRUE' and 'verdict' to true.
        2.  **Misleading Claim:** A new, single sentence that directly contradicts a specific fact, date, or entity explicitly stated in the CONTEXT. Set 'label' to 'FALSE' and 'verdict' to false.

        The 'rationale' for the FALSE claim must state the correct fact from the context. Ensure all claims, labels, and rationales are non-empty strings.
    """).strip()

    input_entries = []
    for i, entry in enumerate(squad_batch_data):
        entry_text = textwrap.dedent(f"""
            --- Entry {i+1} ---
            CONTEXT: {entry.get('context', 'N/A')}
            QUESTION: {entry.get('question', 'N/A')}
            CORRECT ANSWER: {entry.get('answer', 'N/A')}
        """).strip()
        input_entries.append(entry_text)

    output_schema_dict = BatchFactCheckResult.model_json_schema()

    prompt = (
        f"Process the following {len(squad_batch_data)} entries. Your final output MUST be a single JSON object that strictly conforms to the BatchFactCheckResult schema (which contains a list of claims). "
        "The 'results' list must contain **exactly two** claims per entry: one TRUE and one FALSE. Do not include any text outside the JSON object.\n"
        + "\n\n".join(input_entries)
    )

    config = types.GenerateContentConfig(
            system_instruction=system_instruction,
            response_mime_type="application/json",
            response_schema=output_schema_dict,
            temperature=0.1,
        )

    return system_instruction, prompt, config

In [None]:
def generate_squad_distillation_data(
    squad_batch_data: List[Dict[str, str]],
    teacher_model_name: str = 'gemini-2.5-pro'
) -> List[Dict[str, Any]]:
    if not squad_batch_data:
        return []

    system_instruction, prompt, config = _construct_distillation_prompt_and_config(
        squad_batch_data, teacher_model_name
    )

    print(f"-> Sending request for {len(squad_batch_data)} entries to teacher model ({teacher_model_name})...")

    client = genai.Client()
    response = client.models.generate_content(
        model=teacher_model_name,
        contents=[system_instruction, prompt],
        config=config
    )

    return json.loads(response.text)

In [None]:
def distill_squad_to_dataset(
    input_dataset: Dataset,
    batch_size: int = 50,
    teacher_model_name: str = 'gemini-2.5-flash',
    checkpoint_dir: str = 'distillation_checkpoints',
    checkpoint_filename: str = 'squad_distilled_data.csv'
) -> pd.DataFrame:
    if not input_dataset:
        return pd.DataFrame()

    df_unprocessed, current_checkpoint_df = _get_unprocessed_squad_data(
        input_dataset, checkpoint_dir, checkpoint_filename
    )

    if df_unprocessed.empty:
        print("All SQuAD samples already processed. Returning final dataset.")
        return current_checkpoint_df

    df_unprocessed = df_unprocessed.sample(frac=1).reset_index(drop=True)
    total_samples = len(df_unprocessed)
    batch_indices = np.array_split(df_unprocessed.index, np.ceil(total_samples / batch_size))
    num_batches = len(batch_indices)
    print(f"SQuAD IDs remaining: {total_samples} in {num_batches} batches.")

    for i, indices in tqdm(enumerate(batch_indices), total=num_batches, desc="Distilling SQuAD Batches"):
        batch_df = df_unprocessed.loc[indices]

        squad_batch_input = batch_df[[
            'context', 'question', 'squad_answer'
        ]].rename(columns={'squad_answer': 'answer'}).to_dict('records')

        try:
            distilled_batch_output = generate_squad_distillation_data(
                squad_batch_data=squad_batch_input,
                teacher_model_name=teacher_model_name
            )

            if not distilled_batch_output:
                print(f"Warning in Batch {i+1}: API returned empty results. Skipping.")
                continue

            batch_results_df = _process_and_flatten_batch_output(
                distilled_batch_output['results'], batch_df
            )

            checkpoint(batch_results_df, checkpoint_dir, checkpoint_filename)

        except Exception as e:
            print(f"Critical Error in SQuAD Batch {i+1}. Halting distillation. Error: {e}")
            break

    try:
        final_df = pd.read_csv(f"{checkpoint_dir}/{checkpoint_filename}")
        final_df = final_df.drop_duplicates(subset=['claim']).reset_index(drop=True)
        print(f"SQuAD Distillation finished. Total final processed samples: {len(final_df)}")
        return final_df
    except FileNotFoundError:
        print("Final checkpoint file not found. Returning empty DataFrame.")
        return pd.DataFrame()

### Execution

In [None]:
CHECKPOINT_PATH_FULL = '/content/drive/MyDrive/checkpoints/FactGuard/squad_distilled_data.csv'
CHECKPOINT_DIR = '/content/drive/MyDrive/checkpoints/FactGuard/'
CHECKPOINT_FILENAME = 'squad_distilled_data.csv'

In [None]:
squad_dataset = load_dataset("squad")
train_split = squad_dataset['train']
eval_split = squad_dataset['validation']

In [None]:
distilled_dataset = distill_squad_to_dataset(
    train_split,
    checkpoint_dir = CHECKPOINT_DIR,
    checkpoint_filename = CHECKPOINT_FILENAME)

### Push to HF

In [None]:
squad_dataset = load_checkpoint(CHECKPOINT_DIR, CHECKPOINT_FILENAME)
squad_datasets = Dataset.from_pandas(squad_dataset)
squad_dataset_dict = DatasetDict({
    "train": squad_datasets
})
squad_dataset_dict.push_to_hub("factguard_squad_distilled_datasets")