# setup

## imports

In [None]:
import json
import os
import random
from dataclasses import dataclass
from textwrap import dedent

import networkx as nx
import plotly.graph_objects as go
import plotly.io as pio
import psycopg
from germanetpy.filterconfig import Filterconfig
from germanetpy.frames import Frames
from germanetpy.germanet import Germanet
from germanetpy.path_based_relatedness_measures import PathBasedRelatedness
from germanetpy.synset import WordCategory, WordClass
from psycopg.sql import SQL, Identifier, Literal

## global vars

In [None]:
SET_TEST = True
TEST_EMBEDDINGS_TABLE = "word2vec__m4"
TEST_WORD_COLUMN = "word"
TEST_EMBEDDINGS_COLUMN = "embedding"

OUT_DIFF_STATS_JSON_PATH = "/veld/output/diff_stats.json"
OUT_DIFF_STATS_PNG_PATH = "/veld/output/diff_stats.png"
OUT_DIFF_STATS_HTML_PATH = "/veld/output/diff_stats.html"
EMBEDDINGS_METADATA_TABLE = "metadata_static_embedding_models"

In [None]:
pio.renderers.default = "notebook"
random.seed(42)

In [None]:
germanet = Germanet("/veld/input/")

In [None]:
POSTGRES_HOST = os.getenv("POSTGRES_HOST")
POSTGRES_PORT = os.getenv("POSTGRES_PORT")
POSTGRES_USER = os.getenv("POSTGRES_USER")
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
POSTGRES_DB = os.getenv("POSTGRES_DB")
print(f"{POSTGRES_HOST=}")
print(f"{POSTGRES_PORT=}")
print(f"{POSTGRES_USER=}")
print(f"{POSTGRES_PASSWORD=}")
print(f"{POSTGRES_DB=}")

## DB

