# Launching LLM

# Initialize

In [None]:
# @title Download Dependency

from google.colab import drive
drive.mount('/content/drive')

!pip install pyarrow==15.0.0
!pip install xformers -q
!pip install huggingface-cli -q
!pip install accelerate -q
!pip install -i https://pypi.org/simple/ bitsandbytes -q
!pip install --upgrade transformers -q
!pip install --upgrade trl -q
!pip install sqlitedict -q
!pip install fastparquet -q
!pip install scikit-learn -q
!pip install sentence_transformers -q
!pip install openai -q
!pip install rouge-score -q


In [None]:
# @title Login to Huggingface
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# @title Global Variables

import os
from sqlitedict import SqliteDict
import hashlib
import json

task = "iclr23" # @param ["iclr23", "peer_grading"]
working_dir = f"/content/drive/MyDrive/llmserver/"
task_dir = f"/content/drive/MyDrive/llmserver/task_{task}/"
llmgen_cache_dir = working_dir + "cache_llmgen/"
logprobs_cache_dir = working_dir + "cache_logprobs/"
hfmodels_cache_dir = working_dir + "cache_hfmodels/"
finetune_cache_dir = working_dir + "cache_finetune/"
dataset_dir = task_dir + "dataset/"
result_dir = task_dir + "result/"

max_review = 3 if task == "iclr23" else 9

current_directory = os.getcwd()
print(f"current_directory: {current_directory}")
print(f"task_directory: {task_dir}")

# Load prompts
prompts = {}
for filename in os.listdir(task_dir + "prompt/"):
    if filename.endswith('.txt'):
        file_path = task_dir + "prompt/" + filename
        with open(file_path, 'r', encoding='utf-8') as f:
            prompts[filename[:-4]] = f.read()

# Cache key function for caching the output
def generate_cache_key(params):
    '''
    # generate a unique cache key based on the request parameters
    ## md5 seems to be enough
    params: dict, request parameters
    '''
    params_string = json.dumps(params, sort_keys=True)
    return hashlib.sha256(params_string.encode('utf-8')).hexdigest()

# @markdown - Select a debug mode (0 for no log, 1 for light log, 2 for detailed log)
debug_mode = 1 # @param [0, 1, 2]

In [None]:
# @title Define the LLM Wrapper
import logging
import os
import os.path as osp
import torch
import torch.cuda
import torch.backends.cudnn
import argparse
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
)
from typing import List, Dict
import time

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

logger = logging.getLogger(__name__)

class Wrapper:
    def __init__(
        self,
        model_name,
        is_chat_model,
        debug_mode=False,
        load_4bit=False,
        use_local_model = False
    ):
        wrapper_cache_dir = hfmodels_cache_dir if use_local_model else None
        self.model_name = model_name
        self.no_cuda = (os.environ["CUDA_VISIBLE_DEVICES"] == "")
        self.use_cuda = not self.no_cuda
        self.calculator = None
        self.wiki_tool = None
        self.wiki_error_max = None
        self.search_cache_start = None
        self.search_cache_save = None
        self.usage = {
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "total_tokens": 0
        }
        START_TIME = time.perf_counter()
        print("Start loading {}...".format(model_name))
        if self.use_cuda:
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=load_4bit,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
            )
        else:
            quantization_config = None
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            cache_dir=wrapper_cache_dir
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            cache_dir=wrapper_cache_dir,
            device_map="auto" if self.use_cuda else None,
            quantization_config=quantization_config,
        )
        self.model.eval()
        print("Done with {:.2f} seconds.".format(time.perf_counter() - START_TIME))
        self.debug_mode = debug_mode
        self.is_chat_model = is_chat_model
        self.pipeline = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
        )

# Loading LLM Wrapper

In [None]:
# @title Load wrapper from disk {"form-width":"25%"}

judge = "llama3.1-8b-finetune-new" # @param ["llama3-8b", "llama3.1-8b", "llama2-7b", "llama3.1-8b-finetune-new"]
load_in_4bit = True # @param {type:"boolean"}
use_local_model = False # @param {type:"boolean"}


