In [1]:
%cd ../

/home/qwj/code/HippoRAG


In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="4,5"

In [3]:
import pandas as pd
import numpy as np
from typing import List, Tuple

In [4]:
all_relations = pd.read_csv("./output/musique_gpt/relations.tsv", sep='\t', header=None)
all_relations_names = all_relations[1].tolist()
len(all_relations_names)


22222

In [5]:
from gritlm import GritLM
class GritWrapper:
    def __init__(self) -> None:
        super().__init__()
        self.retrieval_model = GritLM("/data/qwj/model/GritLM-7B", torch_dtype="auto")
        # self.tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
        instruction = "Given a relation, retrieve the relevant relations"
        self.query_instruction = "<|user|>\n" + instruction + "\n<|embed|>\n" 

    def index(self, corpus: List[str]):
        self.docs = corpus
        self.doc_embedding_mat = self.retrieval_model.encode_corpus(corpus, instruction='<|embed|>\n')
    
    def retrieve_one_query(self, query: str) -> List[Tuple[float, str]]:
        query_embedding = self.retrieval_model.encode(query, instruction=self.query_instruction)
        query_doc_scores = np.dot(self.doc_embedding_mat, query_embedding.T)
        sorted_indices = np.argsort(query_doc_scores)[::-1]
        sorted_scores = [query_doc_scores[idx] for idx in sorted_indices]
        sorted_docs = [self.docs[idx] for idx in sorted_indices]
        return sorted_scores, sorted_docs

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import spacy
# from nltk.corpus import wordnet as wn
from nltk.stem import PorterStemmer
stemmer = PorterStemmer()

# 加载英文模型，对于中文可以使用 'zh_core_web_sm'
nlp = spacy.load('en_core_web_sm')


def extract_content_words(query):
    doc = nlp(query)
    lemmas = [token.lemma_ for token in doc if token.pos_ in ['NOUN', 'VERB', 'ADJ', 'ADV']]
    stemmed_content_words = [stemmer.stem(word) for word in lemmas]
    return stemmed_content_words

def contains_content_word(text, lemmatized_content_words):
    doc = nlp(text)
    text_words_lemmatized = set(token.lemma_.lower() for token in doc if token.is_alpha)
    text_words_stemmed = set(stemmer.stem(token) for token in text_words_lemmatized)
    return any(word in text_words_stemmed for word in lemmatized_content_words)

In [7]:
def get_similar_relations_v3(query: str, retriever):
    print(query, "=============")
    scores, relation_names = retriever.retrieve_one_query(query)
    query_words_l = extract_content_words(query)

    filtered_relations = []
    original_indices = []
    print("\n".join([f"{n}, {s:.4f}"  for n, s in zip(relation_names[:10], scores[:10])]))
    print("\t=============")
    # Collect filtered relations and their original indices
    for index, (name, score) in enumerate(zip(relation_names, scores)):
        if not contains_content_word(name, query_words_l):
            filtered_relations.append((name, score))
            original_indices.append(index)
            print(f"{index+1}\t {name}\t {score:.4f}")
            if len(original_indices) >= 10:
                return

In [8]:
retriever = GritWrapper()
retriever.index(all_relations_names)

Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.68it/s]


Created GritLM: torch.bfloat16 dtype, mean pool, unified mode, bbcc attn


Batches: 100%|██████████| 87/87 [01:35<00:00,  1.10s/it]


In [9]:
for query in [
    "birthplace",
    "immediately north of",
    "is located",
    "named after",
    "is capital of",
    "the headquarters of",
    "an example of",
    "a part of",
    "follow",
    "an instance of",
    "named commander in chief",
    "a member of",
    "the cast member of",
    'influenced', 'occurred in', 'caused', 'helped', 'the symbol of', 'the owner of', 'went to',
    'about', 'talk about', 'wrote a book about'
    
]:
    print(query.upper(), "======")
    get_similar_relations_v3(query, retriever)

birth place, 0.7803
birthplace of, 0.7530
place of birth is, 0.7222
is the place of birth of, 0.6992
places birth of, 0.6832
was born near, 0.6276
was born in, 0.6177
born near, 0.6165
is hometown of, 0.5959
has its birthplace in, 0.5958
1	 birth place	 0.7803
3	 place of birth is	 0.7222
4	 is the place of birth of	 0.6992
5	 places birth of	 0.6832
6	 was born near	 0.6276
7	 was born in	 0.6177
8	 born near	 0.6165
9	 is hometown of	 0.5959
11	 born in	 0.5946
12	 is born in	 0.5768
directly north of, 0.7394
located just north of, 0.7278
located north of, 0.7213
situated north of, 0.7163
lies just north of, 0.7144
located to the north of, 0.7090
lies north of, 0.7044
lies to the north of, 0.6975
located about north of, 0.6924
is located just north of, 0.6923
29	 situated northwest of	 0.6177
30	 located northwest of	 0.6173
39	 is northwest of	 0.6038
40	 located to the northwest of	 0.6015
43	 northwest of	 0.5941
44	 located to the northwest west of	 0.5936
45	 is to the northwest