In [None]:
conn = psycopg.connect(
    host=POSTGRES_HOST,
    port=POSTGRES_PORT,
    dbname=POSTGRES_DB,
    user=POSTGRES_USER,
    password=POSTGRES_PASSWORD,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute("SELECT version()")
print(cursor.fetchone())

# individual functions

### print_sample_of_related

In [None]:
def print_sample_of_related(word_sim_dict, num_steps):
    step = len(word_sim_dict) // (num_steps - 1)
    step_set = set()
    for i in range(0, num_steps - 1):
        step_set.add(i * step)
    step_set.add(len(word_sim_dict) - 1)
    for i, d in enumerate(word_sim_dict.items()):
        if i in step_set:
            print(d)

## germanet

### relatedness_calculator

In [None]:
relatedness_calculator = PathBasedRelatedness(
    germanet=germanet,
    category=WordCategory.nomen,
)


if SET_TEST:
    w1 = germanet.get_synsets_by_orthform("Trompete").pop()
    w2 = germanet.get_synsets_by_orthform("Flöte").pop()
    w3 = germanet.get_synsets_by_orthform("Haus").pop()
    w1_w2 = relatedness_calculator.leacock_chodorow(w1, w2)
    w1_w3 = relatedness_calculator.leacock_chodorow(w1, w3)
    print(w1_w2)
    print(w1_w3)
    w1_w2 = relatedness_calculator.simple_path(w1, w2)
    w1_w3 = relatedness_calculator.simple_path(w1, w3)
    print(w1_w2)
    print(w1_w3)

### get_average_synset_similarity

In [None]:
def get_average_synset_similarity(word_1, word_2):
    synset_list_1 = germanet.get_synsets_by_orthform(word_1)
    synset_list_2 = germanet.get_synsets_by_orthform(word_2)
    path_distance_list = []
    for synset_1 in synset_list_1:
        for synset_2 in synset_list_2:
            try:
                path_distance_list.append(relatedness_calculator.simple_path(synset_1, synset_2))
            except:
                pass
    if len(path_distance_list) != 0:
        average_path = sum(path_distance_list) / len(path_distance_list)
        return average_path
    else:
        return None


if SET_TEST:
    print(get_average_synset_similarity("Frau", "Gattin"))
    print(get_average_synset_similarity("Frau", "Mann"))
    print(get_average_synset_similarity("Frau", "Küche"))
    print(get_average_synset_similarity("Frau", "Kind"))
    print(get_average_synset_similarity("Mann", "Kind"))
    print(get_average_synset_similarity("Frau", "Mathematik"))
    print(get_average_synset_similarity("Mann", "Mathematik"))
    print(get_average_synset_similarity("Frau", "Frau"))
    print(get_average_synset_similarity("Gattin", "Gattin"))

### get_all_words_of_germanet

In [None]:
def get_all_words_of_germanet():
    lex_all_set = set()
    for ss in germanet.synsets.values():
        for lex in ss.lexunits:
            lex_all_set.add(lex.orthform)
    return lex_all_set


if SET_TEST:
    all_germanet_words_set = get_all_words_of_germanet()
    print(len(all_germanet_words_set))

### create_word_sim_germanet

In [None]:
def create_word_sim_germanet(word_a, word_set):
    word_b_list = []
    for word_b in word_set:
        if word_a != word_b:
            sim = get_average_synset_similarity(word_a, word_b)
            if sim is not None:
                word_b_list.append((word_b, sim))
            else:
                pass
    word_b_list = sorted(word_b_list, key=lambda x: -x[1])
    return {lex_b: dist for lex_b, dist in word_b_list}


if SET_TEST:
    sample_word = "Tisch"
    word_sim_germanet_dict = create_word_sim_germanet(sample_word, all_germanet_words_set)
    print_sample_of_related(word_sim_germanet_dict, 30)

### create_lexeme_graph

In [None]:
def create_lexeme_graph(limit=None, debug=False):
    g = nx.Graph()
    synset_list = list(germanet.synsets.values())
    if limit is not None:
        synset_list = synset_list[:limit]
    for ss in synset_list:
        if debug:
            print(ss)
        for ss_rel_key, ss_rel_set in ss.relations.items():
            if debug:
                print("\t", ss_rel_key)
            for ss_rel in ss_rel_set:
                if debug:
                    print("\t\t", ss_rel)
                    print("\t\t\t", ss_rel.lexunits)
                for lex_ss in ss.lexunits:
                    for lex_ss_rel in ss_rel.lexunits:
                        if "GNROOT" not in [lex_ss.orthform, lex_ss_rel.orthform]:
                            g.add_edge(lex_ss.orthform, lex_ss_rel.orthform, label=ss.id + "-" + ss_rel.id)
    return g


# perhaps unnecessary function; thus disabled for now.
# if ENABLE_TEST:
#     g = create_lexeme_graph(limit=1000, debug=False)

### plot_graph

In [None]:
def plot_graph(g, sub_node_list=None, traversal_limit=None):
    if sub_node_list:
        sub_node_rel_set = set()
        for sub_node in sub_node_list:
            if traversal_limit is not None:
                sub_node_rel_dict = nx.single_source_shortest_path_length(g, sub_node, cutoff=traversal_limit)
            else:
                sub_node_rel_dict = nx.single_source_shortest_path_length(g, sub_node)
            sub_node_rel_set.update(set(sub_node_rel_dict.keys()))
        g = g.subgraph(sub_node_rel_set).copy()  # make a copy to avoid view issues
    pos = nx.spring_layout(g, seed=42)

    # edges
    edge_x = []
    edge_y = []
    edge_text = []
    edge_label_x = []
    edge_label_y = []
    edge_label_text = []
    for u, v, data in g.edges(data=True):
        x0, y0 = pos[u]
        x1, y1 = pos[v]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_text.append(data.get("label", ""))
        # midpoint for edge label
        edge_label_x.append((x0 + x1) / 2)
        edge_label_y.append((y0 + y1) / 2)
        edge_label_text.append(data.get("label", ""))

    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        line=dict(width=2, color="#888"),
        hoverinfo="text",
        text=edge_text,
        mode="lines",
    )

    # edge labels as separate trace
    edge_label_trace = go.Scatter(
        x=edge_label_x,
        y=edge_label_y,
        mode="text",
        text=edge_label_text,
        textposition="middle center",
        hoverinfo="none",
        textfont=dict(color="black", size=12),
    )

    # nodes
    node_x = []
    node_y = []
    node_text = []
    for node in g.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_text.append(node)

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode="markers+text",
        text=node_text,
        textposition="top center",
        hoverinfo="text",
        marker=dict(size=20, color="lightblue", line=dict(width=2, color="DarkSlateGrey")),
    )

    # build figure
    fig = go.Figure(
        data=[edge_trace, edge_label_trace, node_trace],
        layout=go.Layout(
            width=1000,
            height=1000,
            title="Interactive Graph",
            showlegend=False,
            hovermode="closest",
            margin=dict(b=20, l=5, r=5, t=40),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        ),
    )
    fig.show()


