# setup

## imports

In [None]:
import os
import random
from dataclasses import dataclass
from textwrap import dedent

import networkx as nx
import plotly.graph_objects as go
import plotly.io as pio
import psycopg
from germanetpy.filterconfig import Filterconfig
from germanetpy.frames import Frames
from germanetpy.germanet import Germanet
from germanetpy.path_based_relatedness_measures import PathBasedRelatedness
from germanetpy.synset import WordCategory, WordClass
from psycopg.sql import SQL, Identifier, Literal

## global vars

In [None]:
SET_TEST = True
pio.renderers.default = "notebook"
random.seed(42)

In [None]:
germanet = Germanet("/veld/input/")

In [None]:
POSTGRES_HOST = os.getenv("POSTGRES_HOST")
POSTGRES_PORT = os.getenv("POSTGRES_PORT")
POSTGRES_USER = os.getenv("POSTGRES_USER")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
POSTGRES_DB = os.getenv("POSTGRES_DB")
print(f"{POSTGRES_HOST=}")
print(f"{POSTGRES_PORT=}")
print(f"{POSTGRES_USER=}")
print(f"{POSTGRES_PASSWORD=}")
print(f"{POSTGRES_DB=}")

## DB

In [None]:
conn = psycopg.connect(
    host=POSTGRES_HOST,
    port=POSTGRES_PORT,
    dbname=POSTGRES_DB,
    user=POSTGRES_USER,
    password=POSTGRES_PASSWORD,
)
conn.autocommit = True
cur = conn.cursor()
cur.execute("SELECT version()")
print(cur.fetchone())

In [None]:
class Query:

    def __init__(self, cur):
        self.cur = cur

    def __call__(self, query, enable_print=True, **kwargs):
        kwargs_cleaned = {}
        for key, value in kwargs.items():
            if type(value) is type(self):
                kwargs_cleaned[key] = value.as_sql()
            else:
                kwargs_cleaned[key] = value
        self.query = SQL(query).format(**kwargs_cleaned)
        if enable_print:
            print(self.query.as_string())
        return self

    def as_sql(self):
        return self.query

    def execute(self, data=None):
        if data:
            self.cur.execute(self.query, data)
        else:
            self.cur.execute(self.query)
        return self

    def executemany(self, data):
        self.cur.executemany(self.query, data)
        return self

    def fetchall(self):
        return self.cur.fetchall()


query = Query(cur=cur)

# individual functions

### print_sample_of_related

In [None]:
def print_sample_of_related(word_similarity_dict, num_steps):
    step_divider = len(word_similarity_dict) // (num_steps - 1)
    for i, d in enumerate(word_similarity_dict.items()):
        if i % step_divider == 0 or i == len(word_similarity_dict) - 1:
            print(d)

## germanet

### relatedness_calculator

In [None]:
relatedness_calculator = PathBasedRelatedness(
    germanet=germanet,
    category=WordCategory.nomen,
)


if SET_TEST:
    w1 = germanet.get_synsets_by_orthform("Trompete").pop()
    w2 = germanet.get_synsets_by_orthform("Flöte").pop()
    w3 = germanet.get_synsets_by_orthform("Haus").pop()
    w1_w2 = relatedness_calculator.leacock_chodorow(w1, w2)
    w1_w3 = relatedness_calculator.leacock_chodorow(w1, w3)
    print(w1_w2)
    print(w1_w3)
    w1_w2 = relatedness_calculator.simple_path(w1, w2)
    w1_w3 = relatedness_calculator.simple_path(w1, w3)
    print(w1_w2)
    print(w1_w3)

### get_average_synset_similarity

In [None]:
def get_average_synset_similarity(word_1, word_2):
    synset_list_1 = germanet.get_synsets_by_orthform(word_1)
    synset_list_2 = germanet.get_synsets_by_orthform(word_2)
    path_distance_list = []
    for synset_1 in synset_list_1:
        for synset_2 in synset_list_2:
            try:
                path_distance_list.append(relatedness_calculator.simple_path(synset_1, synset_2))
            except:
                pass
    if len(path_distance_list) != 0:
        average_path = sum(path_distance_list) / len(path_distance_list)
        return average_path
    else:
        return None


