In [None]:
!pip install flashrag-dev --pre --upgrade
!pip install -U datasets
!pip install sentence-transformers==3.4.1
!pip install vllm==0.7.3

In [None]:
!pip install flashinfer-python==0.2.5
!pip install litellm==1.68.2
!pip install numpy==1.26.4

In [None]:
!python3 -m nltk.downloader stopwords
!python3 -m spacy download en_core_web_sm

### Pull code from GitHub

In [None]:
!rm -rf sample_data .config
!git clone https://github.com/th-chew/cs605_project .

### You will need to create a Hugging Face account to download the Llama 3.2 models. After creating an account, generate a token to be used below.

In [None]:
!huggingface-cli login --token hf_4ZkXJYbqjWQdVZcXzYwzYwzYwzYwzYwzYwz

In [None]:
!huggingface-cli download meta-llama/Llama-3.2-3B-Instruct --local-dir meta-llama/Llama-3.2-3B-Instruct

In [None]:
!huggingface-cli download loganchew/rbft_llama3 --local-dir loganchew/rbft_llama3

### Restart session after installing the packages because of numpy

In [None]:
import json, argparse, datasets
import numpy as np
from dataclasses import dataclass
from eval.prompt import NaivePromptTemplate, MCQPromptTemplate
from eval.metrics import metric_dict
from flashrag.config import Config
from flashrag.utils import get_generator
from sentence_transformers import CrossEncoder


def load_data(data_path, args):
    def load_psgs(item):
        item['psgs'] = positive_psgs[item['qid']][:args.topk]

        # Attack
        if args.attack_position == 'random':
            sample_value = np.random.rand(len(item['psgs']))
            for id, v in enumerate(sample_value):
                if v <= args.tau: # random mode
                    cur_attack = np.random.choice(['neg', 'nsy', 'cf']) \
                        if args.passage_attack == 'mix' \
                            else args.passage_attack
                    item['psgs'][id] = psgs_dict[cur_attack][item['qid']][id]

        elif args.attack_position == 'top':
            attack_topk = round(args.tau * len(item['psgs']))
            for id in range(len(item['psgs'])):
                if id < attack_topk: # top mode, only attack top psgs
                    cur_attack = np.random.choice(['neg', 'nsy', 'cf']) \
                        if args.passage_attack == 'mix' \
                            else args.passage_attack
                    item['psgs'][id] = psgs_dict[cur_attack][item['qid']][id]
                else: # neglect bottom psgs
                    break

        elif args.attack_position == 'bottom':
            keep_topk = round((1 - args.tau) * len(item['psgs']))
            for id in range(len(item['psgs'])):
                if id < keep_topk: # bottom mode, neglect top psgs
                    continue
                else: # only attack bottom psgs
                    cur_attack = np.random.choice(['neg', 'nsy', 'cf']) \
                        if args.passage_attack == 'mix' \
                            else args.passage_attack
                    item['psgs'][id] = psgs_dict[cur_attack][item['qid']][id]

        else:
            raise NotImplementedError


        return item

    dataset = datasets.load_dataset(
        'json',
        data_files = data_path + "/sample.json",
    )['train']


    with open(data_path + "/posp.json") as fr:
        positive_psgs = {}
        for line in fr:
            line = json.loads(line)
            positive_psgs[line['qid']] = line['pos_psgs']

    neg_psgs, nsy_psgs, cf_psgs = {}, {}, {}
    if args.passage_attack == 'neg' or args.passage_attack == 'mix':
        with open(data_path + "/negp.json") as fr:
            for line in fr:
                line = json.loads(line)
                neg_psgs[line['qid']] = line['neg_psgs']

    if args.passage_attack == 'nsy' or args.passage_attack == 'mix':
        with open(data_path + "/nsyp.json") as fr:
            for line in fr:
                line = json.loads(line)
                nsy_psgs[line['qid']] = line['nsy_psgs']

    if args.passage_attack == 'cf' or args.passage_attack == 'mix':
        with open(data_path + "/cfp.json") as fr:
            for line in fr:
                line = json.loads(line)
                cf_psgs[line['qid']] = line['cf_psgs']

    psgs_dict = {
        "neg" : neg_psgs,
        "nsy" : nsy_psgs,
        "cf" : cf_psgs
    }

    dataset = dataset.map(load_psgs)

    return dataset


def generate(generator, prompt_template, query_dataset, rerank=False, top_k_rerank=5):

    if rerank:
        ranked_psgs = []
        reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')

        for query, psgs in zip(query_dataset['query'], query_dataset['psgs']):
            pairs = [[query, psg] for psg in psgs]

            # Use sentence-transformers CrossEncoder for reranking
            scores = reranker.predict(pairs)
            top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
            psgs = [psgs[i] for i in top_indices[:top_k_rerank]]
            ranked_psgs.append(psgs)

        psgs = ranked_psgs

    else:
        psgs = query_dataset['psgs']

    input_prompts = [
        prompt_template.get_string(
            question=query,
            retrieval_result=psgs,
            multiple_choice="\n".join(mcq_choice),
        ) for query, psgs, mcq_choice in zip(query_dataset['query'], psgs, query_dataset['multiple_choice'])
    ]

    preds = generator.generate(input_prompts)

    return preds