# perhaps unnecessary function; thus disabled for now.
# if ENABLE_TEST:
#     plot_graph(g, sub_node_list=["unwirklich"], traversal_limit=2)

### get_all_paths

In [None]:
def get_all_paths(g, l1, l2):
    return list(nx.node_disjoint_paths(g, l1, l2))


# perhaps unnecessary function; thus disabled for now.
# if ENABLE_TEST:
#     for p in get_all_paths(g, "unwirklich", "fantastisch"):
#         print(p)

## word embeddings

### get_embeddings_metadata

In [None]:
def get_embeddings_metadata():
    query = SQL("SELECT column_name FROM information_schema.columns WHERE table_name = {table_name}").format(
        table_name=Literal(EMBEDDINGS_METADATA_TABLE)
    )
    cursor.execute(query)
    columns = cursor.fetchall()

    metadata = {}
    for col_name in columns:
        col_name = col_name[0]
        if col_name != "model_id":
            query = SQL("SELECT {col_name} FROM {table_name}").format(
                col_name=Identifier(col_name),
                table_name=Identifier(EMBEDDINGS_METADATA_TABLE)
            )
            cursor.execute(query)
            values = [v[0] for v in cursor.fetchall()]
            metadata[col_name] = values
    return metadata


if SET_TEST:
    metadata = get_embeddings_metadata()
    for col in metadata.items():
        print(col)

### get_all_words_of_embeddings

In [None]:
def get_all_words_of_embeddings(embeddings_table, word_column_name):

    query = SQL("SELECT {word_column_name} FROM {embeddings_table}").format(
        word_column_name=Identifier(word_column_name),
        embeddings_table=Identifier(embeddings_table),
    )
    cursor.execute(query)
    return cursor.fetchall()


if SET_TEST:
    words_embeddings = get_all_words_of_embeddings(TEST_EMBEDDINGS_TABLE, TEST_WORD_COLUMN)
    print(len(words_embeddings))

### get_word_embedding_similarity

In [None]:
def get_word_embedding_similarity(
    word_search, embeddings_table, word_column_name, embeddings_column_name, number_results=10, order_by_closest=True
):

    query = SQL(
        dedent(
            """\
            SELECT * FROM get_related(
                word_search := {word_search},
                embeddings_table := {embeddings_table},
                word_column_name := {word_column_name},
                embeddings_column_name := {embeddings_column_name},
                number_results := {number_results},
                order_by_closest := {order_by_closest}
            )
            """
        )
    ).format(
        word_search=Literal(word_search),
        embeddings_table=Literal(embeddings_table),
        word_column_name=Literal(word_column_name),
        embeddings_column_name=Literal(embeddings_column_name),
        number_results=Literal(number_results),
        order_by_closest=Literal(order_by_closest),
    )
    cursor.execute(query)
    embeddings_table = cursor.fetchall()
    return {word: similarity for word, similarity in embeddings_table}


if SET_TEST:
    word_sim_embedding_dict = get_word_embedding_similarity("tisch", TEST_EMBEDDINGS_TABLE, TEST_WORD_COLUMN, TEST_EMBEDDINGS_COLUMN, None)
    print_sample_of_related(word_sim_embedding_dict, 30)

# composite functions

## find_overlap_words

In [None]:
def find_overlap_words(words_germanet, words_embeddings):
    words_embeddings = set([w[0].capitalize() for w in words_embeddings])
    # words_germanet = get_all_words_of_germanet()
    return words_embeddings.intersection(words_germanet)


if SET_TEST:
    overlap_word_set = find_overlap_words(all_germanet_words_set, words_embeddings)
    print(len(overlap_word_set))

## filter_for_overlapping_words

