In [None]:
from groq import Client
api_key = ""
client = Client(api_key=api_key)

In [None]:
def prompt_template(content):
    return [{"role": "user", "content": content}]

In [None]:
from faithfulness import faithfulness_prompt
from hallucination import hallucination_prompt
from answer_relevancy import answer_relevancy_prompt
from ratelimit import limits, sleep_and_retry

RATE_LIMIT = 30
TIME_PERIOD = 180

@sleep_and_retry
@limits(calls=RATE_LIMIT, period=TIME_PERIOD)
def api_call(client, message, model):
    response = client.chat.completions.create(messages=message, model=model)
    return response.choices[0].message.content


def faithfulness(judge, query, context, output):
    judge_statement = judge.chat.completions.create(
        messages=prompt_template(faithfulness_prompt(context, query, output)),
        model="llama3-70b-8192",
        temperature=0.2,
    )
    return judge_statement.choices[0].message.content

def hallucination(judge, query, context, output):
    judge_statement = judge.chat.completions.create(
        messages=prompt_template(hallucination_prompt(context, query, output)),
        model="llama3-70b-8192",
        temperature=0.2,
    )
    return judge_statement.choices[0].message.content


def answer_relevancy(judge, query, output):
    judge_statement = judge.chat.completions.create(
        messages=prompt_template(answer_relevancy_prompt(query, output)),
        model="llama3-70b-8192",
        temperature=0.2,
    )
    return judge_statement.choices[0].message.content

def completeness(judge, query, context, output):
    judge_statement = judge.chat.completions.create(
        messages=prompt_template(completeness_prompt(context, query, output)),
        model="llama3-70b-8192",
        temperature=0.2,
    )
    return judge_statement.choices[0].message.content

In [None]:
from datasets import load_dataset

dataset = load_dataset("rungalileo/ragbench", "tatqa")

In [None]:
import pandas as pd
df = pd.DataFrame(dataset["train"])

### `tokenSHAP` algorithm

The `tokenSHAP` function performs SHAP (SHapley Additive exPlanations) analysis on tokens in a prompt to determine their contribution to the model's output. Below is a mathematical breakdown of the process:

**Input Representation**:
    - The input consists of a `prompt` and a `context`. These are combined into a single string:
      $$
      \text{full\_input} = \text{"Context: "} + \text{context} + \text{"Prompt: "} + \text{prompt}
      $$
    - The input is tokenized into `input_ids` using a tokenizer:
      $$
      \text{input\_ids} = \text{tokenizer}(\text{full\_input})
      $$

**Embedding Extraction**:
    - The embeddings for the input tokens are extracted from the model:
      $$
      \mathbf{E} = \text{model.get\_input\_embeddings()}(\text{input\_ids})
      $$
    - The baseline embedding, $\mathbf{b}$, is computed as the mean of all token embeddings:
      $$
      \mathbf{b} = \frac{1}{n} \sum_{i=1}^n \mathbf{E}_i
      $$
      where $n$ is the number of tokens in the prompt.

**Token Removal and Cosine Similarity**:
    - For each token $i$ in the prompt, a modified prompt is created by removing the $i$-th token:
      $$
      \text{modified\_prompt} = \text{prompt without token } i
      $$
    - The embeddings for the modified prompt are computed:
      $$
      \mathbf{E}_{\text{modified}} = \text{model.get\_input\_embeddings()}(\text{modified\_input\_ids})
      $$
    - The cosine similarity between the mean embedding of the modified prompt and the baseline embedding is calculated:
      $$
      \text{cosine\_similarity} = \frac{\mathbf{E}_{\text{modified}} \cdot \mathbf{b}}{\|\mathbf{E}_{\text{modified}}\| \|\mathbf{b}\|}
      $$
    - The similarity score is stored in `withouts` for the removed token and in `withes` for the remaining tokens.

**SHAP Value Calculation**:
    - For each token $i$, the SHAP value is computed as the difference between the average similarity when the token is included (`withes`) and the similarity when the token is excluded (`withouts`):
      $$
      \text{SHAP}_i = \frac{\sum \text{withes}[i]}{\text{len(withes}[i])} - \text{withouts}[i]
      $$

**Normalization**:
    - The SHAP values are normalized to ensure they sum to 1. The normalization factor is computed as:
      $$
      \text{norm\_factor} = \sqrt{\sum_{i=1}^n \text{SHAP}_i^2}
      $$
    - Each SHAP value is divided by the normalization factor:
      $$
      \text{SHAP}_i = \frac{\text{SHAP}_i}{\text{norm\_factor}}
      $$

    - The function returns a dictionary where each token in the prompt is mapped to its normalized SHAP value:
      $$
      \text{output} = \{ \text{token}_i : \text{SHAP}_i \}
      $$


In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
from collections import defaultdict