judge_model_names = {
    "llama3-8b": "meta-llama/Meta-Llama-3-8B-Instruct",
    "llama3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "llama2-7b": "meta-llama/Llama-2-7b-chat-hf",
    "llama3.1-8b-finetune-new": finetune_cache_dir + "Meta-Llama-3.1-8B-bnb-4bit-finetune-peer-review-judge",
}
if judge in judge_model_names: judge = judge_model_names[judge]

generator = Wrapper(judge,is_chat_model=True,load_4bit=load_in_4bit,use_local_model=use_local_model)

# Logprobs (GEM / GEMS / BARTscore)

In [None]:
# @title Define the function for logprobs

logprobs_db = SqliteDict(f"{logprobs_cache_dir}logprobs_{(judge + ('-4bit' if load_in_4bit else '')).replace('/', '--')}.sqlite", autocommit=True)

def find_assistant_indexes_llama3(li_token):
    '''
    # Find the indexes of assistant feedbacks
    ## Output: list, a sorted list of indexes
    li_token: list, a list of tokens
    '''
    ret = []
    start = -1
    for i in range(len(li_token) - 2):
        if li_token[i:i+3] == ["<|start_header_id|>","assistant","<|end_header_id|>"]:
            start = i + 3
            break
    end = -1
    for i in range(start,len(li_token)):
        if li_token[i] == "<|eot_id|>":
            end = i + 1
            break
    if min(start,end) == -1:
        raise Exception("Not found.")
        return []
    else:
        return list(range(start,end))

def get_logprob(generator, messages):

    full_prompt = generator.tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=False)
    if debug_mode > 1: print(full_prompt)
    # tokenized_msg = generator.tokenizer(prompt, return_tensors="pt") # This is wrong, because it adds <s> twice!
    tokenized_msg = generator.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_tensors="pt")
    input_ids = tokenized_msg[0]

    tokenized_tokens = generator.tokenizer.convert_ids_to_tokens(input_ids)
    output_log_probs = generator.model(tokenized_msg).logits[0].cpu()
    output_log_probs = torch.log(torch.softmax(output_log_probs, dim=-1))
    input_ids = input_ids[1:]
    output_log_probs = output_log_probs[range(len(input_ids)), input_ids]
    logprobs = [0.] + output_log_probs.tolist()

    if debug_mode > 1:
        for i in find_assistant_indexes_llama3(tokenized_tokens): print(tokenized_tokens[i],logprobs[i])

    return sum(logprobs[i] for i in find_assistant_indexes_llama3(tokenized_tokens))


In [None]:
# @title Query logprobs "log Pr[target | source, synopsis]"

def query_logprob(source, target, synopsis = "Not available", use_cache = True):

    assert(synopsis is not None)

    system_prompt = prompts["gen_review_system"]

    messages = [{"role": "system", "content": system_prompt},
                {"role": "user", "content": prompts["gen_review_user"].format(synopsis,source)},
                {"role": "assistant", "content": target},
                ]

    params = {
        "query": "query_logprob",
        "model": judge,
        "load_in_4bit": load_in_4bit,
        "messages": messages
    }
    key = generate_cache_key(params)

    if use_cache and (key in logprobs_db): return logprobs_db[key]


    logprobs_db[key] = get_logprob(generator,messages)
    return logprobs_db[key]


# Sentence Embedding (BERTscore)

In [None]:
# @title Define the function for sentence embedding

# see stella_en_400M_v5

# This model supports two prompts: "s2p_query" and "s2s_query" for sentence-to-passage and sentence-to-sentence tasks, respectively.

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import re

sentence_embedding_model = SentenceTransformer("dunzhang/stella_en_400M_v5", trust_remote_code=True).cuda()
bertscore_model = "stella_en_400M_v5"
bertscore_db = SqliteDict(f"{logprobs_cache_dir}bertscore_{(bertscore_model).replace('/', '--')}.sqlite", autocommit=True)

def get_bertscore(source, target, use_cache = True):
    assert(type(source) is list)
    assert(type(target) is list)

    params = {
        "query": "query_bertscore",
        "model": bertscore_model,
        "source": source,
        "target": target
    }
    key = generate_cache_key(params)
    if use_cache and (key in bertscore_db): return bertscore_db[key]

    query_embeddings = sentence_embedding_model.encode(source, prompt_name="s2s_query")
    doc_embeddings = sentence_embedding_model.encode(target)

    similarities = sentence_embedding_model.similarity(query_embeddings, doc_embeddings)
    precision = float(similarities.max(dim=1)[0].mean())
    recall = float(similarities.max(dim=0)[0].mean())
    f1 = 2 * (precision * recall) / (precision + recall)

    bertscore_db[key] = f1
    return f1