In [None]:
def filter_for_overlapping_words(word_sim_germanet_dict, word_sim_embedding_dict):
    word_embedding_set = {word.capitalize() for word in word_sim_embedding_dict.keys()}
    common_word_set = word_embedding_set.intersection(set(word_sim_germanet_dict.keys()))
    word_sim_germanet_filtered_dict = {}
    for word, sim in word_sim_germanet_dict.items():
        if word in common_word_set:
            word_sim_germanet_filtered_dict[word] = sim
    word_sim_embedding_filtered_dict = {}
    for word, sim in word_sim_embedding_dict.items():
        if word.capitalize() in common_word_set:
            word_sim_embedding_filtered_dict[word] = sim
    return word_sim_germanet_filtered_dict, word_sim_embedding_filtered_dict


if SET_TEST:
    print(len(word_sim_germanet_dict))
    print(len(word_sim_embedding_dict))
    word_sim_germanet_filtered_dict, word_sim_embedding_filtered_dict = filter_for_overlapping_words(
        word_sim_germanet_dict,
        word_sim_embedding_dict,
    )
    print(len(word_sim_germanet_filtered_dict))
    print(len(word_sim_embedding_filtered_dict))

## normalize_word_similarities

In [None]:
def normalize_word_similarities(word_sim_dict):
    sim_list = list(word_sim_dict.values())
    sim_min = sim_list[0]
    sim_max = sim_list[0]
    for sim in sim_list[1:]:
        if sim < sim_min:
            sim_min = sim
        if sim > sim_max:
            sim_max = sim
    scale = sim_max - sim_min
    word_sim_normalized_dict = {}
    for word_other, sim in word_sim_dict.items():
        word_sim_normalized_dict[word_other] = (sim - sim_min) / scale
    word_sim_normalized_dict
    return word_sim_normalized_dict


if SET_TEST:
    word_sim_germanet_filtered_normalized_dict = normalize_word_similarities(word_sim_germanet_filtered_dict)
    print_sample_of_related(word_sim_germanet_filtered_normalized_dict, 30)
    print("----------------------------------------------------")
    word_sim_embedding_filtered_normalized_dict = normalize_word_similarities(word_sim_embedding_filtered_dict)
    print_sample_of_related(word_sim_embedding_filtered_normalized_dict, 30)

## randomize_sim_dict

In [None]:
def randomize_sim_dict(word_sim_dict):
    sim_list = list(word_sim_dict.values())
    random.shuffle(sim_list)
    word_sim_randomized_dict = {}
    for word, sim in zip(word_sim_dict.keys(), sim_list):
        word_sim_randomized_dict[word] = sim
    return word_sim_randomized_dict


if SET_TEST:
    word_sim_embedding_filtered_normalized_randomized_dict = randomize_sim_dict(word_sim_embedding_filtered_normalized_dict)

## calculate_sim_diff

In [None]:
def calculate_sim_diff(word_sim_germanet_dict, word_sim_embedding_dict, top_num=100):
    sum_diff = 0
    count_diff = 0
    word_sim_list = list(word_sim_embedding_dict.items())
    if top_num:
        word_sim_list = word_sim_list[:top_num]
    for word, sim_embedding in word_sim_list:
        sim_germanet = word_sim_germanet_dict[word.capitalize()]
        sum_diff += abs(sim_embedding - sim_germanet)
        count_diff += 1
    return sum_diff / count_diff


if SET_TEST:
    top_num = 1000
    sim_diff = calculate_sim_diff(
        word_sim_germanet_filtered_normalized_dict,
        word_sim_embedding_filtered_normalized_dict,
        top_num=top_num,
    )
    print(sim_diff)
    sim_diff = calculate_sim_diff(
        word_sim_germanet_filtered_normalized_dict,
        word_sim_embedding_filtered_normalized_randomized_dict,
        top_num=top_num,
    )
    print(sim_diff)

## create_word_sim_dicts

