In [2]:
%cd ../

/home/qwj/code/HippoRAG


In [3]:
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple
from abc import ABC, abstractmethod
from typing import List, Tuple
from colbert import Searcher, Indexer
from colbert.data import Queries
from colbert.infra import RunConfig, Run, ColBERTConfig
from transformers import AutoModel, AutoTokenizer
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [30]:
musique_ner = pd.read_csv("output/musique_queries.named_entity_output.tsv", sep='\t')
musique_ner.head()

Unnamed: 0.1,Unnamed: 0,id,paragraphs,question,question_decomposition,answer,answer_aliases,answerable,triples
0,0,2hop__13548_13529,"[{'idx': 0, 'title': 'Lionel Messi', 'paragrap...",When was the person who Messi's goals in Copa ...,"[{'id': 13548, 'question': ""To whom was Messi'...",June 1982,[],True,"{""named_entities"": [""Messi"", ""Copa del Rey"", ""..."
1,1,3hop1__9285_5188_23307,"[{'idx': 0, 'title': 'Member states of NATO', ...",What month did the Tripartite discussions begi...,"[{'id': 9285, 'question': 'What was the nobili...",mid-June,[],True,"{""named_entities"": [""Britain"", ""France"", ""Wars..."
2,2,2hop__766973_770570,"[{'idx': 0, 'title': 'Minsk Region', 'paragrap...",What county is Erik Hort's birthplace a part of?,"[{'id': 766973, 'question': 'Erik Hort >> plac...",Rockland County,"['Rockland County, New York']",True,"{""named_entities"": [""Erik Hort""]}"
3,3,2hop__170823_120171,"[{'idx': 0, 'title': 'Blast Corps', 'paragraph...",What year did the publisher of Labyrinth end?,"[{'id': 170823, 'question': 'Labyrinth >> publ...",1986,[],True,"{""named_entities"": [""Labyrinth""]}"
4,4,2hop__511454_120259,"[{'idx': 0, 'title': 'Kavangoland', 'paragraph...",When was Lady Godiva's birthplace abolished?,"[{'id': 511454, 'question': 'Lady Godiva >> pl...",918,[],True,"{""named_entities"": [""Lady Godiva's birthplace""]}"


In [36]:
q = "This is an instance of who was named commander in chief of texas forces by the new national government of texas National Forest?"
musique_ner.query('question == @q')

Unnamed: 0.1,Unnamed: 0,id,paragraphs,question,question_decomposition,answer,answer_aliases,answerable,triples
751,751,2hop__83460_456238,"[{'idx': 0, 'title': 'Texian Army', 'paragraph...",This is an instance of who was named commander...,"[{'id': 83460, 'question': 'who was named comm...",United States National Forest,"['forest', 'forests', 'Forest', 'National Fore...",True,"{""named_entities"": [""commander in chief"", ""tex..."


In [31]:
pd.set_option('display.max_colwidth', 300)
musique_ner[musique_ner['question'].str.contains("about")][['id','question']]

Unnamed: 0,id,question
22,2hop__159215_779396,Where was the person who wrote about the rioting being a dividing factor in Birmingham educated?
54,2hop__14_59409,Who was Beyonce's husband talking about in the the song Cry?
327,3hop1__622497_160088_821792,Who wrote a book about growing up in the same nationality as the man who produced The Wild Women of Chastity Gulch?
444,2hop__50199_59409,Who was the person talking at the beginning of Thriller by Fall Out Boy talking about in Song Cry?
458,2hop__347_59409,Who was the person who acquired the parent company of the music service Beyonce owns part of talking about in the song Cry?
537,2hop__69_59409,"Who was the artist who is associated with Beyonce's premiere solo recording talking about in the song ""Cry""?"
648,2hop__80_59409,"Who was the artist who did a duet with Beyonce in the single ""Deja Vu"" talking about in the song ""Cry""?"


In [4]:
all_seed_qs = []
for i, row in musique_ner.iterrows():
    decomposed_qs = [q['question'] for q in eval(row["question_decomposition"])]
    all_seed_qs.extend(decomposed_qs)

## build kb

In [4]:
import sys
sys.path.append('./src')
from kb_utils import ExperimentConfig, KnowledgeBase

config = ExperimentConfig(
    dataset='musique',
    graph_type='facts_and_sim',
    retrieval_model_name='colbertv2',
    extraction_model_name='gpt-3.5-turbo-1106',
    base_dir='./output/musique_gpt'
)
kb = KnowledgeBase.build_from_config(config)
print(kb)

building knowledge graph: 100%|██████████| 298594/298594 [00:01<00:00, 165166.09it/s]

KnowledgeBase(91729 entities, 22222 relations, 298594 triplets)





In [6]:
all_relations = pd.DataFrame(list(kb.relations_to_id.keys()))
all_relations.to_csv("./output/musique_gpt/relations.tsv", sep='\t', index=True, header=False)

## retriever


In [5]:
import spacy
# from nltk.corpus import wordnet as wn
from nltk.stem import PorterStemmer
stemmer = PorterStemmer()

# 加载英文模型，对于中文可以使用 'zh_core_web_sm'
nlp = spacy.load('en_core_web_sm')


def extract_content_words(query):
    doc = nlp(query)
    lemmas = [token.lemma_ for token in doc if token.pos_ in ['NOUN', 'VERB', 'ADJ', 'ADV']]
    stemmed_content_words = [stemmer.stem(word) for word in lemmas]
    return stemmed_content_words

def contains_content_word(text, lemmatized_content_words):
    doc = nlp(text)
    text_words_lemmatized = set(token.lemma_.lower() for token in doc if token.is_alpha)
    text_words_stemmed = set(stemmer.stem(token) for token in text_words_lemmatized)
    return any(word in text_words_stemmed for word in lemmatized_content_words)

In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="4,5"

In [4]:
class RetrieverBase(ABC):
    @abstractmethod
    def build_index(cls, collection_path: str, *args, **kwargs):
        pass

    @abstractmethod
    def retrieve_one_query(self, query: str) -> List[float]:
        pass

    @abstractmethod
    def retrieve_many_top_one(self, query_list: List[str]) -> Tuple[List[int], List[float]]:
        pass


In [5]:
class ColBertRetriever(RetrieverBase):
    def __init__(self, root, experiment) -> None:
        super().__init__()
        with Run().context(RunConfig(nranks=1, experiment=experiment, root=root)):
            config = ColBERTConfig(root=root)
            self.searcher = Searcher(index='nbits_2', config=config, verbose=0)
        self.corpus_size = len(self.searcher.collection.data)
        # if self.searcher.collection:
        #     pass

    @classmethod
    def build_index(cls, collection_path: str, root_path: str, experiment: str, checkpoint_path: str=None):
        if checkpoint_path is None:
            checkpoint_path = "exp/colbertv2.0"
        if not collection_path.endswith('tsv'):
            raise NotImplementedError
        with Run().context(RunConfig(nranks=1, experiment=experiment, root=root_path)):
            config = ColBERTConfig(
                nbits=2,
                root=root_path,
            )
            indexer = Indexer(checkpoint=checkpoint_path, config=config)
            indexer.index(name=f"nbits_2", collection=collection_path, overwrite=True)
        
    def retrieve_one_query(self, query: str):
        query_doc_scores = np.zeros(self.corpus_size)
        queries = Queries(path=None, data={0: query})
        # ranking = self.searcher.search_all(queries, k=self.corpus_size)
        ranking = self.searcher.search(queries, k=self.corpus_size)
        for doc_id, doc,score in zip(ranking[0], ranking[2]):
            query_doc_scores[doc_id] = score
        return query_doc_scores
    
    def get_colbert_max_score(self, query):
        queries_ = [query]
        encoded_query = self.searcher.encode(queries_, full_length_search=False)
        encoded_doc = self.searcher.checkpoint.docFromText(queries_).float()
        max_score = encoded_query[0].matmul(encoded_doc[0].T).max(dim=1).values.sum().detach().cpu().numpy()
        return max_score

    def retrieve_many_top_one(self, query_list: List[str]) -> Tuple[List[int], List[float]]:
        phrase_ids = []
        phrases = []
        max_scores = []
        # def get_colbert_max_score(self, query):
        #     queries_ = [query]
        #     encoded_query = self.phrase_searcher.encode(queries_, full_length_search=False)
        #     encoded_doc = self.phrase_searcher.checkpoint.docFromText(queries_).float()
        #     max_score = encoded_query[0].matmul(encoded_doc[0].T).max(dim=1).values.sum().detach().cpu().numpy()
        #     return max_score

        for query in query_list:
            queries = Queries(path=None, data={0: query})

            queries_ = [query]
            encoded_query = self.searcher.encode(queries_, full_length_search=False)

            max_score = self.get_colbert_max_score(query) # qd都是自己的情况下，最大相似度分数
            # with suppress_tqdm():
            phrase_id, rank, score = self.searcher.search(query, k=1)
            # for phrase_id, rank, score in ranking.data[0]:
            phrase_id, rank, score = phrase_id[0], rank[0], score[0]
            phrase = self.searcher.collection.data[phrase_id]
            phrases_ = [phrase]
            phrases.append(phrase)
            encoded_doc = self.searcher.checkpoint.docFromText(phrases_).float()
            real_score = encoded_query[0].matmul(encoded_doc[0].T).max(dim=1).values.sum().detach().cpu().numpy()

            phrase_ids.append(phrase_id)
            # max_scores.append(score)
            max_scores.append(real_score / max_score)
        return phrase_ids, phrases, max_scores

In [9]:
ColBertRetriever.build_index("output/musique_gpt/relations.tsv", root_path="./data/lm_vectors/colbert/musique_gpt", experiment="musique_gpt")



[Jul 01, 13:17:18] #> Note: Output directory ./data/lm_vectors/colbert/musique_gpt/musique_gpt/indexes/nbits_2 already exists


[Jul 01, 13:17:18] #> Will delete 1 files already at ./data/lm_vectors/colbert/musique_gpt/musique_gpt/indexes/nbits_2 in 20 seconds...
#> Starting...
nranks = 1 	 num_gpus = 2 	 device=0
{
    "query_token_id": "[unused0]",
    "doc_token_id": "[unused1]",
    "query_token": "[Q]",
    "doc_token": "[D]",
    "ncells": null,
    "centroid_score_threshold": null,
    "ndocs": null,
    "load_index_with_mmap": false,
    "index_path": null,
    "index_bsize": 64,
    "nbits": 2,
    "kmeans_niters": 20,
    "resume": false,
    "similarity": "cosine",
    "bsize": 64,
    "accumsteps": 1,
    "lr": 1e-5,
    "maxsteps": 400000,
    "save_every": null,
    "warmup": 20000,
    "warmup_bert": null,
    "relu": false,
    "nway": 64,
    "use_ib_negatives": true,
    "reranker": false,
    "distillation_alpha": 1.0,
    "ignore_scores": false,
    "model_name": 



  Iteration 19 (0.27 s, search 0.20 s): objective=36869.1 imbalance=1.414 nsplit=0       
[Jul 01, 13:17:57] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Jul 01, 13:17:58] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[0.039, 0.04, 0.042, 0.038, 0.039, 0.04, 0.038, 0.039, 0.039, 0.04, 0.041, 0.038, 0.039, 0.04, 0.039, 0.045, 0.037, 0.037, 0.039, 0.038, 0.041, 0.042, 0.037, 0.04, 0.038, 0.039, 0.037, 0.04, 0.037, 0.038, 0.04, 0.039, 0.041, 0.036, 0.039, 0.036, 0.035, 0.039, 0.039, 0.039, 0.041, 0.039, 0.04, 0.041, 0.039, 0.04, 0.039, 0.042, 0.04, 0.037, 0.041, 0.037, 0.041, 0.04, 0.038, 0.039, 0.042, 0.039, 0.042, 0.04, 0.038, 0.039, 0.04, 0.043, 0.041, 0.041, 0.04, 0.043, 0.036, 0.039, 0.041, 0.039, 0.04, 0.039, 0.039, 0.041, 0.039, 0.039, 0.04, 0.041, 0.039, 0.039, 0.042, 0.04, 0.04, 0.037, 0.039, 0.04, 0.038, 0.039, 0.038, 0.04, 0.038, 0.043, 0.041, 0.043, 0.045, 0.038, 0

0it [00:00, ?it/s]

[Jul 01, 13:18:01] [0] 		 #> Saving chunk 0: 	 22,222 passages and 136,573 embeddings. From #0 onward.
[Jul 01, 13:18:01] [0] 		 #> Checking all files were saved...
[Jul 01, 13:18:01] [0] 		 Found all files!
[Jul 01, 13:18:01] [0] 		 #> Building IVF...
[Jul 01, 13:18:01] [0] 		 #> Loading codes...
[Jul 01, 13:18:01] [0] 		 Sorting codes...
[Jul 01, 13:18:01] [0] 		 Getting unique codes...
[Jul 01, 13:18:01] #> Optimizing IVF to store map from centroids to list of pids..
[Jul 01, 13:18:01] #> Building the emb2pid mapping..
[Jul 01, 13:18:01] len(emb2pid) = 136573
[Jul 01, 13:18:02] #> Saved optimized IVF to ./data/lm_vectors/colbert/musique_gpt/musique_gpt/indexes/nbits_2/ivf.pid.pt
[Jul 01, 13:18:02] [0] 		 #> Saving the indexing metadata to ./data/lm_vectors/colbert/musique_gpt/musique_gpt/indexes/nbits_2/metadata.json ..


1it [00:03,  3.24s/it]
100%|██████████| 1/1 [00:00<00:00, 1906.50it/s]
100%|██████████| 4096/4096 [00:00<00:00, 134641.64it/s]


#> Joined...


In [6]:
retriever = ColBertRetriever(root="./data/lm_vectors/colbert/musique_gpt", experiment="musique_gpt")

[Jul 03, 08:07:33] #> Loading collection...
0M 
[Jul 03, 08:07:44] #> Loading codec...
[Jul 03, 08:07:44] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Jul 03, 08:07:45] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Jul 03, 08:07:45] #> Loading IVF...
[Jul 03, 08:07:45] #> Loading doclens...


100%|██████████| 1/1 [00:00<00:00, 1170.29it/s]

[Jul 03, 08:07:45] #> Loading codes and residuals...



100%|██████████| 1/1 [00:00<00:00, 48.28it/s]


In [15]:
def get_similar_relations(query: str):
    rel_ids, _, scores = retriever.searcher.search(query, k=100)
    relation_names = [retriever.searcher.collection.data[id] for id in rel_ids]
    # relation_names, scores = get_similar_relations('birthplace')
    max_score = retriever.get_colbert_max_score(query)
# print(relation_names[:10])
    print("\n".join([f"{n}, {s/max_score:.4f}"  for n, s in zip(relation_names[:10], scores[:10])]))
    # return relation_names, scores

In [17]:
for query in [
    "birthplace",
    "immediately north of",
    "is located",
    "named after",
    "is capital of",
    "the headquarters of",
    
]:
    get_similar_relations(query)

birthplace of, 0.9035
has its birthplace in, 0.8115
tito s birthplace in, 0.7541
birth place, 0.7286
born in, 0.7140
born at, 0.7097
birth name was, 0.7064
places birth of, 0.7010
is the place of birth of, 0.7005
place of birth is, 0.6983
directly north of, 0.7827
is to the north of, 0.7493
located north of, 0.7461
located just north of, 0.7450
located about north of, 0.7413
is situated approximately north of, 0.7408
is located north of, 0.7350
is located just north of, 0.7334
located to the north of, 0.7323
extends north of, 0.7275
is located, 0.9926
is located in, 0.9761
is located on, 0.9631
is located at, 0.9631
is located approximately, 0.9580
is located on the, 0.9563
is located about, 0.9557
is located within, 0.9523
is located around, 0.9500
is located at a distance of, 0.9495
named after, 0.9912
is named after, 0.9646
are named after, 0.9586
was named after, 0.9440
named it after, 0.9429
named after the song, 0.9021
camp david is named after, 0.8749
renamed after, 0.8418
renam

In [19]:
for query in [
    "birthplace",
    "immediately north of",
    "is located",
    "named after",
    "is capital of",
    "the headquarters of",
    "an example of",
    "a part of",
    "follow",
    "an instance of",
    "named commander in chief",
    "a member of",
    "the cast member of",
    'influenced', 'occurred in', 'caused', 'helped', 'the symbol of', 'the owner of', 'went to',
    'about', 'talk about', 'wrote a book about'
    
]:
    print(query.upper(), "======")
    get_similar_relations(query)
    

birthplace of, 0.9035
has its birthplace in, 0.8115
tito s birthplace in, 0.7541
birth place, 0.7286
born in, 0.7140
born at, 0.7097
birth name was, 0.7064
places birth of, 0.7010
is the place of birth of, 0.7005
place of birth is, 0.6983
directly north of, 0.7827
is to the north of, 0.7493
located north of, 0.7461
located just north of, 0.7450
located about north of, 0.7413
is situated approximately north of, 0.7408
is located north of, 0.7350
is located just north of, 0.7334
located to the north of, 0.7323
extends north of, 0.7275
is located, 0.9926
is located in, 0.9761
is located on, 0.9631
is located at, 0.9631
is located approximately, 0.9580
is located on the, 0.9563
is located about, 0.9557
is located within, 0.9523
is located around, 0.9500
is located at a distance of, 0.9495
named after, 0.9912
is named after, 0.9646
are named after, 0.9586
was named after, 0.9440
named it after, 0.9429
named after the song, 0.9021
camp david is named after, 0.8749
renamed after, 0.8418
renam

In [16]:


def get_similar_relations_v2(query: str):
    rel_ids, _, scores = retriever.searcher.search(query, k=100)
    relation_names = [retriever.searcher.collection.data[id] for id in rel_ids]
    max_score = retriever.get_colbert_max_score(query)
    query_words_l = extract_content_words(query)

    filtered_relations = []
    original_indices = []

    # Collect filtered relations and their original indices
    for index, (name, score) in enumerate(zip(relation_names, scores)):
        if not contains_content_word(name, query_words_l):
            filtered_relations.append((name, score))
            original_indices.append(index)

    # Print formatted output including original rank
    formatted_output = []
    cnt = 0
    for (name, score), idx in zip(filtered_relations, original_indices):
        formatted_output.append(f"{idx+1}\t {name}\t {score/max_score:.4f}")
        cnt += 1
        if cnt == 10:
            break

    print("\n".join(formatted_output))

In [17]:
get_similar_relations_v2('is located')

48	 is situated in	 0.8808
66	 is situated approximately north of	 0.8627
74	 is situated near	 0.8576
86	 is where	 0.8513
89	 is situated on	 0.8491


In [18]:
for query in [
    "birthplace",
    "immediately north of",
    "is located",
    "named after",
    "is capital of",
    "the headquarters of",
    "an example of",
    "a part of",
    "follow",
    "an instance of",
    "named commander in chief",
    "a member of",
    "the cast member of",
    'influenced', 'occurred in', 'caused', 'helped', 'the symbol of', 'the owner of', 'went to',
    'about', 'talk about', 'wrote a book about'
    
]:
    print(query.upper(), "======")
    get_similar_relations_v2(query)
    

4	 birth place	 0.7286
5	 born in	 0.7140
6	 born at	 0.7097
7	 birth name was	 0.7064
8	 places birth of	 0.7010
9	 is the place of birth of	 0.7005
10	 place of birth is	 0.6983
11	 born on	 0.6945
12	 birth name	 0.6945
13	 birth date	 0.6940
43	 located at the northern end of	 0.6533
47	 is to the northwest of	 0.6443
60	 is located northwest of	 0.6194
61	 located to the northwest of	 0.6188
62	 located northwest of	 0.6178
69	 is located west northwest of	 0.6082
70	 located northeast of	 0.6024
71	 located to the northwest west of	 0.6013
72	 are located northeast of	 0.5997
73	 is located west of	 0.5944
48	 is situated in	 0.8808
66	 is situated approximately north of	 0.8627
74	 is situated near	 0.8576
86	 is where	 0.8513
89	 is situated on	 0.8491
8	 renamed after	 0.8418
9	 renamed itself after	 0.8211
10	 founded after	 0.7853
13	 established after	 0.7592
14	 born after	 0.7576
19	 termed after	 0.7304
20	 dubbed after	 0.7266
21	 after	 0.7233
22	 started after	 0.7201

In [20]:
relation_names, scores = get_similar_relations('immediately north of')
relation_names[:10]

['directly north of',
 'is to the north of',
 'located north of',
 'located just north of',
 'located about north of',
 'is situated approximately north of',
 'is located north of',
 'is located just north of',
 'located to the north of',
 'extends north of']

In [21]:
relation_names, scores = get_similar_relations('is located')
relation_names[:10]

['is located',
 'is located in',
 'is located on',
 'is located at',
 'is located approximately',
 'is located on the',
 'is located about',
 'is located within',
 'is located around',
 'is located at a distance of']

In [22]:
relation_names, scores = get_similar_relations('named after')
relation_names[:10]

['named after',
 'is named after',
 'are named after',
 'was named after',
 'named it after',
 'named after the song',
 'camp david is named after',
 'renamed after',
 'renamed itself after',
 'founded after']

In [23]:
relation_names, scores = get_similar_relations('the capital of ')
relation_names[:10]

['capital of',
 'is the capital of',
 'is capital of',
 'is the capital city of',
 'declared as capital of',
 'current capital of',
 'capital',
 'capital and largest city is',
 'capital in',
 'capital is']

In [25]:
relation_names, scores = get_similar_relations('the headquarters of')
relation_names[:10]

['headquarters of',
 'is headquarters of',
 'headquarters at',
 'headquarters in',
 'headquarters is',
 'headquarters',
 'headquarters based in',
 'headquarters on',
 'headquarters located in',
 'headquarters located at']

In [47]:
for query in ['an example of', 'a part of', 'follow', 'an instance of', 'named commander in chief', 'a member of', 'the cast member of']:
    relation_names, scores = get_similar_relations(query)
    print(query, relation_names[:10])

an example of ['is an example of', 'an example of', 'example of', 'was an example of', 'such as', 'requires proof of age such as', 'includes frazioni such as', 'is surrounded by soil associations such as', 'divided into major branches such as', 'is surrounded by rock formations such as']
a part of ['part of', 'is a part of', 'is part of', 'forms a part of', 'is part of the scope of', 'are part of', 'form part of', 'considered as part of', 'forms part of', 'part of scheme in']
follow ['follow', 'follows', 'followed', 'followed in', 'followed from', 'is followed by', 'follows course to', 'follows up with', 'followed by', 'follow up to']
an instance of ['is an example of', 'an example of', 'has instances of', 'example of', 'was an example of', 'such as', 'is observer of', 'observed', 'attested in', 'attested by']
named commander in chief ['commander in chief of', 'is commander in chief of', 'commander of', 'commander in', 'commander', 'chief of', 'had commander', 'was commander of', 'beca

In [58]:
for query in ['influenced', 'occurred in', 'caused', 'helped', 'the symbol of', 'the owner of', 'went to']:
    relation_names, scores = get_similar_relations(query)
    print(query, relation_names[:10])

influenced ['influenced', 'influenced by', 'influences include', 'influences', 'inspired', 'was influenced by', 'inspired by', 'inspired from', 'combined influences from', 'influenced the demographics of']
occurred in ['occurred in', 'occurred at', 'occurred within', 'occurred during', 'occurred on', 'occurred from', 'occurred around', 'has occurred in', 'occurred', 'occurred during the term of']
caused ['caused', 'caused by', 'cause of', 'causes', 'cause', 'are caused by', 'was caused by', 'caused damage of', 'caused major damage in', 'cause of death']
helped ['helped', 'helped by', 'helps', 'helped found', 'helped pressure', 'helped develop', 'helped form', 'helped to develop', 'helps bring together', 'helped bring']
the symbol of ['symbol of', 'is symbol of', 'was a symbol of', 'symbol is', 'symbolises', 'symbolizes', 'symbolize', 'symbolized by', 'signifies', 'denoted by']
the owner of ['owner of', 'is the owner of', 'is owner of', 'owner', 'proprietor of', 'owner is', 'owns', 'fou

In [36]:
for query in ['about', 'talk about', 'wrote a book about']:
    relation_names, scores = get_similar_relations(query)
    print(query, relation_names[:10])

about ['is about', 'about', 'are about', 'regarding', 'and', 'is located about', 'has information about', 'exists about', 'defines', 'meaning']
talk about ['said in interview about', 'talks to', 'discussed about', 'discuss', 'speaking', 'argues about', 'discussing', 'discusses', 'talked with', 'delivered speech about']
wrote a book about ['writes novels about', 'wrote book on', 'wrote books on', 'is a book written by', 'is a book about', 'wrote about', 'is a novel written by', 'has written books on', 'writes about', 'written about']


## Contriever

In [7]:
class Contriever:
    def __init__(self, retrieval_model_name) -> None:
        super().__init__()
        self.retrieval_model = AutoModel.from_pretrained(retrieval_model_name).to('cuda')
        self.tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)

    def index(self, corpus: List[str]):
        BATCH_SIZE = 256
        all_embeddings = []
        
        for i in tqdm(range(0, len(corpus), BATCH_SIZE), total=len(corpus)//BATCH_SIZE):
            batch = corpus[i:i + BATCH_SIZE]
            with torch.no_grad():
                encoding = self.tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
                input_ids = encoding['input_ids'].to('cuda')
                attention_mask = encoding['attention_mask'].to('cuda')
                outputs = self.retrieval_model(input_ids, attention_mask=attention_mask)
                embeddings = self._mean_pooling(outputs[0], attention_mask)
                embeddings = embeddings.T.divide(torch.linalg.norm(embeddings, dim=1)).T
                all_embeddings.append(embeddings.cpu())

        self.docs = corpus
        self.doc_embedding_mat = torch.cat(all_embeddings, dim=0)


    def _mean_pooling(self, token_embeddings, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings

    def get_embedding_with_mean_pooling(self, input_str):
        with torch.no_grad():
            encoding = self.tokenizer(input_str, return_tensors='pt', padding=True, truncation=True)
            input_ids = encoding['input_ids']
            attention_mask = encoding['attention_mask']
            input_ids = input_ids.to('cuda')
            attention_mask = attention_mask.to('cuda')
            outputs = self.retrieval_model(input_ids, attention_mask=attention_mask)
            embeddings = self._mean_pooling(outputs[0], attention_mask)
            embeddings = embeddings.T.divide(torch.linalg.norm(embeddings, dim=1)).T

            return embeddings
    
    def retrieve_one_query(self, query: str) -> List[Tuple[float, str]]:
        query_embedding = self.get_embedding_with_mean_pooling(query).cpu().numpy()
        query_doc_scores = np.dot(self.doc_embedding_mat, query_embedding.T)
        query_doc_scores = query_doc_scores.T[0]
        sorted_indices = np.argsort(query_doc_scores)[::-1]
        sorted_scores = [query_doc_scores[idx] for idx in sorted_indices]
        sorted_docs = [self.docs[idx] for idx in sorted_indices]

        return sorted_scores, sorted_docs

In [8]:
all_relations = pd.read_csv("./output/musique_gpt/relations.tsv", sep='\t', header=None)
all_relations.head()

Unnamed: 0,0,1
0,0,forced to unite to save
1,1,walks out on
2,2,entire forest was divided between
3,3,is a neighbourhood in
4,4,organizes church services inspired by


In [9]:
all_relations_names = all_relations[1].tolist()
len(all_relations_names)

22222

In [17]:
!ls /data/qwj/model/contriever

config.json  special_tokens_map.json  tokenizer.json
README.md    tokenizer_config.json    vocab.txt


In [10]:
contriever = Contriever("/data/qwj/model/contriever")
contriever.index(all_relations_names)

87it [00:03, 22.16it/s]                        


In [11]:
contriever.doc_embedding_mat.shape

torch.Size([22222, 768])

In [11]:
def get_similar_relations_v3(query: str, retriever):
    print(query, "=============")
    scores, relation_names = retriever.retrieve_one_query(query)
    query_words_l = extract_content_words(query)

    filtered_relations = []
    original_indices = []
    print("\n".join([f"{n}, {s:.4f}"  for n, s in zip(relation_names[:10], scores[:10])]))
    print("\t=============")
    # Collect filtered relations and their original indices
    for index, (name, score) in enumerate(zip(relation_names, scores)):
        if not contains_content_word(name, query_words_l):
            filtered_relations.append((name, score))
            original_indices.append(index)
            print(f"{index+1}\t {name}\t {score:.4f}")
            if len(original_indices) >= 10:
                return

    # # Print formatted output including original rank
    # formatted_output = []
    # cnt = 0
    # for (name, score), idx in zip(filtered_relations, original_indices):
    #     formatted_output.append(f"{idx+1}\t {name}\t {score:.4f}")
    #     cnt += 1
    #     if cnt == 10:
    #         break

    # print("\n".join(formatted_output))

In [17]:
for query in [
    "birthplace",
    "immediately north of",
    "is located",
    "named after",
    "is capital of",
    "the headquarters of",
    "an example of",
    "a part of",
    "follow",
    "an instance of",
    "named commander in chief",
    "a member of",
    "the cast member of",
    'influenced', 'occurred in', 'caused', 'helped', 'the symbol of', 'the owner of', 'went to',
    'about', 'talk about', 'wrote a book about'
    
]:
    print(query.upper(), "======")
    get_similar_relations_v3(query, contriever)

birthplace of, 0.9312
is hometown of, 0.8237
has its birthplace in, 0.7746
home of, 0.7614
was location of, 0.7122
is home of, 0.7107
is seat of, 0.7100
is the home of, 0.7010
was the location of, 0.7005
home to, 0.6925
2	 is hometown of	 0.8237
4	 home of	 0.7614
5	 was location of	 0.7122
6	 is home of	 0.7107
7	 is seat of	 0.7100
8	 is the home of	 0.7010
9	 was the location of	 0.7005
10	 home to	 0.6925
11	 is the seat of	 0.6897
12	 is location of	 0.6865
directly north of, 0.8527
north of, 0.7335
occupies the land directly east of, 0.6898
located north of, 0.6828
south of, 0.6739
is located north of, 0.6677
is north of, 0.6645
located north northwest of, 0.6544
west of, 0.6499
located west of, 0.6428
3	 occupies the land directly east of	 0.6898
5	 south of	 0.6739
9	 west of	 0.6499
10	 located west of	 0.6428
12	 located south of	 0.6426
13	 to the west of	 0.6423
16	 is located west of	 0.6371
17	 located east of	 0.6344
18	 is south of	 0.6331
19	 northwest of	 0.6331
is lo

In [26]:
for query in [
    "birthplace",
    "immediately north of",
    "is located",
    "named after",
    "is capital of",
    "the headquarters of"
]:
    sorted_scores, sorted_docs = contriever.retrieve_one_query(query)
    print(query, sorted_docs[:10])

birthplace ['birthplace of', 'is hometown of', 'has its birthplace in', 'home of', 'was location of', 'is home of', 'is seat of', 'is the home of', 'was the location of', 'home to']
immediately north of ['directly north of', 'north of', 'occupies the land directly east of', 'located north of', 'south of', 'is located north of', 'is north of', 'located north northwest of', 'west of', 'located west of']
is located ['is located', 'is located in', 'located', 'located in', 'is located on', 'is situated', 'is located near', 'is located between', 'located on', 'is located on the']
named after ['named after', 'is named after', 'was named after', 'is named', 'renamed after', 'is named for', 'named for', 'is named in', 'named in honor of', 'was named in honor of']
is capital of ['is capital of', 'is the capital of', 'is the capital city of', 'is largest city of', 'serves as capital of', 'is the largest city of', 'capital of', 'was capital of', 'is administrative capital of', 'is a city of']
the 

In [27]:
for query in [
    "an example of",
    "a part of",
    "follow",
    "an instance of",
    "named commander in chief",
    "a member of",
    "the cast member of"
]:
    sorted_scores, sorted_docs = contriever.retrieve_one_query(query)
    print(query, sorted_docs[:10])

an example of ['an example of', 'example of', 'is an example of', 'was an example of', 'illustrates', 'is a manifestation of', 'exemplified by', 'is a form of', 'are of', 'product of']
a part of ['part of', 'is a part of', 'was a part of', 'are part of', 'is part of', 'became a part of', 'was part of', 'forms a part of', 'were part of', 'has part']
follow ['follow', 'required to follow', 'follow up to', 'expected to follow', 'had following on', 'up', 'followed suit', 'follow up game to', 'follows', 'failed to follow through on']
an instance of ['an example of', 'example of', 'is an example of', 'was an example of', 'has instances of', 'is a manifestation of', 'is a form of', 'was a result of', 'result of', 'is typical of']
named commander in chief ['commander in chief of', 'commander in', 'assumed command of', 'commander of', 'is commander in chief of', 'became commander of', 'placed in command of', 'commander', 'rose to command of', 'was commander of']
a member of ['member of', 'membe

In [28]:
for query in ['influenced', 'occurred in', 'caused', 'helped', 'the symbol of', 'the owner of', 'went to']:
    sorted_scores, sorted_docs = contriever.retrieve_one_query(query)
    print(query, sorted_docs[:10])

influenced ['influenced', 'influenced by', 'was influenced by', 'influences', 'inspired', 'influential on', 'reflected the influence of', 'influential in', 'had influence in', 'influences include']
occurred in ['occurred in', 'occurred', 'occurred during', 'occurred between', 'occurred after', 'occurred near', 'occurred on', 'occurred at', 'occurred from', 'occurred around']
caused ['caused', 'was caused by', 'caused by', 'caused damage of', 'cause', 'caused the death of', 'affected', 'occurred', 'prevented', 'resulted from']
helped ['helped', 'helped form', 'helped found', 'helped bring', 'helped establish', 'helped develop', 'helped to develop', 'helped to produce', 'aided in', 'helped lead']
the symbol of ['symbol of', 'is symbol of', 'was a symbol of', 'symbolize', 'symbolizes', 'symbolized by', 'has symbol', 'symbolises', 'symbol is', 'logo symbolized']
the owner of ['owner of', 'owner', 'is the owner of', 'is owner of', 'owner is', 'has owner', 'owns', 'became owner of', 'founder

In [32]:
for query in ['about', 'talk about', 'wrote a book about']:
    sorted_scores, sorted_docs = contriever.retrieve_one_query(query)
    print(query, sorted_docs[:10])

about ['about', 'is about', 'are about', 'discussed about', 'read about', 'found out about', 'heard about', 'study about', 'story about', 'wrote about']
talk about ['about', 'discussed about', 'discussing', 'discuss', 'read about', 'is about', 'heard about', 'are about', 'talks to', 'discussion of']
wrote a book about ['wrote book on', 'wrote biography of', 'wrote about', 'wrote books on', 'writes novels about', 'written about', 'has written about', 'is a book about', 'made documentary about', 'writes about']


## Bert CLS


In [6]:
class BERTCLS:
    def __init__(self, retrieval_model_name) -> None:
        super().__init__()
        self.retrieval_model = AutoModel.from_pretrained(retrieval_model_name).to('cuda')
        self.tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)

    def index(self, corpus: List[str]):
        BATCH_SIZE = 256
        all_embeddings = []
        
        for i in tqdm(range(0, len(corpus), BATCH_SIZE), total=len(corpus)//BATCH_SIZE):
            batch = corpus[i:i + BATCH_SIZE]
            with torch.no_grad():
                encoding = self.tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
                input_ids = encoding['input_ids'].to('cuda')
                attention_mask = encoding['attention_mask'].to('cuda')
                outputs = self.retrieval_model(input_ids, attention_mask=attention_mask)
                embeddings = self._cls_pooling(outputs[0], attention_mask)
                embeddings = embeddings.T.divide(torch.linalg.norm(embeddings, dim=1)).T
                all_embeddings.append(embeddings.cpu())

        self.docs = corpus
        self.doc_embedding_mat = torch.cat(all_embeddings, dim=0)


    def _cls_pooling(self, token_embeddings, mask):
        return token_embeddings[:, 0, :]

    def get_embedding_with_mean_pooling(self, input_str):
        with torch.no_grad():
            encoding = self.tokenizer(input_str, return_tensors='pt', padding=True, truncation=True)
            input_ids = encoding['input_ids']
            attention_mask = encoding['attention_mask']
            input_ids = input_ids.to('cuda')
            attention_mask = attention_mask.to('cuda')
            outputs = self.retrieval_model(input_ids, attention_mask=attention_mask)
            embeddings = self._cls_pooling(outputs[0], attention_mask)
            embeddings = embeddings.T.divide(torch.linalg.norm(embeddings, dim=1)).T

            return embeddings
    
    def retrieve_one_query(self, query: str) -> List[Tuple[float, str]]:
        query_embedding = self.get_embedding_with_mean_pooling(query).cpu().numpy()
        query_doc_scores = np.dot(self.doc_embedding_mat, query_embedding.T)
        query_doc_scores = query_doc_scores.T[0]
        sorted_indices = np.argsort(query_doc_scores)[::-1]
        sorted_scores = [query_doc_scores[idx] for idx in sorted_indices]
        sorted_docs = [self.docs[idx] for idx in sorted_indices]

        return sorted_scores, sorted_docs

In [7]:
bert_sf = BERTCLS('/data/MODELS/bert-base-uncased')

In [8]:
all_relations = pd.read_csv("./output/musique_gpt/relations.tsv", sep='\t', header=None)
all_relations_names = all_relations[1].tolist()
len(all_relations_names)

22222

In [9]:
bert_sf.index(all_relations_names)
bert_sf.doc_embedding_mat.shape

  0%|          | 0/86 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
87it [00:03, 22.51it/s]                        


torch.Size([22222, 768])

In [12]:
for query in [
    "birthplace",
    "immediately north of",
    "is located",
    "named after",
    "is capital of",
    "the headquarters of",
    "an example of",
    "a part of",
    "follow",
    "an instance of",
    "named commander in chief",
    "a member of",
    "the cast member of",
    'influenced', 'occurred in', 'caused', 'helped', 'the symbol of', 'the owner of', 'went to',
    'about', 'talk about', 'wrote a book about'
    
]:
    print(query.upper(), "======")
    get_similar_relations_v3(query, bert_sf)

residence, 0.9544
inhabited, 0.9501
place, 0.9467
preserve, 0.9467
teammate, 0.9467
inherited, 0.9465
successor to, 0.9455
governed, 0.9455
organized, 0.9448
entertained, 0.9448
1	 residence	 0.9544
2	 inhabited	 0.9501
3	 place	 0.9467
4	 preserve	 0.9467
5	 teammate	 0.9467
6	 inherited	 0.9465
7	 successor to	 0.9455
8	 governed	 0.9455
9	 organized	 0.9448
10	 entertained	 0.9448
directly north of, 0.9925
to the east of, 0.9815
located to the west of, 0.9787
to the west of, 0.9787
located north east of, 0.9783
located to the northwest of, 0.9777
located north west of, 0.9776
located about north of, 0.9770
located north of, 0.9769
located to the east of, 0.9769
2	 to the east of	 0.9815
3	 located to the west of	 0.9787
4	 to the west of	 0.9787
6	 located to the northwest of	 0.9777
10	 located to the east of	 0.9769
11	 located south west of	 0.9765
13	 located to the southeast of	 0.9759
14	 located to the northwest west of	 0.9754
15	 located west of	 0.9754
16	 located southerl

In [12]:
for query in [
    "birthplace",
    "immediately north of",
    "is located",
    "named after",
    "is capital of",
    "the headquarters of"
]:
    sorted_scores, sorted_docs = bert_sf.retrieve_one_query(query)
    print(query, sorted_docs[:10])

birthplace ['residence', 'inhabited', 'place', 'preserve', 'teammate', 'inherited', 'successor to', 'governed', 'organized', 'entertained']
immediately north of ['directly north of', 'to the east of', 'located to the west of', 'to the west of', 'located north east of', 'located to the northwest of', 'located north west of', 'located about north of', 'located north of', 'located to the east of']
is located ['is located', 'is situated', 'is located on', 'is located at', 'is located over', 'is situated on', 'is located about', 'is located in', 'is located within', 'was located on']
named after ['named after', 'named for', 'named after the song', 'named in honor of', 'is named in', 'renamed in honor of', 'was named on', 'was named in', 'is named by', 'was named for']
is capital of ['is capital of', 'is principal city of', 'is the capital of', 'is the capital city of', 'is the center of', 'was the capital of', 'is a principal city of', 'is administrative capital of', 'is administrative cent

In [13]:
for query in ['about', 'talk about', 'wrote a book about']:
    sorted_scores, sorted_docs = bert_sf.retrieve_one_query(query)
    print(query, sorted_docs[:10])

about ['about', 'of', 'regarding', 'describes', 'discusses', 'describe', 'is', 'describing', 'experiencing', 'boasts']
talk about ['accounted for', 'downplayed', 'focus on', 'emphasis on', 'promised friendship to', 'business was', 'settled with', 'vote for', 'time when', 'set up by']
wrote a book about ['wrote book on', 'wrote books on', 'has a book called', 'conducted a study on', 'was a pioneer in', 'enjoyed a career in', 'pursued a career in', 'is an author of', 'wrote about', 'is the author of']


In [14]:
for query in [
    "an example of",
    "a part of",
    "follow",
    "an instance of",
    "named commander in chief",
    "a member of",
    "the cast member of"
]:
    sorted_scores, sorted_docs = bert_sf.retrieve_one_query(query)
    print(query, sorted_docs[:10])

an example of ['an example of', 'example of', 'exemplified by', 'is an example of', 'main principle summarized by', 'given additional characteristics by', 'discussed by', 'equivalent to', 'visualized by', 'metaphor for']
a part of ['not part of', 'now part of', 'completely changed', 'broke apart with', 'no longer part of', 'filled by', 'erased from memories of', 'discarded into', 'broke apart in', 'shattered by']
follow ['follow', 'join', 'allow', 'take', 'followed', 'read', 'pursue', 'argue', 'put', 'use']
named commander in chief ['named president of', 'became commander of', 'appointed president of', 'placed in command of', 'served as commanding officer of', 'was commander of', 'named as president of', 'appointed manager of', 'served as general officer commanding', 'appointed to serve']
a member of ['member of', 'was a member of', 'not a member of', 'selected as a member of', 'editor of', 'life member of', 'publisher of', 'a branch of', 'researcher at', 'former fellow at']
the cast m

In [15]:
for query in ['influenced', 'occurred in', 'caused', 'helped', 'the symbol of', 'the owner of', 'went to']:
    sorted_scores, sorted_docs = bert_sf.retrieve_one_query(query)
    print(query, sorted_docs[:10])

influenced ['influenced', 'inspired', 'nicknamed', 'dominating', 'dominated', 'affect', 'resembled', 'headed', 'sparked', 'deadline']
occurred in ['occurred in', 'occurred during', 'occurred on', 'occurred from', 'occurred within', 'occurred after', 'occurred at', 'occurred near', 'flourished in', 'originated at']
caused ['caused', 'associated', 'causes', 'triggered', 'preventing', 'planted', 'encouraged', 'treated', 'suggested', 'prevents']
helped ['helped', 'helped pressure', 'needed', 'enjoyed', 'respected', 'ravaged', 'resented', 'brought', 'resembled', 'connected']
the symbol of ['at the center of', 'bound for', 'a branch of', 'framed by', 'hub for', 'centered on', 'important for', 'in the center of', 'visible in', 'banner was']
the owner of ['publisher of', 'leader of', 'head of', 'one of the scribes of', 'not a member of', 'target of', 'is the owner of', 'called the mother of', 'editor of', 'obliged to']
went to ['went to', 'goes to', 'traveled to', 'travelled to', 'moved to', '

: 