In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
from semantic_text_splitter import TextSplitter
import tiktoken
import os
import sys
sys.path.append('..')
import random
from tqdm import tqdm
from itertools import chain

from src.retriever import DenseRetriever
from tasks.longbench_utils import qa_f1_score
encoder = tiktoken.get_encoding("cl100k_base")
text_splitter = TextSplitter.from_tiktoken_model("gpt-3.5-turbo", 512)

  from .autonotebook import tqdm as notebook_tqdm
08/06/2024 16:13:35 - INFO - faiss.loader - Loading faiss with AVX512 support.
08/06/2024 16:13:35 - INFO - faiss.loader - Successfully loaded faiss with AVX512 support.


### initialize hyper-parameters

In [2]:
model_kwargs = {
        "device_map": "cuda:0",
        "attn_implementation": "sdpa",
        "torch_dtype": 	torch.bfloat16,
        "trust_remote_code": True,
        "ultragist_ratio": [4]
    }

tokenizer_kwargs = {
        "padding_side": "left",
        "trust_remote_code": True,
    }

generation_kwargs = {
    "max_new_tokens": 128,
    "do_sample": False,
    "temperature": 1.0,
    "top_k": 0,
    "top_p": 1.0
}

prompts = {
  "summarization": "You are provided with a long article. Your task is to generate a concise summary by listing the key points of the long article.\n\n\
  ### Instructions:\n\n1. Long Article: {context}\n2. Output: Generate a list of key points, each separated by a newline, with numeric order.\n\n\
  ### Requirements:\n\n- The key points should be short and high-level.\n- Ensure that the key points convey the most important information and main events of the long article.\n",
  "clues_1": "You are given a long article and a question. After a quick read-through, you have a rough memory of the article. To answer the question effectively, \
  you need to recall and extract specific details from the article. Your task is to find and retrieve the relevant clue texts from the article that will help answer the question.\
  \n\n### Inputs:\n- **Long Article:** {context}\n- **Question:** {input}\n\n### Requirements:\n1. You have a general understanding of the article. Your task is to identify and \
  extract one or more specific clue texts from the article that are relevant to the question.\n2. Output only the extracted clue texts. For multiple sentences, separate them with a newline.\n",
  "clues_2": "You are provided with a long article and a question. After a quick read-through, you have a rough memory of the article. To better answer the question,\
  you need to recall specific details within the article. Your task is to generate precise clue questions that can help locate the necessary information.\n\n### \
  Inputs:\n- **Long Article:** {context}\n- **Question:** {input}\n\n### Requirements:\n1. You have a general understanding of the article. Your task is to write \
  one or more precise clue questions to search for supporting evidence in the article.\n2. Output only the clue questions. For multiple questions, separate them with a newline.\n ",
  "qa": "You are given a long text. You're required to read the long text and answer the questions.\n\nNow the long text begins. \n\n{context}\n\nNow the \
  long text ends.\n\nAnswer the following questions.\n\n{input}",
}



### initialize models

In [None]:
model_path = "path_to_checkpoint"

tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            **tokenizer_kwargs)

model = AutoModelForCausalLM.from_pretrained(
            model_path, 
            **model_kwargs,
        ).eval()

retriever = DenseRetriever(
            encoder="BAAI/bge-m3",
            pooling_method="cls",
            dense_metric="cos",
            query_max_length=128,
            key_max_length=512,
            cache_dir="/share/shared_models",
            hits=3,
        )




generator_tokenizer = AutoTokenizer.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.2",
            **tokenizer_kwargs)

model_kwargs.pop('ultragist_ratio')
generator_model = AutoModelForCausalLM.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.2", 
            **model_kwargs,
        ).eval()
generator_tokenizer.pad_token = generator_tokenizer.eos_token

### Functions

In [9]:
def generate(prompts):
    if isinstance(prompts, str):
        prompts = [prompts]
    to_encode = []
    for p in prompts:
        to_encode.append(tokenizer.apply_chat_template(
            [{"role": "user", "content": p}], 
            tokenize=False, 
            add_generation_prompt=True))
        
    inputs = tokenizer(
        to_encode, add_special_tokens=False, return_tensors="pt", padding=True
    ).to(model.device)
    
    outputs = model.generate(
                **inputs,
                **generation_kwargs,
                pad_token_id=tokenizer.eos_token_id)
    outputs = outputs[:, inputs["input_ids"].shape[1]:]
    outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    model.memory.reset()
    return outputs

def load_jsonl(path):
    rtn = []
    for line in tqdm(open(path)):
        rtn.append(json.loads(line))
    return rtn