In [None]:
def create_word_sim_dicts(word, all_germanet_words_set=None, print_num=None):
    print(f"{word=}")
    if all_germanet_words_set is None:
        all_germanet_words_set = get_all_words_of_germanet()

    # create sims
    word_sim_germanet_dict = create_word_sim_germanet(word, all_germanet_words_set)
    word_sim_embedding_dict = get_word_embedding_similarity(
        word.lower(), TEST_EMBEDDINGS_TABLE, TEST_WORD_COLUMN, TEST_EMBEDDINGS_COLUMN, None
    )

    # filter on overlapping words
    word_sim_germanet_filtered_dict, word_sim_embedding_filtered_dict = filter_for_overlapping_words(
        word_sim_germanet_dict,
        word_sim_embedding_dict,
    )

    # normalize
    word_sim_germanet_filtered_normalized_dict = normalize_word_similarities(word_sim_germanet_filtered_dict)
    if print_num:
        print("word_sim_germanet_filtered_normalized_dict=")
        print_sample_of_related(word_sim_germanet_filtered_normalized_dict, print_num)
    word_sim_embedding_filtered_normalized_dict = normalize_word_similarities(word_sim_embedding_filtered_dict)
    if print_num:
        print("word_sim_embedding_filtered_normalized_dict=")
        print_sample_of_related(word_sim_embedding_filtered_normalized_dict, print_num)

    # create randomized null model
    word_sim_embedding_filtered_normalized_randomized_dict = randomize_sim_dict(word_sim_embedding_filtered_normalized_dict)
    if print_num:
        print("word_sim_embedding_filtered_normalized_randomized_dict=")
        print_sample_of_related(word_sim_embedding_filtered_normalized_randomized_dict, print_num)

    return (
        word_sim_germanet_filtered_normalized_dict,
        word_sim_embedding_filtered_normalized_dict,
        word_sim_embedding_filtered_normalized_randomized_dict,
    )


if SET_TEST:
    (
        word_sim_germanet_dict,
        word_sim_embedding_dict,
        word_sim_embedding_randomized_dict,
    ) = create_word_sim_dicts("Tisch", print_num=10)

## calculate_diffs

In [None]:
def calculate_diffs(embeddings_table_name, diff_top_num=100, print_num=10, count_words=None):
    all_germanet_words_set = get_all_words_of_germanet()
    all_words_embeddings_list = get_all_words_of_embeddings(TEST_EMBEDDINGS_TABLE, TEST_WORD_COLUMN)
    overlap_word_set = find_overlap_words(all_germanet_words_set, all_words_embeddings_list)
    word_list = list(overlap_word_set)
    if count_words:
        word_list = random.sample(sorted(word_list), count_words)
    sim_diff_real_sum = 0
    sim_diff_real_count = 0
    sim_diff_random_sum = 0
    sim_diff_random_count = 0
    for word in word_list:
        (
            word_sim_germanet_dict,
            word_sim_embedding_dict,
            word_sim_embedding_randomized_dict,
        ) = create_word_sim_dicts(word, all_germanet_words_set, print_num=print_num)
        sim_diff_real = calculate_sim_diff(
            word_sim_germanet_dict,
            word_sim_embedding_dict,
            top_num=diff_top_num,
        )
        print(f"sim_diff_real={sim_diff_real}")
        sim_diff_random = calculate_sim_diff(
            word_sim_germanet_dict,
            word_sim_embedding_randomized_dict,
            top_num=diff_top_num,
        )
        print(f"sim_diff_random={sim_diff_random}")

        # total sums
        sim_diff_real_sum += sim_diff_real
        sim_diff_real_count += 1
        sim_diff_random_sum += sim_diff_random
        sim_diff_random_count += 1
        print("----------------------------------------------------")

    sim_diff_real_avg = sim_diff_real_sum / sim_diff_real_count
    sim_diff_random_avg = sim_diff_random_sum / sim_diff_random_count
    print(f"sim_diff_real_avg={sim_diff_real_avg}")
    print(f"sim_diff_random_avg={sim_diff_random_avg}")
    return sim_diff_real_avg, sim_diff_random_avg


if SET_TEST:
    sim_diff_real_avg, sim_diff_random_avg = calculate_diffs(diff_top_num=10, count_words=2)

# full analysis

## analyse

In [None]:
def analyse(embeddings_table_name_list):
    if os.path.exists(OUT_DIFF_STATS_JSON_PATH):
        with open(OUT_DIFF_STATS_JSON_PATH, "r") as f:
            diff_stats_all_dict = json.load(f)
    else:
        diff_stats_all_dict = {}
    diff_real_list = []
    diff_random_list = []
    for embeddings_table_name in embeddings_table_name_list:
        if embeddings_table_name not in diff_stats_all_dict:
            sim_diff_real_avg, sim_diff_random_avg = calculate_diffs(embeddings_table_name, diff_top_num=10, count_words=2)
            diff_stats_all_dict[embeddings_table_name] = sim_diff_real_avg, sim_diff_random_avg
        else:
            sim_diff_real_avg, sim_diff_random_avg = diff_stats_all_dict[embeddings_table_name]
        diff_real_list.append(sim_diff_real_avg)
        diff_random_list.append(sim_diff_random_avg)

    with open(OUT_DIFF_STATS_JSON_PATH, "w") as f:
        json.dump(diff_stats_all_dict, f)

    return diff_real_list, diff_random_list


