# setup

## imports

In [None]:
import os
from dataclasses import dataclass

import networkx as nx
import plotly.graph_objects as go
import plotly.io as pio
import psycopg
from germanetpy.filterconfig import Filterconfig
from germanetpy.frames import Frames
from germanetpy.germanet import Germanet
from germanetpy.path_based_relatedness_measures import PathBasedRelatedness
from germanetpy.synset import WordCategory, WordClass
from psycopg.sql import SQL, Identifier, Literal

## global vars

In [None]:
enable_test = True
pio.renderers.default = "notebook"

In [None]:
germanet = Germanet("/veld/input/")

In [None]:
POSTGRES_HOST = os.getenv("POSTGRES_HOST")
POSTGRES_PORT = os.getenv("POSTGRES_PORT")
POSTGRES_USER = os.getenv("POSTGRES_USER")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
POSTGRES_DB = os.getenv("POSTGRES_DB")
print(f"{POSTGRES_HOST=}")
print(f"{POSTGRES_PORT=}")
print(f"{POSTGRES_USER=}")
print(f"{POSTGRES_PASSWORD=}")
print(f"{POSTGRES_DB=}")

## DB

In [None]:
conn = psycopg.connect(
    host=POSTGRES_HOST,
    port=POSTGRES_PORT,
    dbname=POSTGRES_DB,
    user=POSTGRES_USER,
    password=POSTGRES_PASSWORD,
)
conn.autocommit = True
cur = conn.cursor()
cur.execute("SELECT version()")
print(cur.fetchone())

In [None]:
class DB:

    class CreateColumn:
        def __init__(self, name, col_type, primary_key=False):
            self.name = name
            self.col_type = col_type
            self.primary_key = primary_key

    def __init__(self, cur):
        self.cur = cur
        self.reset()

    def reset(self):
        self.query = SQL("")

    def execute(self):
        print(self.query.as_string())
        self.cur.execute(self.query)
        self.reset()

    def __repr__(self):
        return self.query.as_string()

    def drop_table(self, table_name):
        self.query += SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(table_name))
        return self

    def create_table(self, name, columns):
        self.query += SQL("CREATE TABLE {table_name} (").format(table_name=Identifier(name))
        for col_id in range(0, len(columns)):
            col = columns[col_id]
            self.query += SQL("{column_name} {column_type}").format(column_name=Identifier(col.name), column_type=SQL(col.col_type))
            if col.primary_key:
                self.query += SQL(" PRIMARY KEY")
            if col_id < len(columns) - 1:
                self.query += SQL(", ")
        self.query += SQL(")")
        return self

    def select_from(self, table_name, columns, where_conditions=None):
        self.query += SQL("SELECT ")


db = DB(cur)

# individual functions

## germanet

In [None]:
relatedness_calculator = PathBasedRelatedness(
    germanet=germanet,
    category=WordCategory.nomen,
    max_len=35,
    max_depth=20,
)


if enable_test:
    w1 = germanet.get_synsets_by_orthform("Trompete").pop()
    w2 = germanet.get_synsets_by_orthform("Flöte").pop()
    w3 = germanet.get_synsets_by_orthform("Haus").pop()
    w1_w2 = relatedness_calculator.simple_path(w1, w2)
    w1_w3 = relatedness_calculator.simple_path(w1, w3)
    print(w1_w2)
    print(w1_w3)

## get_average_synset_path

In [None]:
def get_average_synset_path(word_1, word_2):
    synset_list_1 = germanet.get_synsets_by_orthform(word_1.capitalize())
    synset_list_2 = germanet.get_synsets_by_orthform(word_2.capitalize())
    path_distance_list = []
    for synset_1 in synset_list_1:
        for synset_2 in synset_list_2:
            try:
                path_distance_list.append(relatedness_calculator.simple_path(synset_1, synset_2))
            except:
                return None
    if len(path_distance_list) != 0:
        average_path = sum(path_distance_list) / len(path_distance_list)
        return average_path
    else:
        return None


if enable_test:
    print(get_average_synset_path("Frau", "Gattin"))
    print(get_average_synset_path("Frau", "Mann"))
    print(get_average_synset_path("Frau", "Küche"))
    print(get_average_synset_path("Frau", "Kind"))
    print(get_average_synset_path("Mann", "Kind"))
    print(get_average_synset_path("Frau", "Mathematik"))
    print(get_average_synset_path("Mann", "Mathematik"))
    print(get_average_synset_path("Frau", "Frau"))
    print(get_average_synset_path("Gattin", "Gattin"))

## create_lexeme_graph

In [None]:
def create_lexeme_graph(limit=None, debug=False):
    g = nx.Graph()
    synset_list = list(germanet.synsets.values())
    if limit is not None:
        synset_list = synset_list[:limit]
    for ss in synset_list:
        if debug:
            print(ss)
        for ss_rel_key, ss_rel_set in ss.relations.items():
            if debug:
                print("\t", ss_rel_key)
            for ss_rel in ss_rel_set:
                if debug:
                    print("\t\t", ss_rel)
                    print("\t\t\t", ss_rel.lexunits)
                for lex_ss in ss.lexunits:
                    for lex_ss_rel in ss_rel.lexunits:
                        if "GNROOT" not in [lex_ss.orthform, lex_ss_rel.orthform]:
                            g.add_edge(lex_ss.orthform, lex_ss_rel.orthform, label=ss.id + "-" + ss_rel.id)
    return g