if SET_TEST:
    print(get_average_synset_similarity("Frau", "Gattin"))
    print(get_average_synset_similarity("Frau", "Mann"))
    print(get_average_synset_similarity("Frau", "Küche"))
    print(get_average_synset_similarity("Frau", "Kind"))
    print(get_average_synset_similarity("Mann", "Kind"))
    print(get_average_synset_similarity("Frau", "Mathematik"))
    print(get_average_synset_similarity("Mann", "Mathematik"))
    print(get_average_synset_similarity("Frau", "Frau"))
    print(get_average_synset_similarity("Gattin", "Gattin"))

### get_all_words_of_germanet

In [None]:
def get_all_words_of_germanet():
    lex_all_set = set()
    for ss in germanet.synsets.values():
        for lex in ss.lexunits:
            lex_all_set.add(lex.orthform)
    return lex_all_set


if SET_TEST:
    all_germanet_words_set = get_all_words_of_germanet()
    print(len(all_germanet_words_set))

### create_word_similarity_germanet

In [None]:
def create_word_similarity_germanet(word_a, word_set):
    word_b_list = []
    for word_b in word_set:
        if word_a != word_b:
            sim = get_average_synset_similarity(word_a, word_b)
            if sim is not None:
                word_b_list.append((word_b, sim))
            else:
                pass
    word_b_list = sorted(word_b_list, key=lambda x: -x[1])
    return {lex_b: dist for lex_b, dist in word_b_list}


if SET_TEST:
    sample_word = "Tisch"
    word_similarity_germanet_dict = create_word_similarity_germanet(sample_word, all_germanet_words_set)
    print_sample_of_related(word_similarity_germanet_dict, 30)

### create_lexeme_graph

In [None]:
def create_lexeme_graph(limit=None, debug=False):
    g = nx.Graph()
    synset_list = list(germanet.synsets.values())
    if limit is not None:
        synset_list = synset_list[:limit]
    for ss in synset_list:
        if debug:
            print(ss)
        for ss_rel_key, ss_rel_set in ss.relations.items():
            if debug:
                print("\t", ss_rel_key)
            for ss_rel in ss_rel_set:
                if debug:
                    print("\t\t", ss_rel)
                    print("\t\t\t", ss_rel.lexunits)
                for lex_ss in ss.lexunits:
                    for lex_ss_rel in ss_rel.lexunits:
                        if "GNROOT" not in [lex_ss.orthform, lex_ss_rel.orthform]:
                            g.add_edge(lex_ss.orthform, lex_ss_rel.orthform, label=ss.id + "-" + ss_rel.id)
    return g


# perhaps unnecessary function; thus disabled for now.
# if ENABLE_TEST:
#     g = create_lexeme_graph(limit=1000, debug=False)

### plot_graph

In [None]:
def plot_graph(g, sub_node_list=None, traversal_limit=None):
    if sub_node_list:
        sub_node_rel_set = set()
        for sub_node in sub_node_list:
            if traversal_limit is not None:
                sub_node_rel_dict = nx.single_source_shortest_path_length(g, sub_node, cutoff=traversal_limit)
            else:
                sub_node_rel_dict = nx.single_source_shortest_path_length(g, sub_node)
            sub_node_rel_set.update(set(sub_node_rel_dict.keys()))
        g = g.subgraph(sub_node_rel_set).copy()  # make a copy to avoid view issues
    pos = nx.spring_layout(g, seed=42)

    # edges
    edge_x = []
    edge_y = []
    edge_text = []
    edge_label_x = []
    edge_label_y = []
    edge_label_text = []
    for u, v, data in g.edges(data=True):
        x0, y0 = pos[u]
        x1, y1 = pos[v]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_text.append(data.get("label", ""))
        # midpoint for edge label
        edge_label_x.append((x0 + x1) / 2)
        edge_label_y.append((y0 + y1) / 2)
        edge_label_text.append(data.get("label", ""))

    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        line=dict(width=2, color="#888"),
        hoverinfo="text",
        text=edge_text,
        mode="lines",
    )

    # edge labels as separate trace
    edge_label_trace = go.Scatter(
        x=edge_label_x,
        y=edge_label_y,
        mode="text",
        text=edge_label_text,
        textposition="middle center",
        hoverinfo="none",
        textfont=dict(color="black", size=12),
    )

    # nodes
    node_x = []
    node_y = []
    node_text = []
    for node in g.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(node)

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode="markers+text",
        text=node_text,
        textposition="top center",
        hoverinfo="text",
        marker=dict(size=20, color="lightblue", line=dict(width=2, color="DarkSlateGrey")),
    )

    # build figure
    fig = go.Figure(
        data=[edge_trace, edge_label_trace, node_trace],
        layout=go.Layout(
            width=1000,
            height=1000,
            title="Interactive Graph",
            showlegend=False,
            hovermode="closest",
            margin=dict(b=20, l=5, r=5, t=40),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        ),
    )
    fig.show()