def tokenSHAP(prompt, context, model_name="PrunaAI/SweatyCrayfish-llama-3-8b-quantized-bnb-4bit-smashed"):
    """
    Performs token SHAP analysis on an initial query using a distilled Llama 2 model.

    Args:
        prompt (str): The initial query.
        context (str): The context for the query.
        model_name (str, optional): .

    Returns:
        dict: A dictionary where keys are the words from the prompt and values are their corresponding SHAP values.
    """

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)

    # Preprocess input
    full_input = f"Context: {context}\nPrompt: {prompt}"
    input_ids = tokenizer(full_input, return_tensors="pt").input_ids

    # Get model embeddings
    with torch.no_grad():
        embeddings = model.get_input_embeddings()(input_ids).squeeze(0)

    # Token SHAP analysis
    words = prompt.split(" ")
    b = embeddings.mean(dim=0).detach().numpy()  # Average embedding as baseline
    withouts, withes = {}, defaultdict(list)
    n = len(words)

    for i in range(n):
        modified_prompt = " ".join(words[:i] + words[i+1:])
        modified_input = f"Context: {context}\nPrompt: {modified_prompt}"
        modified_input_ids = tokenizer(modified_input, return_tensors="pt").input_ids
        with torch.no_grad():
            modified_embeddings = model.get_input_embeddings()(modified_input_ids).squeeze(0)
        cs = cosine_similarity(modified_embeddings.mean(dim=0).detach().numpy().reshape(1, -1), b.reshape(1, -1))
        withouts[words[i]] = cs[0][0]
        for j in range(i):
            withes[words[j]].append(cs[0][0])
        for j in range(i+1, n):
            withes[words[j]].append(cs[0][0])

    # Calculate and normalize SHAP values
    shaps = {}
    for word in words:
        withs_avg = sum(withes[word]) / len(withes[word]) if withes[word] else 0
        shaps[word] = (withs_avg - withouts[word])
    shap_values = list(shaps.values())
    norm_factor = np.sqrt(np.sum(np.array(shap_values)**2))
    for word in shaps:
        shaps[word] /= norm_factor

    return shaps

In [None]:
import json
from tqdm import tqdm
import numpy as np
import shutil

STEP = 25
BATCHES = 4

for batch_num in range(BATCHES):
    batch_data = {
        "ids": [],
        "questions": [],
        "documents": [],
        "faithfulness": []
    }
    shap_data = {
        "id": [],
        "answer_idx": [],
        "token": [],
        "shap_value": []
    }

    for i, (source_id, group) in enumerate(tqdm(df.groupby("id"))):
        if not (51 < i <= 450):
            continue

        id = group["id"].iloc[0]
        question = group["question"].iloc[0]
        context = group["documents"].iloc[0]
        responses = group["response"].tolist()

        # Process hallucination
        raw_json = faithfulness(client, question, context, responses)
        raw_json = raw_json[raw_json.find("{"): raw_json.rfind("}") + 1]

        try:
            parsed_json = json.loads(raw_json)
            faithfulnesses = [
                parsed_json.get(f"answer_{j}", {}).get("faithfulness", None)
                for j in range(len(group))
            ]
            # Process SHAP for hallucinated responses
            for j in range(len(group)):
                if faithfulnesses[j] == 0:

                    shap_output = tokenSHAP(question, context)

                    # Structure SHAP data
                    for token, value in shap_output.items():
                        shap_data["id"].append(id)
                        shap_data["answer_idx"].append(j)
                        shap_data["token"].append(token)
                        shap_data["shap_value"].append(value)
        except json.JSONDecodeError:
            print(f"Error decoding JSON for group {id}: {raw_json}")
            faithfulnesses = [None] * len(responses)

        # Batch data collection
        batch_data["ids"].extend([id]*len(group))
        batch_data["questions"].extend([question]*len(group))
        batch_data["documents"].extend(
            [json.dumps(context) if isinstance(context, list) else context]*len(group)
        )
        batch_data["faithfulness"].extend(faithfulnesses)

        # Periodic saving
        if i % STEP == 0 and i > 0:
            print(f"Processed {i} groups. Current batch size: {len(batch_data['ids'])}")

            res_df = pd.DataFrame({
                "id": batch_data["ids"],
                "question": batch_data["questions"],
                "documents": batch_data["documents"],
                "faithfulness": batch_data["faithfulness"]
            })
            res_df.to_parquet(f"hallucination_{(i//STEP)-1}_{batch_num}.parquet")
            shap_df = pd.DataFrame(shap_data)
            shap_df = shap_df.pivot_table(
                index=["id", "answer_idx"],
                columns="token",
                values="shap_value",
                aggfunc="first"
            ).reset_index()
            shap_df.to_parquet(f"shap_results_{(i//STEP)-1}_{batch_num}.parquet")
            # shutil.move(f"/content/shap_results_{(i//STEP)-1}_{batch_num}.parquet", f"/content/drive/shap_results_{(i//STEP)-1}_{batch_num}.parquet")
            # Reset batch containers
            batch_data = {k: [] for k in batch_data}

    # Final save for the batch
    if batch_data["ids"]:
        res_df = pd.DataFrame(batch_data)
        res_df.to_parquet(f"hallucination_final_{batch_num}.parquet")

    if shap_data["id"]:
        shap_df = pd.DataFrame(shap_data)
        shap_df = shap_df.pivot_table(
            index=["id", "answer_idx"],
            columns="token",
            values="shap_value",
            aggfunc="first"
        ).reset_index()
        shap_df.to_parquet(f"shap_results_batch_{batch_num}.parquet")

    print(f"Batch {batch_num} completed. Total processed: {len(batch_data['ids'])}")