if enable_test:
    g = create_lexeme_graph(limit=1000, debug=False)

## plot_graph

In [None]:
def plot_graph(g, sub_node_list=None, traversal_limit=None):
    if sub_node_list:
        sub_node_rel_set = set()
        for sub_node in sub_node_list:
            if traversal_limit is not None:
                sub_node_rel_dict = nx.single_source_shortest_path_length(g, sub_node, cutoff=traversal_limit)
            else:
                sub_node_rel_dict = nx.single_source_shortest_path_length(g, sub_node)
            sub_node_rel_set.update(set(sub_node_rel_dict.keys()))
        g = g.subgraph(sub_node_rel_set).copy()  # make a copy to avoid view issues
    pos = nx.spring_layout(g, seed=42)

    # edges
    edge_x = []
    edge_y = []
    edge_text = []
    edge_label_x = []
    edge_label_y = []
    edge_label_text = []
    for u, v, data in g.edges(data=True):
        x0, y0 = pos[u]
        x1, y1 = pos[v]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_text.append(data.get("label", ""))
        # midpoint for edge label
        edge_label_x.append((x0 + x1) / 2)
        edge_label_y.append((y0 + y1) / 2)
        edge_label_text.append(data.get("label", ""))

    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        line=dict(width=2, color="#888"),
        hoverinfo="text",
        text=edge_text,
        mode="lines",
    )

    # edge labels as separate trace
    edge_label_trace = go.Scatter(
        x=edge_label_x,
        y=edge_label_y,
        mode="text",
        text=edge_label_text,
        textposition="middle center",
        hoverinfo="none",
        textfont=dict(color="black", size=12),
    )

    # nodes
    node_x = []
    node_y = []
    node_text = []
    for node in g.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(node)

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode="markers+text",
        text=node_text,
        textposition="top center",
        hoverinfo="text",
        marker=dict(size=20, color="lightblue", line=dict(width=2, color="DarkSlateGrey")),
    )

    # build figure
    fig = go.Figure(
        data=[edge_trace, edge_label_trace, node_trace],
        layout=go.Layout(
            width=1000,
            height=1000,
            title="Interactive Graph",
            showlegend=False,
            hovermode="closest",
            margin=dict(b=20, l=5, r=5, t=40),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        ),
    )
    fig.show()


if enable_test:
    plot_graph(g, sub_node_list=["unwirklich"], traversal_limit=2)

In [None]:
def get_graph_path(g, l1, l2):
    return list(nx.node_disjoint_paths(g, l1, l2))


if enable_test:
    for p in get_graph_path(g, "unwirklich", "fantastisch"):
        print(p)

## word embeddings

# analysis

In [None]:
g = create_lexeme_graph(debug=False)

# experiments

## germanet

In [None]:
b = germanet.get_synsets_by_orthform("Bank")
for ss in b:
    print(ss)

In [None]:
b = germanet.get_synsets_by_orthform("Bankinstitut")
for ss in b:
    print(ss)

In [None]:
print(b[1])
print(b[1].lexunits)

In [None]:
germanet.get_lexunit_by_id("l9381").get_all_orthforms()

In [None]:
b = germanet.get_synsets_by_orthform("Sitzmöbel")
for k, v in b[0].relations.items():
    print(k)
    for other in v:
        print(other)
    print("\n")

In [None]:
b[1].incoming_relations

In [None]:
Filterconfig("orange", ignore_case=True).filter_synsets(germanet)

In [None]:
WordCategory.get_possible_word_classes(WordCategory.nomen)

In [None]:
b = germanet.get_synsets_by_orthform("Bank")
for ss in b:
    print(ss)

In [None]:
ss = b[1]
ss

In [None]:
ss.relations

In [None]:
list(ss.relations.keys())

In [None]:
l = list(ss.relations.keys())
l

In [None]:
ss_hyper_set = ss.relations[l[0]]
ss_hypo_set = ss.relations[l[1]]
print(ss_hyper_set)
print(ss_hypo_set)

In [None]:
ss_hyper = list(ss_hyper_set)[0]
ss_hypo = list(ss_hypo_set)[2]

In [None]:
relatedness_calculator.simple_path(ss, ss_hypo)

In [None]:
relatedness_calculator.simple_path(ss, ss_hyper)

In [None]:
ss_list = germanet.get_synsets_by_orthform("Bank")
l = ss_list[0].lexunits[0]
type(l.orthform)

In [None]:
b = germanet.get_synsets_by_orthform("Sitzmöbel")
for k, v in b[0].relations.items():
    print(k)
    for other in v:
        print(other)
    print("\n")

In [None]:
germanet.get_synsets_by_orthform("GNROOT")