# perhaps unnecessary function; thus disabled for now.
# if ENABLE_TEST:
#     plot_graph(g, sub_node_list=["unwirklich"], traversal_limit=2)

### get_all_paths

In [None]:
def get_all_paths(g, l1, l2):
    return list(nx.node_disjoint_paths(g, l1, l2))


# perhaps unnecessary function; thus disabled for now.
# if ENABLE_TEST:
#     for p in get_all_paths(g, "unwirklich", "fantastisch"):
#         print(p)

## word embeddings

### get_all_words_of_embeddings

In [None]:
def get_all_words_of_embeddings(embeddings_table, word_column_name):
    return (
        query(
            "SELECT {word_column_name} FROM {embeddings_table}",
            enable_print=False,
            word_column_name=Identifier(word_column_name),
            embeddings_table=Identifier(embeddings_table),
        )
        .execute()
        .fetchall()
    )


if SET_TEST:
    print(len(get_all_words_of_embeddings("word2vec__m4", "word")))

### get_word_embedding_similarity

In [None]:
def get_word_embedding_similarity(
    word_search, embeddings_table, word_column_name, embeddings_column_name, number_results=10, order_by_closest=True
):
    embeddings_table = (
        query(
            dedent(
                """\
                SELECT * FROM get_related(
                    word_search := {word_search},
                    embeddings_table := {embeddings_table},
                    word_column_name := {word_column_name},
                    embeddings_column_name := {embeddings_column_name},
                    number_results := {number_results},
                    order_by_closest := {order_by_closest}
                )
                """
            ),
            enable_print=False,
            word_search=Literal(word_search),
            embeddings_table=Literal(embeddings_table),
            word_column_name=Literal(word_column_name),
            embeddings_column_name=Literal(embeddings_column_name),
            number_results=Literal(number_results),
            order_by_closest=Literal(order_by_closest),
        )
        .execute()
        .fetchall()
    )
    return {word: similarity for word, similarity in embeddings_table}


if SET_TEST:
    word_similarity_embedding_dict = get_word_embedding_similarity("tisch", "word2vec__m4", "word", "embedding", None)
    print_sample_of_related(word_similarity_embedding_dict, 30)

# composite functions

## find_overlap_words

In [None]:
def find_overlap_words():
    words_embeddings = set([w[0].capitalize() for w in get_all_words_of_embeddings("word2vec__m4", "word")])
    words_germanet = get_all_words_of_germanet()
    return words_embeddings.intersection(words_germanet)


overlap_word_set = find_overlap_words()
print(len(overlap_word_set))

## filter_for_overlapping_words

In [None]:
def filter_for_overlapping_words(word_similarity_germanet_dict, word_similarity_embedding_dict):
    word_embedding_set = {word.capitalize() for word in word_similarity_embedding_dict.keys()}
    common_word_set = word_embedding_set.intersection(set(word_similarity_germanet_dict.keys()))
    word_similarity_germanet_filtered_dict = {}
    for word, sim in word_similarity_germanet_dict.items():
        if word in common_word_set:
            word_similarity_germanet_filtered_dict[word] = sim
    word_similarity_embedding_filtered_dict = {}
    for word, sim in word_similarity_embedding_dict.items():
        if word.capitalize() in common_word_set:
            word_similarity_embedding_filtered_dict[word] = sim
    return word_similarity_germanet_filtered_dict, word_similarity_embedding_filtered_dict


if SET_TEST:
    print(len(word_similarity_germanet_dict))
    print(len(word_similarity_embedding_dict))
    word_similarity_germanet_filtered_dict, word_similarity_embedding_filtered_dict = filter_for_overlapping_words(
        word_similarity_germanet_dict,
        word_similarity_embedding_dict,
    )
    print(len(word_similarity_germanet_filtered_dict))
    print(len(word_similarity_embedding_filtered_dict))

## normalize_word_similarities

In [None]:
def normalize_word_similarities(word_sim_dict):
    sim_list = list(word_sim_dict.values())
    sim_min = sim_list[0]
    sim_max = sim_list[0]
    for sim in sim_list[1:]:
        if sim < sim_min:
            sim_min = sim
        if sim > sim_max:
            sim_max = sim
    scale = sim_max - sim_min
    word_sim_normalized_dict = {}
    for word_other, sim in word_sim_dict.items():
        word_sim_normalized_dict[word_other] = (sim - sim_min) / scale
    word_sim_normalized_dict
    return word_sim_normalized_dict