def split_sentences(text):
    text = text.replace('\n', ' ').replace('\r', ' ')
    sentences = re.split(r'(?<=[.!?])\s+', text)
    sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
    return sentences

def query_bertscore(source, target):
    assert(type(source) is str)
    assert(type(target) is str)
    return get_bertscore(split_sentences(source), split_sentences(target))


# LLM evaluation (LM-examiner)

In [None]:
# @title Define the function for lm-examiner

import openai
import threading
import re

openrouter_client = openai.OpenAI(
    base_url="https://openrouter.ai/api/v1",
    api_key="API_KEY",
)
openrouter_chat = openrouter_client.chat.completions

cache_lock = threading.Lock()  # lock to keep the cache operations safe

def call_api(compl, params, cache_db, use_cache = True):
    key = generate_cache_key(params)
    with cache_lock:
        if (cache_db is not None) and (key in cache_db): return cache_db[key]

    response = compl.create(**params)

    with cache_lock:
        if cache_db is not None: cache_db[key] = response
    return response


def simple_call_api(system_text, user_text, model, cache_db, max_tokens=4000, temperature=0):
    compl = openrouter_chat
    messages = [{
        "role": "system",
        "content": system_text,
    }, {
        "role": "user",
        "content": user_text,
    }] if system_text != "" else [{
        "role": "user",
        "content": user_text,
    }]
    params = {
        "model": model,
        "messages": messages,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "frequency_penalty": 0,
        "presence_penalty": 0
    }

    response = call_api(compl, params, cache_db)
    if response.choices is None:  # an error occurs
        error_code = response.error["code"]
        print("Error occurs (%d)" % (error_code))
        print("Try again!")
        response = call_api(compl, params, cache_db, use_cache = False)
        if response.choices is None:
            print(response)
            raise Exception("Error while calling api")
        else:
            print("Success")
    return response.choices[0].message.content

lmexaminer_model = "openai/gpt-4o"
lmexaminer_db = SqliteDict(f"{logprobs_cache_dir}lmexaminer_{(lmexaminer_model).replace('/', '--')}.sqlite", autocommit=True)

def get_lmexaminer(source, ref):
    assert(type(source) is str)
    assert(type(ref) is str)
    return simple_call_api(prompts["eval_review_system"],prompts["eval_review_user"].format(ref,source), lmexaminer_model, lmexaminer_db)

def query_lmexaminer(source, ref):
    assert isinstance(source, str)
    assert isinstance(ref, str)

    res = get_lmexaminer(source, ref)
    pattern = r"(?:overall_score|Overall Score)\W*\s*([-+]?\d*\.?\d+)\s*</overall_score>"
    matches = re.findall(pattern, res)

    if matches:
        score = float(matches[0])
        return score
    else:
        print("Error: score not found")
        print(source)
        return 5.0


# Longest Common Subsequence (ROUGE-L)

In [None]:
from rouge_score import rouge_scorer

rougel_db = SqliteDict(f"{logprobs_cache_dir}rougel.sqlite", autocommit=True)

def query_rougel(source, target, use_cache = True):

    params = {
        "query": "query_rougel",
        "source": source,
        "target": target
    }
    key = generate_cache_key(params)
    if use_cache and (key in rougel_db): return rougel_db[key]

    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    scores = scorer.score(source, target) # (hypothesis, reference)

    rougel_db[key] = scores['rougeL'].fmeasure # precision / recall / fmeasure
    return rougel_db[key]


# BLEU Score

In [None]:
import nltk
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize

nltk.download('punkt')

def query_bleu(source, target):
    reference = [word_tokenize(target)]
    candidate = word_tokenize(source)
    score = sentence_bleu(reference, candidate)
    return score


# Judging Model