def show_sample(line):
    print("Context Length: ", line["length"])
    print("Input Query: ", line["input"])
    print("Answer: ", line["answers"][0])
    print("==="*20)
    
def get_clues(sample):
    rtn = []
    output = []
    output.append(generate(
         prompts["clues_1"].format(context=sample["context"], input=sample["input"]))[0])
    output.append(generate(
         prompts["clues_2"].format(context=sample["context"], input=sample["input"]))[0])
    
    for line in output:
        rtn.extend(line.split("\n"))
    rtn = [sent for sent in rtn if len(sent.split()) > 3]
    rtn = list(set(rtn))
        
    return rtn

def pred(prompts):
    if isinstance(prompts, str):
        prompts = [prompts]
    to_encode = []
    for p in prompts:
        to_encode.append(tokenizer.apply_chat_template(
            [{"role": "user", "content": p}], 
            tokenize=False, 
            add_generation_prompt=True))
        
    inputs = generator_tokenizer(
        to_encode, add_special_tokens=False, return_tensors="pt", padding=True
    ).to(generator_model.device)
    
    outputs = generator_model.generate(
                **inputs,
                **generation_kwargs,
                pad_token_id=tokenizer.eos_token_id)
    outputs = outputs[:, inputs["input_ids"].shape[1]:]
    outputs = generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs

def get_evidence(queries, context):
    retriever.remove_all()
    if isinstance(queries, str):
        queries = [queries]
    retrieval_corpus = text_splitter.chunks(context)
    print(f"{len(retrieval_corpus)} passages will be indexed...")
    
    retriever.add(retrieval_corpus)
    topk_scores, topk_indices = retriever.search(
            queries=queries)
    
    # reorder indices to keep consecutive chunks
    topk_indices = list(chain(*[topk_index.tolist() for topk_index in topk_indices]))
    topk_indices = list(set(topk_indices))

    topk_indices = sorted([x for x in topk_indices if x > -1])
    print("Mem Selected indices: ", topk_indices)
    # slice out relevant context from the corpus
    mem_results = [retrieval_corpus[i].strip() for i in topk_indices]
    
    topk_scores, topk_indices = retriever.search(
            queries=queries[-1:], hits=len(mem_results))
    
    topk_indices = topk_indices[0].tolist()
    topk_indices = sorted([x for x in topk_indices if x > -1])
    print("RAG Selected indices: ", topk_indices)
    rag_results = [retrieval_corpus[i].strip() for i in topk_indices]
    
    return mem_results, rag_results

### load data

In [5]:
dev_data = load_jsonl(
    "legal.dev.jsonl")

438it [00:00, 439.71it/s]


In [15]:
sample_id = random.randint(0,len(dev_data)-1)
sample = dev_data[sample_id]

show_sample(sample)
clues = get_clues(sample)
print("Clues from Memory:\n", "\n".join(clues))
print("==="*20)

clues.append(sample["input"])
mem_evidence, rag_evidence = get_evidence(clues, sample["context"])
mem_results = pred(prompts["qa"].format(context="\n\n".join(mem_evidence), input=sample["input"]))
rag_results = pred(prompts["qa"].format(context="\n\n".join(rag_evidence), input=sample["input"]))

print("Context Compression Rate for MemRAG: ", round(len(encoder.encode("\n\n".join(mem_evidence)))/sample["length"], 3)*100, "%")
print("Context Compression Rate for RAG: ", round(len(encoder.encode("\n\n".join(rag_evidence)))/sample["length"], 3)*100, "%")

print("MemRAG score: ", round(qa_f1_score(mem_results[0], sample["answers"][0]),3))
print("RAG score: ", round(qa_f1_score(rag_results[0], sample["answers"][0]),3))

Context Length:  56441
Input Query:  What are the representations and warranties of the Company regarding its qualification, organization, and subsidiaries?
Answer:  The Company represents and warrants that it and its subsidiaries are duly organized, validly existing, and in good standing under the laws of their respective jurisdictions, have all requisite power and authority to own their properties and assets, and are duly qualified to do business and in good standing in each jurisdiction where required.
Clues from Memory:
 The Company and each of its Subsidiaries is a legal entity duly organized, validly existing and in good standing (or the equivalent thereof, if applicable, in each case, with respect to the jurisdictions that recognize the concept of good standing or any equivalent thereof) under the Laws of the jurisdiction of its incorporation, organization or formation, as applicable, and has all requisite corporate or similar power and authority to own, lease and operate its pr