In [30]:
import json
import pathlib
import inflect

import numpy as np
import networkx as nx

from collections import defaultdict
from semantic_memory import taxonomy
from transformers import AutoTokenizer

In [31]:
# let's say we are interested in PaliGemma

tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-mix-224")

def check_in_vocab(word):
    with_space = tokenizer.tokenize(f" {word}", add_special_tokens=False)
    no_space = tokenizer.tokenize(word, add_special_tokens=False)
    if len(with_space) == 1 or len(no_space) == 1:
        return True
    else:
        return False
    
p = inflect.engine()

In [3]:
def read_json(filename):
    with open(filename, 'r') as f:
        return json.load(f)

In [4]:
noun_hypernyms = read_json("../data/gqa_entities/noun-hypernyms.json")
final_entities = read_json("../data/gqa_entities/entity_set.json")

In [46]:
hypernym_paths = defaultdict(set)

for noun, hypernyms in noun_hypernyms.items():
    hypernym_paths[noun].add(tuple(hypernyms))
    # each hypernym is the child of the next one
    for i in range(len(hypernyms) - 1):
        hypernym_paths[hypernyms[i]].add(tuple(hypernyms[i + 1:]))

# store only the longest paths
longest_paths = {}
for noun, paths in hypernym_paths.items():
    longest_paths[noun] = max(paths, key=len)

In [52]:
# make sure they are in the model
# longest_paths_model = {}
# for noun, path in longest_paths.items():
#     final_path = []
#     if check_in_vocab(noun):
#         for concept in path:
#             if check_in_vocab(concept):
#                 final_path.append(concept)
#         # if len(final_path) >= 1:
#         longest_paths_model[noun] = final_path


# now store the unique hypernym pairs
hypernym_pairs = {}
for noun, path in longest_paths.items():
    try:
        hypernym_pairs[noun] = path[0]
    except:
        print(noun, path)

def get_hypernym_model(word):
    try:
        hypernym = hypernym_pairs[word]
    except:
        hypernym = "entity"
    if not check_in_vocab(hypernym):
        hypernym = get_hypernym_model(hypernym)
    return hypernym

final_pairs = {}
for noun in set(hypernym_pairs.keys()).union(set(hypernym_pairs.values())):
    if check_in_vocab(noun):
        final_pairs[noun] = get_hypernym_model(noun)

In [56]:
final_pairs['monitor']

'device'

In [57]:
# Create a tree with "entity" as its root node.
Tree = taxonomy.Nodeset(taxonomy.Node)
root = Tree['entity']

# # populate the tree

for concept, hypernym in final_pairs.items():
    if concept in final_entities and hypernym in final_entities:
        node = Tree[concept]
        parent = Tree[hypernym]
        node.add_parent(parent)
        parent.add_child(node)

# make sure root is added as a parent to all top level nodes
for value, node in Tree.items():
    if value == "entity":
        continue
    elif node.parent is None:
        node.add_parent(root)
        root.add_child(node)

Tree.default_factory = None # to make sure we dont accidentally add more nodes.

In [None]:
# save data in a way that park et al. did for their paper, but for paligemma:
G = nx.DiGraph()

for entry, node in Tree.items():
    path = node.path()
    if len(path) > 1:
        G.add_edge(node.parent.value, entry)

# I am going to skip the merging since our taxonomy is already pretty smol


In [None]:
# taken directly from the park et al. code.
vocab = tokenizer.get_vocab()
vocab_set = set(vocab.keys())

def _noun_to_gemma_vocab_elements(word):
    word = word.lower()
    plural = p.plural(word)
    add_cap_and_plural = [word, word.capitalize(), plural, plural.capitalize()]
    add_space = ["▁" + w for w in add_cap_and_plural]
    return vocab_set.intersection(add_space)

In [None]:
# saving convention: data/taxonomies/custom/<model_family>

path = "../data/taxonomies/custom/gemma/"
pathlib.Path(path).mkdir(exist_ok=True, parents=True)

with open(f'{path}/items.json', 'w') as f:
    for key, node in Tree.items():
        words = []
        for w in node.descendants():
            words.extend(_noun_to_gemma_vocab_elements(w.value))

        f.write(json.dumps({key : words}) + "\n")
    
nx.write_adjlist(G, f"{path}/hypernym_graph.adjlist")