In [None]:
model_names = {
        "review1": "human_review1",
        "review2": "human_review2",
        "review3": "human_review3",
        "gpt4omini": "openai/gpt-4o-mini",
        "gpt4o": "openai/gpt-4o",
        "gpt4turbo": "openai/gpt-4-turbo",
        "gpt35turbo": "openai/gpt-3.5-turbo-0125",
        "claude2": "anthropic/claude-2",
        "claude3haiku": "anthropic/claude-3-haiku",
        "claude3sonnet": "anthropic/claude-3-sonnet",
        "claude3opus": "anthropic/claude-3-opus",
        "geminipro1": "google/gemini-pro",
        "geminipro15": "google/gemini-pro-1.5",
        "llama3-8b": "meta-llama/llama-3-8b-instruct",
        "llama3-70b": "meta-llama/llama-3-70b-instruct",
        "mixtral-7b": "mistralai/mistral-7b-instruct:nitro",
        "mixtral-8x7b": "mistralai/mixtral-8x7b-instruct",
        "mixtral-8x22b": "mistralai/mixtral-8x22b-instruct",
        "o1": "openai/o1-preview",
        "o1mini": "openai/o1-mini",
        "llama31-8b": "meta-llama/llama-3.1-8b-instruct",
        "llama31-70b": "meta-llama/llama-3.1-70b-instruct",
        "llama31-405b": "meta-llama/llama-3.1-405b-instruct",
        "mistral-large": "mistralai/mistral-large",
        "mistral-small": "mistralai/mistral-small",
        "mistral-tiny": "mistralai/mistral-tiny",
    }

In [None]:
# @title  {"form-width":"30%"}
import pandas as pd
from tqdm.notebook import tqdm
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
import numpy as np

candidates = ["review2"] \
        # + ["gpt4omini", "gpt4o", "gpt35turbo"] \
        # + ["claude2", "claude3haiku", "claude3sonnet", "claude3opus"] \
        # + ["geminipro1", "geminipro15"] \
        # + ["llama3-8b", "llama3-70b"] \
        # + ["llama31-8b", "llama31-70b"] \
        # + ["mixtral-8x7b", "mixtral-8x22b"] \
        # + ["mistral-large","mistral-small","mistral-tiny"] \
        # + ["o1mini", "o1"] \

manipulations = ["non", "restate", "degrade", "meaningless_elongation", "llama3_restatement", "degradation_v2", "abstract_review"]

# manipulations = ["non"]

ben = "allother_flat" # @param ["review1", "review2", "review3", "allother", "allother_flat"]

benchmark = model_names[ben] if "allother" not in ben.lower() else ben

# methods = ["BERTscore","LMexaminer","GEM","GEMS","BARTscore_F1","BARTscore_precision","BARTscore_recall","ROUGE-L","BLEU"] #

methods = ["GEM","GEMS"]

hard_prefix = 300 # @param {type:"integer"}

test_mode = 0 # @param {type:"integer"}

save_to_parquet = True # @param {type:"boolean"}

res_dict = {}

run_list = []



for can in candidates:

    for method in methods:

      for mani in manipulations:

        run_list.append((can, method, mani))