if SET_TEST:
    word_similarity_germanet_filtered_normalized_dict = normalize_word_similarities(word_similarity_germanet_filtered_dict)
    print_sample_of_related(word_similarity_germanet_filtered_normalized_dict, 30)
    print("----------------------------------------------------")
    word_similarity_embedding_filtered_normalized_dict = normalize_word_similarities(word_similarity_embedding_filtered_dict)
    print_sample_of_related(word_similarity_embedding_filtered_normalized_dict, 30)

## randomize_sim_dict

In [None]:
def randomize_sim_dict(word_sim_dict):
    sim_list = list(word_sim_dict.values())
    random.shuffle(sim_list)
    word_sim_randomized_dict = {}
    for word, sim in zip(word_sim_dict.keys(), sim_list):
        word_sim_randomized_dict[word] = sim
    return word_sim_randomized_dict


if SET_TEST:
    word_similarity_embedding_filtered_normalized_randomized_dict = randomize_sim_dict(word_similarity_embedding_filtered_normalized_dict)

## calculate_sim_diff

In [None]:
def calculate_sim_diff(word_similarity_germanet_dict, word_similarity_embedding_dict, top_num=100):
    sum_diff = 0
    count_diff = 0
    word_sim_list = list(word_similarity_embedding_dict.items())
    if top_num:
        word_sim_list = word_sim_list[:top_num]
    for word, sim_embedding in word_sim_list:
        sim_germanet = word_similarity_germanet_dict[word.capitalize()]
        sum_diff += abs(sim_embedding - sim_germanet)
        count_diff += 1
    return sum_diff / count_diff


if SET_TEST:
    top_num = 1000
    sim_diff = calculate_sim_diff(
        word_similarity_germanet_filtered_normalized_dict,
        word_similarity_embedding_filtered_normalized_dict,
        top_num=top_num,
    )
    print(sim_diff)
    sim_diff = calculate_sim_diff(
        word_similarity_germanet_filtered_normalized_dict,
        word_similarity_embedding_filtered_normalized_randomized_dict,
        top_num=top_num,
    )
    print(sim_diff)

# analysis

# experiments

In [None]:
for p in get_all_paths(g, "Hütte", "Haus"):
    print(p)

In [None]:
"haus" in g

## germanet

In [None]:
b = germanet.get_synsets_by_orthform("Bank")
for ss in b:
    print(ss)

In [None]:
b = germanet.get_synsets_by_orthform("Bankinstitut")
for ss in b:
    print(ss)

In [None]:
print(b[1])
print(b[1].lexunits)

In [None]:
germanet.get_lexunit_by_id("l9381").get_all_orthforms()

In [None]:
b = germanet.get_synsets_by_orthform("Sitzmöbel")
for k, v in b[0].relations.items():
    print(k)
    for other in v:
        print(other)
    print("\n")

In [None]:
b[1].incoming_relations

In [None]:
Filterconfig("orange", ignore_case=True).filter_synsets(germanet)

In [None]:
WordCategory.get_possible_word_classes(WordCategory.nomen)

In [None]:
b = germanet.get_synsets_by_orthform("Bank")
for ss in b:
    print(ss)

In [None]:
ss = b[1]
ss

In [None]:
ss.relations

In [None]:
list(ss.relations.keys())

In [None]:
l = list(ss.relations.keys())
l

In [None]:
ss_hyper_set = ss.relations[l[0]]
ss_hypo_set = ss.relations[l[1]]
print(ss_hyper_set)
print(ss_hypo_set)

In [None]:
ss_hyper = list(ss_hyper_set)[0]
ss_hypo = list(ss_hypo_set)[2]

In [None]:
relatedness_calculator.simple_path(ss, ss_hypo)

In [None]:
relatedness_calculator.simple_path(ss, ss_hyper)

In [None]:
ss_list = germanet.get_synsets_by_orthform("Bank")
l = ss_list[0].lexunits[0]
type(l.orthform)

In [None]:
b = germanet.get_synsets_by_orthform("Sitzmöbel")
for k, v in b[0].relations.items():
    print(k)
    for other in v:
        print(other)
    print("\n")

In [None]:
germanet.get_synsets_by_orthform("GNROOT")