In [None]:
import numpy as np
import torch

import matplotlib.pyplot as plt

plt.style.use("ggplot")
plt.style.use("seaborn-v0_8-colorblind")

In [None]:
# # SpaCy uses nltk to handle wordnet
# !pip install nltk
# import nltk
# nltk.download('wordnet')
# nltk.download('brown')
import nltk
from nltk.corpus import brown
from nltk.corpus import wordnet as wn

line_break = '-'*40
big_line_break = '='*40
# word_frequency = nltk.FreqDist(w.lower() for w in brown.words())


In [None]:
# !pip install wordfreq
from wordfreq import word_frequency

In [None]:
# sudo apt-get install mysql-client
# sudo apt-get install mysql-server
# sudo apt-get install libmysqlclient-dev
# !pip install mysqlclient==2.1.1 pattern

from pattern.en import singularize

In [None]:
## load neuron-concept similarity for ranking words
similarities = torch.cat(
    [s["similarities"] for s in torch.load("my_data/all_layer_similarities.pt")]
)

with open("data/20k.txt") as f:
    vocabulary = [l.strip() for l in f]

In [None]:
sim_sum = similarities.clip(0,torch.inf).max(0).values

plt.hist(sim_sum, bins=100)
sim_sum.shape, sim_sum.quantile(0.4)

In [None]:
argsort = sim_sum.argsort(descending=True)
# [vocabulary[i] for i in argsort[:30]]

print('bottom 30:')
[vocabulary[i] for i in argsort[-30:]]

## Word list used to grab synset objects

In [None]:
# with open('data/nouns_and_adjectives.txt') as f:
#     words = [line.strip() for line in f]
# words = vocabulary

words = [vocabulary[i] for i in argsort[:10000]] # use the top 10k highly-activated words from resnet50

#remove single and double-letter words
words = [w for w in words if len(w)>2] 
# singularize words while preserving word order, using the order-preserving property of python dictionaries
words = list(dict.fromkeys([singularize(w) for w in words]).keys())

len(words)

In [None]:
for synset in wn.synsets('a'):
    singularize(synset.name().split(".")[0])

In [None]:
def prefer_exact_match(synsets, word):
    """prefer exact match"""
    word = singularize(word)
    synset_singletons = [singularize(s.name().split(".")[0]) for s in synsets]
    synsets_exact_match = [s for s, ss in zip(synsets, synset_singletons) if ss == word]
    synsets_not_exact_match = [
        s for s, ss in zip(synsets, synset_singletons) if ss != word
    ]
    return synsets_exact_match + synsets_not_exact_match


def prefer_nouns(synsets, word):
    """prefer exact match"""
    synsets_n = [s for s in synsets if s.name().split(".")[1] == "n"]
    synsets_adj = [s for s in synsets if s.name().split(".")[1] == "s"]
    synsets_other = [
        s
        for s in synsets
        if s.name().split(".")[1] != "n" and s.name().split(".")[1] != "s"
    ]
    return synsets_n + synsets_adj + synsets_other


def choose_synset(synsets, word):
    synsets = prefer_exact_match(synsets, word)
#     synsets = prefer_nouns(synsets, word)

    best_synset, best_score = None, -1
    for j, synset in enumerate(synsets):
        lemma_names = [l.name() for l in synset.lemmas()]
        lemma_scores = [word_frequency(ln.lower(), 'en') for ln in lemma_names]
        ## downgrade unmatched lemma
        lemma_scores = [score for ln, score in zip(lemma_names, lemma_scores)]
        synset_score = np.sum(lemma_scores)
        #         synset_score += len(lemma_names)/10 # [optional] favor synset with more lemma

        #         print(list(zip(lemma_names, lemma_scores)))
        #         print(f'[{j}]', synset.definition(), lemma_names, synset_score)

        if synset_score > best_score:
            best_score = synset_score
            best_synset = synset

    return best_synset, best_score

In [None]:
word_synsets = []
word_synsets_write = []
for i, word in enumerate(words):
    print(f"word{i}", word)

    synsets = wn.synsets(word, pos=wn.NOUN)
    synsets += wn.synsets(word, pos=wn.ADJ)

    if len(synsets) == 0:
        print(f"no synset for {word}")
        print(line_break)
        continue

    synset, score = choose_synset(synsets, word)
    
    
    word_synsets.append(
        [word, synset]
    )
    word_synsets_write.append(
        [word, synset.name(), synset.definition()]
    )
    
    print(synset)
    print(synset.definition())
    print(score)
    print(line_break)
    

In [None]:
with open('my_data/wordnet.csv', 'w') as f:
    f.write('word, synset, definition\n')
    for line in word_synsets_write:
        f.write(', '.join(line)  + '\n')

## construct graph from synsets

In [None]:
word_synsets

In [None]:
import networkx as nx

In [None]:
graph = nx.DiGraph()

for word, synset in word_synsets:
    paths = synset.hypernym_paths()
#     print(word)
#     print(synset.definition())
#     display(path)
#     print(line_break)

    for path in paths:
        nodes = [synset.name() for synset in path]
        graph.add_nodes_from(nodes) 
        edges = zip(nodes[1:], nodes[:-1]) # edge pointing toward more general term
        graph.add_edges_from(edges)


In [None]:
graph_drawing_style = dict(
    node_size=2,
    width=0.5,
)

G = nx.dodecahedral_graph()
nx.draw(
    G,**graph_drawing_style
)

In [None]:
for layer, nodes in enumerate(nx.topological_generations(graph)):
    for node in nodes:
        graph.nodes[node]["layer"] = layer
# Compute the multipartite_layout using the "layer" node attribute
pos = nx.multipartite_layout(graph, subset_key="layer")

In [None]:
nx.draw(graph, pos=pos, **graph_drawing_style)

In [None]:
# TODO save graph, pos to tile, visualize and refine in JS