In [7]:
for query in [
    "birthplace",
    "immediately north of",
    "is located",
    "named after",
    "is capital of",
    "the headquarters of"
]:
    sorted_scores, sorted_docs = retriever.retrieve_one_query(query)
    print(query, sorted_docs[:10])

birthplace ['birth place', 'birthplace of', 'place of birth is', 'is the place of birth of', 'places birth of', 'was born near', 'was born in', 'born near', 'is hometown of', 'has its birthplace in']
immediately north of ['directly north of', 'located just north of', 'located north of', 'situated north of', 'lies just north of', 'located to the north of', 'lies north of', 'lies to the north of', 'located about north of', 'is located just north of']
is located ['is located', 'is located in', 'located in', 'located', 'located at', 'is located on', 'located on', 'is located at', 'is located on the', 'was located in']
named after ['named after', 'was named after', 'is named after', 'named it after', 'named for', 'was named for', 'are named after', 'is named for', 'was named in honor of', 'was named in honour of']
is capital of ['is capital of', 'is the capital of', 'is the capital city of', 'capital of', 'serves as capital of', 'capital city is', 'capital is located in', 'was capital of', 

In [8]:
for query in [
    "an example of",
    "a part of",
    "follow",
    "an instance of",
    "named commander in chief",
    "a member of",
    "the cast member of"
]:
    sorted_scores, sorted_docs = retriever.retrieve_one_query(query)
    print(query, sorted_docs[:10])

an example of ['an example of', 'example of', 'is an example of', 'was an example of', 'exemplified by', 'sample of', 'has instances of', 'such as', 'one of', 'for illustration of']
a part of ['part of', 'is a part of', 'forms a part of', 'is part of', 'forms part of', 'formed part of', 'form part of', 'was a part of', 'was part of', 'are part of']
follow ['follow', 'follows', 'followed', 'followed in', 'followed by', 'followed from', 'is followed by', 'followed into', 'had following on', 'follower of']
an instance of ['has instances of', 'is an example of', 'an example of', 'was an example of', 'example of', 'is a subclass of', 'is a subtype of', 'is a subset of', 'is a type of', 'is an entity of']
named commander in chief ['commander in chief of', 'is commander in chief of', 'became commander of', 'was commander of', 'placed in command of', 'named as president of', 'was a commander of', 'named president of', 'was military commander in', 'named to head']
a member of ['is member of', '

In [9]:
for query in ['influenced', 'occurred in', 'caused', 'helped', 'the symbol of', 'the owner of', 'went to']:
    sorted_scores, sorted_docs = retriever.retrieve_one_query(query)
    print(query, sorted_docs[:10])

influenced ['influenced', 'influenced by', 'was influenced by', 'had influence in', 'faced influence from', 'was influential in', 'influences', 'exerted influence over', 'had influence over', 'came under influence of']
occurred in ['occurred in', 'occurred at', 'happened in', 'occurred during', 'occurred on', 'took place in', 'occurred within', 'occurred', 'occurred near', 'occur in']
caused ['caused', 'caused by', 'was caused by', 'resulted from', 'brought about by', 'reportedly caused', 'caused damage of', 'cause of', 'are caused by', 'led to by']
helped ['helped', 'helped by', 'provided assistance for', 'received help from', 'aided in', 'gave aid to', 'helped lead', 'helped bring', 'offered help to', 'helped organize']
the symbol of ['symbol of', 'symbol is', 'has symbol', 'is symbol of', 'symbolized by', 'was a symbol of', 'symbolizes', 'symbolises', 'symbolize', 'logo symbolized']
the owner of ['owner of', 'is the owner of', 'is owner of', 'owner is', 'owner', 'proprietor of', 'ha

In [10]:
for query in ['about', 'talk about', 'wrote a book about']:
    sorted_scores, sorted_docs = retriever.retrieve_one_query(query)
    print(query, sorted_docs[:10])

about ['about', 'are about', 'is about', 'exists about', 'regarding', 'related to', 'has information about', 'around', 'written about', 'contains information about']
talk about ['discussed about', 'talks to', 'spoke of', 'discussing', 'discusses', 'discussion of', 'discuss', 'speaks of', 'discussed', 'talked with']
wrote a book about ['wrote book on', 'wrote books on', 'wrote about', 'has written books on', 'has written about', 'published book entitled', 'written about', 'authored books on', 'is a book about', 'published book in']