def eval(args):
    config_dict = {"save_note": "eval",
                   "gpu_id": args.gpu_id,
                   "generator_model": args.generator_model
                }
    config = Config(args.config_file, config_dict, )
    print(config)
    np.random.seed(config['seed'])
    dataset = load_data(args.data_path, args)

    dataset = dataset.filter(lambda example: example['from'] == 'truthful')

    # print(f"Loaded {len(dataset)} examples from {args.data_path}")

    # Filter first 10 rows from dataset for testing
    # dataset = dataset.select(range(20))
    # print first 10 rows of dataset
    # print(dataset[:10])
    
    scorers = [metric_dict[metric](config) for metric in config['metrics']]

    use_robustrag = args.use_robustrag
    rerank = args.rerank
    topk_rerank = args.topk_rerank

    if use_robustrag:
        from src.models import create_model
        from src.defense import KeywordAgg

        preds = []
        print("Using RobustRAG")
        llm = create_model(args.model_name,cache_path=None)
        model = KeywordAgg(llm, relative_threshold=args.alpha, absolute_threshold=args.beta, longgen=False, certify_save_path='')

        if rerank:
            ranked_psgs = []
            reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')
            for query, psgs in zip(dataset['query'], dataset['psgs']):
                pairs = [[query, psg] for psg in psgs]
                scores = reranker.predict(pairs)
                top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
                psgs = [psgs[i] for i in top_indices[:topk_rerank]]
                ranked_psgs.append(psgs)
            dataset_psgs = ranked_psgs
        else:
            dataset_psgs = dataset['psgs']

        for query, psgs, mcq_choice in zip(dataset['query'], dataset_psgs, dataset['multiple_choice']):
            # print(mcq_choice)
            query_hints = model.generate_query_hints(query, psgs)
            # print(query_hints)
            query_prompt = llm.wrap_mcq_prompt_flashrag(query, psgs, hints=query_hints, multiple_choice=mcq_choice)                
            response = llm.query(query_prompt)
            # print(response)
            # Extract first letter from response
            response = response[0].strip().lower() if response else ""
            # print(response)
            preds.append(response)

    else:
        generator = get_generator(config)
        prompt_template = MCQPromptTemplate(config)

        preds = generate(generator, prompt_template, dataset, rerank, topk_rerank)

    eval_results = [scorer.calculate_metric(preds, dataset['answer_alphabet'])[0] for scorer in scorers]
    # print(eval_results)

    with open(args.output_file, 'w') as fw:
        fw.write(json.dumps({'result':eval_results})+"\n")
        for i, (q, a, p) in enumerate(zip(dataset['query'], dataset['answer'], preds)):
            fw.write(json.dumps({i:{'query':q, 'answer':a, 'pred':p}})+'\n')

@dataclass
class Args:
    data_path: str = "data/e5/test_data"
    gpu_id: str = "0"
    topk: int = 5
    rerank: bool = False
    topk_rerank: int = 5
    tau: float = 0.0
    config_file: str = "configs/eval.yaml"
    output_file: str = "output/output.txt"
    attack_position: str = "random"
    passage_attack: str = None
    prompt_template: str = None

    # RobustRAG parameters
    model_name: str = 'llama3b'
    defense_method: str = 'keyword'
    alpha: float = 0.3
    beta: float = 3.0
    eta: float = 0.0
    corruption_size: int = 1

    use_robustrag: bool = False



args = Args()
args.config_file = "configs/eval.yaml"
args.prompt_template = None

# Old test
# args.data_path = "data/e5/old_test"
args.data_path = "data/e5/test"
args.gpu_id = "0"
args.topk = 10
# Crossencoder reranker
args.topk_rerank = 5
args.tau = 0.0

args.attack_position = "random"
args.passage_attack = None
args.generator_model = 'llama'

# passage_attack is either None, "mix", "cf", "nsy", "neg". cf means counterfactual, nsy means noisy, neg means irrelevant. mix means random from all 3
# randomness is set by a seed in config file, so it is deterministic.
# attack_position is either "random", "top", "bottom". Note that "random" with tau=0 means no attack at all.

attack_params_str = f"{str(args.passage_attack)}_{str(args.attack_position)}_{str(args.tau)}"

# Vanilla RAG (Scenario 0)
args.use_robustrag = False
args.rerank = False
args.output_file = f"output/truthful_vanilla_rag_{attack_params_str}.json"
eval(args)

# Use Vanilla RAG with CrossEncoder reranking (Scenario 1).
args.use_robustrag = False
args.rerank = True
args.output_file = f"output/truthful_crossencoder_rag_{attack_params_str}.json"
eval(args)

# Use RobustRAG with CrossEncoder reranking (Scenario 2).
args.use_robustrag = True
args.rerank = True
args.output_file = f"output/truthful_crossencoder_robustrag_{attack_params_str}.json"
eval(args)

# Use RbFt with CrossEncoder reranking (Scenario 3).
# Change generator from llama to rbft_llama.
args.use_robustrag = False
args.generator_model = 'rbft_llama'
args.rerank = True
args.output_file = f"output/truthful_crossencoder_rbft_{attack_params_str}.json"
eval(args)

# Use RobustRAG with RbFt with CrossEncoder reranking (Scenario 4).
# # Note that this is different from RobustRaG model_name
args.generator_model = 'rbft_llama'
args.use_robustrag = True
args.model_name = 'rbft_llama'
args.rerank = True
args.output_file = f"output/truthful_crossencoder_robustrag_rbft{attack_params_str}.json"
eval(args)