In [1]:
import os
from datasets import load_dataset
import torch
import json
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig,
)
from tqdm import tqdm
import numpy as np
import random
import argparse


In [2]:

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)
seed_everything(42)

In [3]:
model2path = json.load(open("eval/LongBench/config/model2path.json", "r"))
model2maxlen = json.load(open("eval/LongBench/config/model2maxlen.json", "r"))
print(model2path)
print(model2maxlen)

model_name = 'Llama-3-8B-Instruct-Gradient-1048k'


{'Llama-2-7B-32K-Instruct': 'models/Llama-2-7B-32K-Instruct', 'Mistral-7B-Instruct-v0.2': 'models/Mistral-7B-Instruct-v0.2', 'Mistral-7B-Instruct-v0.3': 'models/Mistral-7B-Instruct-v0.3', 'Llama-3-8B-Instruct-Gradient-1048k': 'models/Llama-3-8B-Instruct-Gradient-1048k', 'Meta-Llama-3.1-8B-Instruct': 'models/Meta-Llama-3.1-8B-Instruct'}
{'Mistral-7B-Instruct-v0.2': 31500, 'Mistral-7B-Instruct-v0.3': 31500, 'Llama-3-8B-Instruct-Gradient-1048k': 1047500, 'Llama-2-7B-32K-Instruct': 31500, 'Meta-Llama-3.1-8B-Instruct': 127500}


In [4]:

def load_model_and_tokenizer(path, model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        path, trust_remote_code=True, use_fast=False
    )
    model = AutoModelForCausalLM.from_pretrained(
        path,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        attn_implementation="eager",
    )

    generation_config = GenerationConfig.from_pretrained(path)
    eos_token_ids = generation_config.eos_token_id
    if not isinstance(eos_token_ids, list):
        eos_token_ids = [eos_token_ids]

    model = model.eval()

    return model, tokenizer, eos_token_ids

model, tokenizer, eos_token_ids = load_model_and_tokenizer(
    model2path[model_name], model_name
)

from duo_attn.utils import to_device
device_list = [i for i in range(torch.cuda.device_count())]
model = to_device(model, device_list, enable_tp=True)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
max_length = model2maxlen[model_name]
datasets = [
            "qasper",
            "multifieldqa_en",
            "hotpotqa",
            "2wikimqa",
            "gov_report",
            "multi_news",
            "trec",
            "triviaqa",
            "samsum",
            "passage_count",
            "passage_retrieval_en",
            "lcc",
            "repobench-p",
        ]

In [6]:

dataset2prompt = json.load(open("eval/LongBench/config/dataset2prompt.json", "r"))
dataset2maxlen = json.load(open("eval/LongBench/config/dataset2maxlen.json", "r"))

In [7]:
if not os.path.exists("LongBench/pred"):
    os.makedirs("LongBench/pred")
if not os.path.exists("LongBench/pred_e"):
    os.makedirs("LongBench/pred_e")

In [8]:

# This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name):
    if "llama-2" in model_name:
        prompt = f"[INST]{prompt}[/INST]"
    return prompt

def post_process(response, model_name):
    if "xgen" in model_name:
        response = response.strip().replace("Assistant:", "")
    elif "internlm" in model_name:
        response = response.split("<eoa>")[0]
    elif "llama-3" in model_name.lower():
        response = (
            response.split(".assistant")[0]
            .split("\n\nQuestion")[0]
            .split("</s>")[0]
            .strip()
        )
    elif "Llama-2-7B-32K-Instruct" in model_name:
        response = (
            response.split("(Document")[0]
            .split("\n\nQuestion")[0]
            .split("\n\nAnswer")[0]
            .split("(Passage")[0]
            .strip()
        )
    return response