metadata = get_embeddings_metadata()
diff_real_list, diff_random_list = analyse(metadata["vector_table_name"])
metadata["diff_real"] = diff_real_list
metadata["diff_random"] = diff_random_list

## transform_data_for_visualisation

In [None]:
def transform_data_for_visualisation(metadata):

    def create_axis(model_id_col_list, col_list, col_name):

        # check if this column has any non-numeric value
        is_numeric = True
        for row in zip(model_id_col_list, col_list):
            try:
                round(row[1], 3)
            except:
                is_numeric = False
                break

        # iterate over rows and create potentially compressed value-label pairs
        value_list = []
        ticks_dict = {}
        non_numeric_dict = {}
        for row in zip(model_id_col_list, col_list):

            # get label and value from row
            label = row[0]
            value = row[1]
            try:
                value = round(value, 3)
            except:
                pass

            # merge labels if values already occurred before
            label_pre = ticks_dict.get(value)
            if label_pre is not None:
                label = label_pre + "," + label
            ticks_dict[value] = label

            # handle non-numeric values, by creating fake numeric values
            if is_numeric:
                value_list.append(value)
            else:
                fake_value = non_numeric_dict.get(value)
                if fake_value is None:
                    fake_value = len(non_numeric_dict) + 1
                non_numeric_dict[value] = fake_value
                value_list.append(fake_value)

        # create main label and value data structure for plotly's tick attributes
        tick_label_list = []
        tick_value_list = []
        for value, label in ticks_dict.items():
            value_str = str(value)
            if value_str == "-1":
                value_str = "null"
            tick_label_list.append(value_str)
            if is_numeric:
                tick_value_list.append(value)
            else:
                tick_value_list.append(non_numeric_dict[value])

        return {
            "tickvals": tick_value_list,
            "ticktext": tick_label_list,
            "label": col_name,
            "values": value_list,
        }

    def create_dimensions_main():
        dim_list = []
        model_id_col_list = metadata["vector_table_name"]
        for col_name in [
            "training_architecture",
            "vector_table_name",
            # "score",
            "training_vector_size",
            "min_count",
            "window_size",
            "training_epochs",
            "training_duration_minutes",
            "model_data_size",
            "train_data_md5_hash",
        ]:
            dim_list.append(create_axis(model_id_col_list, metadata[col_name], col_name))
        return dim_list

    return create_dimensions_main()


dim_list = transform_data_for_visualisation(metadata)
dim_list

## visualise

In [None]:
def visualise(dim_list):
    fig = go.Figure(
        data=go.Parcoords(
            line={"color": list(range(0, len(dim_list[0]["values"]))), "colorscale": 'Rainbow'},
            dimensions=dim_list
        )
    )
    fig.update_layout(height=700)
    fig.show()
    fig.write_html(OUT_DIFF_STATS_HTML_PATH)
    fig.write_image(OUT_DIFF_STATS_PNG_PATH, width=1200, height=800)


visualise(dim_list)

In [None]:
sim_diff_real_avg, sim_diff_random_avg

In [None]:
def visualise(vis_data):
    fig = go.Figure(data=go.Parcoords(line={"color": list(vis_data[0]), "colorscale": "Rainbow"}, dimensions=vis_data))
    fig.update_layout(height=700)
    fig.show()
    # fig.write_html(OUT_VISUALIZATION_HTML_PATH)
    # fig.write_image(OUT_VISUALIZATION_PNG_PATH, width=1200, height=800)


visualise(dim_list)

# experiments

In [None]:
(
    word_sim_germanet_dict,
    word_sim_embedding_dict,
    word_sim_embedding_randomized_dict,
) = create_word_sim_dicts("Tisch", print_num=5)

In [None]:
for i, d in enumerate(word_sim_germanet_dict.items()):
    print(d)
    if i == 20:
        break

In [None]:
for i, d in enumerate(word_sim_embedding_randomized_dict.items()):
    print(d)
    if i == 20:
        break