for can, method, manipulation in run_list:
    # print(f"Starting {can} {method} {manipulation}")

    preprocess = True if method == "GEM" or method == "GEMS" else False
    results_file_name_prefix = f"{can}_{manipulation}_{ben}_"
    results_file_name = results_file_name_prefix + f"{method}{'-finetune' if 'finetune' in judge else ''}_{preprocess}_{hard_prefix}.parquet"


    candidate = model_names[can]

    manipulation_names = {
        "non":"review",
        "elongate":"review_elongation",
        "restate":"review_restatement",
        "degrade":"review_degradation",
        "meaningless_elongation":"meaningless_elongation",
        "llama3_restatement":"llama3_restatement",
        "degradation_v2":"degradation_v2",
        "abstract_review":"abstract_review",
    }
    candidate_col = manipulation_names[manipulation] + ("_summary" if preprocess else "")

    df_paper = pd.read_parquet(f"{dataset_dir}dataset_paper.parquet",engine='fastparquet')
    df_candidate = pd.read_parquet(f"{dataset_dir}dataset_{candidate.replace('/', '--')}.parquet",engine='fastparquet')
    df_result = pd.DataFrame(columns=["paper_id", "method", "judge", "candidate", "benchmark", "score"])

    if benchmark == "allother":
        all_benchmarks = []
        for i in [1, 3]:
            if f"human_review{i}" != candidate:
                all_benchmarks.append(pd.read_parquet(f"{dataset_dir}dataset_human_review{i}.parquet",engine='fastparquet'))
        assert(len(all_benchmarks)==2)
    elif benchmark == "allother_flat":
        all_benchmarks = []
        for i in [1, 3]: # review 1 and review 3 are peers
            if f"human_review{i}" != candidate:
                all_benchmarks.append(pd.read_parquet(f"{dataset_dir}dataset_human_review{i}.parquet",engine='fastparquet')[:hard_prefix])
        all_benchmarks = [pd.concat(all_benchmarks, ignore_index=True)]
    else:
        all_benchmarks = [pd.read_parquet(f"{dataset_dir}dataset_{benchmark.replace('/', '--')}.parquet",engine='fastparquet')]

    su = []

    if benchmark == "allother_flat":
        rang = len(all_benchmarks[0])
    else:
        rang = min(min(len(df_candidate),len(all_benchmarks[0])),hard_prefix)

    for i in tqdm(range(rang), desc=f"Processing papers ({can}) ({method}) ({manipulation})"):
        row_paper = df_paper.iloc[i % hard_prefix]
        row_candidate = df_candidate.iloc[i % hard_prefix]
        sum_score = 0.
        for df_benchmark in all_benchmarks:
            row_benchmark = df_benchmark.iloc[i]
            if not (row_candidate["paper_id"] == row_paper["paper_id"] and row_benchmark["paper_id"] == row_paper["paper_id"]):
                raise(Exception(row_candidate["paper_id"], row_paper["paper_id"], row_benchmark["paper_id"]))

            target = row_benchmark["review_summary" if preprocess else "review"]
            source = row_candidate[candidate_col]

            if method == "GEM" or method == "GEMS":
                synopsis = row_paper["abstract"] if method == "GEMS" else "Not available"
                score = query_logprob(source,target,synopsis) - query_logprob("Not available",target,synopsis)
                name_judge = judge + ('-4bit' if load_in_4bit else '')
            elif method == "GEM-raw":
                score = query_logprob(source[:10000],target[:10000]) - query_logprob("Not available",target)
                name_judge = judge + ('-4bit' if load_in_4bit else '')
            elif method == "BERTscore":
                score = query_bertscore(source, target)
                name_judge = "stella_en_400M_v5"
            elif method == "LMexaminer":
                paper = row_paper["parsed_text"]
                score = query_lmexaminer(source, paper)
                name_judge = "openai/gpt-4o"
            elif method == "ROUGE-L":
                score = query_rougel(source, target)
                name_judge = "non"
            elif method == "BLEU":
                score = query_bleu(source, target)
                name_judge = "non"
            elif method == "BARTscore_F1":
                score = (query_logprob(source[:10000],target[:10000]) + query_logprob(target[:10000],source[:10000])) / 2.0
                name_judge = judge + ('-4bit' if load_in_4bit else '')
            elif method == "BARTscore_recall":
                score = query_logprob(source[:10000],target[:10000])
                name_judge = judge + ('-4bit' if load_in_4bit else '')
            elif method == "BARTscore_precision":
                score = query_logprob(target[:10000],source[:10000])
                name_judge = judge + ('-4bit' if load_in_4bit else '')
            else:
                raise(Exception)
            sum_score += score

        avg_score = sum_score / len(all_benchmarks)
        df_result.loc[len(df_result)] = [row_paper["paper_id"], method, name_judge, candidate, benchmark, avg_score]
        su.append(avg_score)
        sys.stdout.write("\r{:s} index = {:d}, score = {:.4g}, avg = {:.4g}".format(method, i, avg_score, np.mean(su)))

    if manipulation == "non":
        if f"{method}-{judge}" not in res_dict: res_dict[f"{method}-{judge}"] = {}
        res_dict[f"{method}-{judge}"][candidate] = [np.mean(su), np.std(su, ddof=1) / np.sqrt(len(su))]
    if save_to_parquet:
        df_result.to_parquet(f"{result_dir}{results_file_name}", index=False)

if manipulation == "non":
    with open(f"{result_dir}abstract_result{'_ft' if 'finetune' in judge else ''}.json", 'w', encoding='utf-8') as json_file:
        json.dump(res_dict, json_file)