def get_pred(
    model,
    tokenizer,
    eos_token_ids,
    data,
    max_length,
    max_gen,
    prompt_format,
    dataset,
    model_name,
    decoding_simulation_length,
):
    preds = []
    pbar = tqdm(data)
    for idx, json_obj in enumerate(pbar):
        prompt = prompt_format.format(**json_obj)
        # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
        tokenized_prompt = tokenizer(
            prompt, truncation=False, return_tensors="pt"
        ).input_ids[0]
        if len(tokenized_prompt) > max_length:
            half = int(max_length / 2)
            prompt = tokenizer.decode(
                tokenized_prompt[:half], skip_special_tokens=True
            ) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
        if dataset not in [
            "trec",
            "triviaqa",
            "samsum",
            "lsht",
            "lcc",
            "repobench-p",
        ]:  # chat models are better off without build prompts on these tasks
            prompt = build_chat(tokenizer, prompt, model_name)

        input = tokenizer(prompt, truncation=False, return_tensors="pt").to("cuda")
        pbar.set_description(
            f"Generating for {idx}, len = {input.input_ids.shape[-1]}"
        )
        simulation_start_idx = input.input_ids.shape[-1] - decoding_simulation_length
        with torch.no_grad():
            output = model(
                input_ids=input.input_ids[:, :simulation_start_idx],
                past_key_values=None,
                use_cache=True,
            )
            past_key_values = output.past_key_values
            if decoding_simulation_length > 0:
                for idx, input_id in enumerate(
                    input.input_ids[0, simulation_start_idx:]
                ):
                    output = model(
                        input_ids=input_id.unsqueeze(0).unsqueeze(0),
                        past_key_values=past_key_values,
                        use_cache=True,
                    )
                    past_key_values = output.past_key_values
            pred_token_idx = output.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
            generated_content = [pred_token_idx.item()]
            for _ in range(max_gen - 1):
                outputs = model(
                    input_ids=pred_token_idx,
                    past_key_values=past_key_values,
                    use_cache=True,
                )

                past_key_values = outputs.past_key_values
                pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
                generated_content += [pred_token_idx.item()]
                if pred_token_idx.item() in eos_token_ids:
                    break

        pred = tokenizer.decode(generated_content, skip_special_tokens=True)
        pred = post_process(pred, model_name)
        print(f"Prediction: {pred}")
        preds.append(
            {
                "pred": pred,
                "answers": json_obj["answers"],
                "all_classes": json_obj["all_classes"],
                "length": json_obj["length"],
            }
        )
    return preds

In [None]:
for dataset in datasets:
    print(dataset)
    data = load_dataset("THUDM/LongBench", dataset, split="test")
    if not os.path.exists(f"LongBench/pred/{model_name}"):
        os.makedirs(f"LongBench/pred/{model_name}")
    out_path = f"LongBench/pred/{model_name}/{dataset}-full.jsonl"
    prompt_format = dataset2prompt[dataset]
    max_gen = dataset2maxlen[dataset]

    preds = get_pred(
            model,
            tokenizer,
            eos_token_ids,
            data,
            max_length,
            max_gen,
            prompt_format,
            dataset,
            model_name,
            50,
        )
    print(preds)

    with open(out_path, "w", encoding="utf-8") as f:
        for pred in preds:
            json.dump(pred, f, ensure_ascii=False)
            f.write("\n")

Generating for 0, len = 4000:   0%|          | 0/200 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
Generating for 1, len = 3331:   0%|          | 1/200 [00:03<12:51,  3.88s/it]

Prediction: The ground truth for fake news is established through manual inspection of the text field within the tweets to label them as containing fake news, or not containing them.


Generating for 2, len = 4361:   1%|          | 2/200 [00:08<14:15,  4.32s/it]

Prediction: The GhostVLAD approach is an extension of the NetVLAD approach that adds ghost clusters to map noisy or irrelevant content into ghost clusters and excludes them during feature aggregation. It is used for language identification and has been shown to outperform other pooling methods by achieving 98.43% F1-Score.


Generating for 3, len = 2977:   2%|▏         | 3/200 [00:15<18:32,  5.65s/it]

Prediction: Their model outperforms previous state-of-the-art methods by 68.8% to 71.8% when applied to the IEMOCAP dataset.


Generating for 4, len = 4439:   2%|▏         | 4/200 [00:19<16:12,  4.96s/it]

Prediction: The article proposes using context tweets as an additional feature for neural network models to better understand the data and improve the accuracy of detecting abusive language. The article also suggests using ensemble models of variant models and features for further improvements.


Generating for 5, len = 5373:   2%|▎         | 5/200 [00:25<16:59,  5.23s/it]

Prediction: They looked at different Facebook pages, including FoxNews, CNN, ESPN, New York Times, Time magazine, Huffington Post Weird News, The Guardian, Cartoon Network, Cooking Light, Home Cooking Adventure, Justin Bieber, Nickelodeon, Spongebob, and Disney. They also used a subset of pages based on their performance on the development set and the observation of emotions distribution on different pages and in the different datasets.


Generating for 6, len = 5635:   3%|▎         | 6/200 [00:30<17:04,  5.28s/it]

Prediction: Yes. The article states that the hashtag segmentation model is language-independent and the authors intend to extend their toolkit to languages other than English as future work. However, the article also mentions that the authors focused on English hashtags in their experiments. Therefore, the answer to this question is "yes".


Generating for 7, len = 6498:   4%|▎         | 7/200 [00:38<19:38,  6.11s/it]

Prediction: The article proposes an evaluation protocol and baseline for the task of concept-map-based MDS. The corpus is also publicly available under a permissive license.
