# modules

## setup

### imports

In [None]:
import os
import pickle
import random
import time
from functools import partial
from typing import TypeAlias

import hnswlib
import hunspell
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import psycopg
import spacy
from gensim.models import Word2Vec
from joblib import Memory
from pgvector.psycopg import register_vector
from psycopg.sql import SQL, Identifier, Placeholder
from scipy.linalg import orthogonal_procrustes
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity

### config

In [None]:
TEST = True
RESET_DB = False

MODELS_WORD2VEC_FOLDER = "/veld/input/models/word2vec/"
MODELS_FASTTEXT_FOLDER = "/veld/input/models/fasttext/"
MODELS_GLOVE_FOLDER = "/veld/input/models/glove/"
TEXTS_FOLDER = "/veld/input/texts/"
CACHE_FOLDER = "/veld/storage/cache/"

INDEX_EF_CONSTRUCTION = 100
INDEX_M = 16

USABLE_DECADES = [156, 191]

PLOT_SLEEP = 2

DB_NAME = "postgres_db"
DB_USER = "postgres_user"
DB_PASSWORD = "postgres_password"
DB_HOST = "veld_step_7_run_embeddings_sql_server"
DB_PORT = "5432"

COLS_EMBEDDING = ["lemma", "occurrence_count", "embedding"]
COLS_DIFF = ["lemma", "occurrence_count", "diff"]

TEST_LEMMA_RANGED_LIST = sorted(["d", "und", "gehen", "wohnen", "Fürst"], key=str.lower)
TEST_LEMMA_CLOSE_LIST = sorted(["gehen", "laufen", "wandern", "wohnen", "trinken"], key=str.lower)

memory = Memory(location="/veld/storage/cache/", verbose=0)
nlp = spacy.load("de_core_news_sm")
hunspell_check = hunspell.HunSpell("/usr/share/hunspell/de_DE.dic", "/usr/share/hunspell/de_DE.aff")
random.seed(42)
pio.renderers.default = "iframe"

## helpers

### is_word

In [None]:
def is_word(word):
    try:
        _ = int(word)
    except:
        if len(word) == 1 and word != "d":
            return False
        else:
            return hunspell_check.spell(word)
    else:
        return False

### create_decades_list

In [None]:
def create_decades_list(decade_start: int = USABLE_DECADES[0], decade_end: int = USABLE_DECADES[1]):
    return [d for d in range(decade_start, decade_end + 1)]


if TEST:
    decade_list_test = create_decades_list(183, 187)
    print(decade_list_test)

### create_occurrence_dict

In [None]:
@memory.cache
def create_occurrence_dict(decade) -> dict:
    print("create_occurrence_dicts: start: decade:", decade)
    # lemma_occurrence_position_dict = {}
    lemma_occurrence_count_dict = {}
    total_occurrence_count = 0
    with open(TEXTS_FOLDER + str(decade) + ".txt", "r") as f:
        for line_number, line in enumerate(f):
            for lemma_number, lemma in enumerate(line.rstrip("\n").split()):
                occurrence_count = lemma_occurrence_count_dict.get(lemma, 0)
                lemma_occurrence_count_dict[lemma] = occurrence_count + 1
                total_occurrence_count += 1
    lemma_count = len(lemma_occurrence_count_dict)
    occurrence_avg = total_occurrence_count / lemma_count
    median_pos = int(lemma_count / 2)
    occurrence_median = list(lemma_occurrence_count_dict.values())[median_pos]
    print("create_occurrence_dicts: uniqe lemma_count:", lemma_count)
    print("create_occurrence_dicts: total_occurrence_count:", total_occurrence_count)
    print("create_occurrence_dicts: occurrence_avg:", occurrence_avg)
    print("create_occurrence_dicts: occurrence_median:", occurrence_median)
    return lemma_occurrence_count_dict


if TEST:
    decade_lemma_occurrence_count_dict = {}
    for decade in decade_list_test:
        decade_lemma_occurrence_count_dict[decade] = create_occurrence_dict(decade)

### sort_dict_by_value

In [None]:
def sort_dict_by_value(key_value_dict: dict, desc=True) -> dict:
    if desc:
        sort_mod = -1
    else:
        sort_mod = 1
    return dict(sorted(key_value_dict.items(), key=lambda x: sort_mod * x[1]))


if TEST:
    print(
        sort_dict_by_value(
            {
                "x": 3,
                "y": 2,
                "z": 4,
            },
            desc=False,
        )
    )

## plotting

### plot_tsne_from_labels_embeddings

In [None]:
def plot_tsne_from_labels_embeddings(
    label_embedding_list: list[str, np.array],
    title: str = None,
    height: int = None,
    width: int = None,
    rotation_degree: int = None,
    perplexity: int = None,
):
    labels = []
    embeddings = []
    for l, e in label_embedding_list:
        labels.append(l)
        embeddings.append(e)
    reduced_vectors_tsne = calculate_tsne(embeddings, perplexity)

    if rotation_degree:
        angle_rad = np.deg2rad(-rotation_degree)
        rotation_matrix = np.array([[np.cos(angle_rad), -np.sin(angle_rad)], [np.sin(angle_rad), np.cos(angle_rad)]])
        reduced_vectors_tsne = reduced_vectors_tsne @ rotation_matrix.T

    if height is None:
        height = 800
    if width is None:
        width = 800
    time.sleep(PLOT_SLEEP)
    fig = px.scatter(
        x=reduced_vectors_tsne[:, 0],
        y=reduced_vectors_tsne[:, 1],
        text=labels,
        height=height,
        width=width,
        title=title,
    )
    fig.update_layout(xaxis=dict(title=None, showticklabels=False), yaxis=dict(title=None, showticklabels=False))
    fig.update_traces(
        marker=dict(size=10),
        textposition="bottom center",
        textfont=dict(size=12),
    )
    fig.show()

# if TEST:
#     plot_tsne_from_labels_embeddings(query_generic("word2vec__185", lemma_list=["gehen", "laufen", "essen", "Haus", "Philosophie"], select_cols=["lemma", "embedding"]))

### plot_tsne_from_lemma_and_related

In [None]:
def plot_tsne_from_lemma_and_related(
    table_name,
    lemma: str = None,
    n: int = 100,
    title: str = None,
    height: int = None,
    width: int = None,
    rotation_degree: int = None,
):
    plot_tsne_from_labels_embeddings(query_related(table_name, lemma, select_cols=["lemma", "embedding"], n=n))

### plot_2d_scatter

In [None]:
def plot_2d_scatter(data: dict | list[list], title: str = None, draw_line=False, show_x_labels=True, xaxis_range=None, yaxis_range=None):
    time.sleep(PLOT_SLEEP)
    key_list = []
    value_list = []
    if type(data) is dict:
        for key, value in data.items():
            key_list.append(key)
            value_list.append(value)
    else:
        for key, value in data:
            key_list.append(key)
            value_list.append(value)
    fig = px.scatter(x=key_list, y=value_list, title=title)
    if draw_line:
        fig.update_traces(mode="lines+markers")
    fig.update_layout(xaxis_title=None, yaxis_title=None)
    if not show_x_labels:
        fig.update_xaxes(showticklabels=False)
    if xaxis_range:
        fig.update_layout(xaxis_range=xaxis_range)
    if yaxis_range:
        fig.update_layout(yaxis_range=yaxis_range)
    fig.show()

### plot_merged_diff_table

In [None]:
def plot_merged_diff_table(merged_diff_table):
    data = []
    for merged_diff in merged_diff_table:
        data.append(("A:" + merged_diff[0], merged_diff[1]))
        data.append(("B:" + merged_diff[0], merged_diff[2]))
    plot_2d_scatter(data)

### plot_lemma_and_decade

In [None]:
def plot_lemma_and_decade(decade_list, lemma_list, perplexity=None, title=None):
    time.sleep(PLOT_SLEEP)

    # prepare data
    global_labels_list = []
    global_embeddings_list = []
    group_end_position_list = []
    position_count = 0
    lemma_decade_embdding = {}
    for decade, lemma, embedding in query_lemma_over_decades(decade_list, lemma_list, print_query=False):
        decade_embedding_dict = lemma_decade_embdding.get(lemma, {})
        decade_embedding_dict[decade] = embedding
        lemma_decade_embdding[lemma] = decade_embedding_dict
    for lemma, decade_embedding_dict in lemma_decade_embdding.items():
        for decade, embedding in decade_embedding_dict.items():
            global_labels_list.append(str(decade) + ":" + lemma)
            global_embeddings_list.append(embedding)
            position_count += 1
        group_end_position_list.append(position_count)
    if 1 < len(global_embeddings_list) < 6:
        perplexity = len(global_embeddings_list) - 1
    else:
        perplexity = None
    lemma_embeddings_reduced_array = calculate_tsne(global_embeddings_list, perplexity=perplexity)

    # create plot
    fig = go.Figure()
    group_start_position = 0
    for group_end_position in group_end_position_list:
        lemma_respective_embeddings = lemma_embeddings_reduced_array[group_start_position:group_end_position]
        fig.add_trace(
            go.Scatter(
                x=lemma_respective_embeddings[:, 0],
                y=lemma_respective_embeddings[:, 1],
                mode="lines",
            )
        )
        group_start_position = group_end_position
    fig.add_trace(
        go.Scatter(
            x=lemma_embeddings_reduced_array[:, 0],
            y=lemma_embeddings_reduced_array[:, 1],
            mode="markers+text",
            text=global_labels_list,
            textposition="top center",
        )
    )
    fig.update_layout(
        title=title if title is not None else "",
        showlegend=False,
        width=800,
        height=800,
    )
    fig.show()

# if TEST:
#     plot_lemma_and_decade(decade_list_test, ["Haus"])

### plot_average_diff

In [None]:
def plot_average_diff(decade_list, avg_diff_table, order_desc=True, lemma_min_occurrence=None):

    result = query_generic(avg_diff_table, select_cols=["lemma", "avg"], order_by="avg", order_desc=order_desc, print_query=False)
    if lemma_min_occurrence is not None:
        result_lemma_set = set(r[0] for r in result)
        lemma_decade_occurrence_dict = {}
        for decade in decade_list:
            decades_lemma_set = set(query_generic(f"word2vec__{decade}", select_cols=["lemma"], print_query=False))
            for lemma in result_lemma_set:
                if lemma in decades_lemma_set:
                    count = lemma_decade_occurrence_dict.get(lemma, 0)
                    lemma_decade_occurrence_dict[lemma] = count + 1
        result_filtered = []
        for r in result:
            if lemma_decade_occurrence_dict.get(r[0], 0) >= lemma_min_occurrence:
                result_filtered.append(r)
        result = result_filtered
    plot_2d_scatter(result)

### plot_lemma_diff

In [None]:
def plot_lemma_diff(decade_list, lemma, diff_table_prefix, yaxis_range=None):
    table = []
    decade_a = decade_list[0]
    if "trajectory" in diff_table_prefix:
        start_i = 2
    else:
        start_i = 1
    for decade_b in decade_list[start_i:]:
        result = query_generic(f"{diff_table_prefix}__{decade_a}_{decade_b}", lemma_list=[lemma], select_cols=["diff"], print_query=False)
        if result:
            table.append(
                (
                    f"{decade_a}-{decade_b}",
                    result[0]
                )
            )
        decade_a += 1
    plot_2d_scatter(table, draw_line=True, yaxis_range=yaxis_range)

## DB functions

### connect_db

In [None]:
def connect_db(create_cursor=True):
    conn = psycopg.connect(
        dbname=DB_NAME,
        user=DB_USER,
        password=DB_PASSWORD,
        host=DB_HOST,
        port=DB_PORT,
    )
    conn.autocommit = True
    cursor = None
    if create_cursor:
        cursor = conn.cursor()
        cursor.execute("SELECT version();")
        print("connected to:", cursor.fetchone())
    return conn, cursor


conn, cursor = connect_db()

register_vector(conn)
cursor = conn.cursor()

### reset_db

In [None]:
def reset_db(decade_list):
    cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
    cursor.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public';")
    for table in cursor.fetchall():
        drop_table(table[0])
    for decade in decade_list:
        # for model in ["word2vec", "fasttext", "glove"]:
        for model in ["word2vec"]:
            create_embeddings_table(f"{model}__{decade}")


if TEST and RESET_DB:
    reset_db(decade_list_test)

### register_vector_conn

In [None]:
def register_vector_conn():
    global conn
    global cursor
    cursor.close()
    conn.close()
    conn, _ = connect_db(create_cursor=False)
    register_vector(conn)
    cursor = conn.cursor()
    return conn, cursor


conn, cursor = register_vector_conn()

### drop_table

In [None]:
def drop_table(table_name):
    print("drop_table: table_name:", table_name)
    cursor.execute(f'DROP TABLE IF EXISTS "{table_name}" CASCADE;')


if TEST:
    drop_table("test")

### create_embeddings_table

In [None]:
def create_embeddings_table(table_name):
    print("create_embeddings_table: table_name:", table_name)
    cursor.execute(
        f"CREATE TABLE IF NOT EXISTS {table_name} ("
        f"lemma TEXT PRIMARY KEY, "
        f"occurrence_count INTEGER, "
        f"embedding VECTOR(300) not null"
        f");"
    )


if TEST:
    create_embeddings_table("test")

### create_diff_table

In [None]:
def create_diff_table(table_name):
    print("create_diff_table: table_name:", table_name)
    cursor.execute(
        f"CREATE TABLE IF NOT EXISTS {table_name} (" f"lemma TEXT PRIMARY KEY, " f"occurrence_count INTEGER, " f"diff REAL" f");"
    )


if TEST:
    create_diff_table("test_diff")

### create_merged_diff_table

In [None]:
def create_merged_diff_table(table_name):
    print("create_diff_table: table_name:", table_name)
    cursor.execute(f"CREATE TABLE IF NOT EXISTS {table_name} (lemma TEXT PRIMARY KEY, diff_a REAL, diff_b REAL, diff_diff REAL);")


if TEST:
    create_diff_table("test_merged_diff")

### create_avg_table

In [None]:
def create_avg_table(table_name, avg_is_int=True):
    print("create_diff_table: table_name:", table_name)
    if avg_is_int:
        cursor.execute(
            f"CREATE TABLE IF NOT EXISTS {table_name} (lemma TEXT PRIMARY KEY, avg INTEGER);"
        )
    else:
        cursor.execute(
            f"CREATE TABLE IF NOT EXISTS {table_name} (lemma TEXT PRIMARY KEY, avg REAL);"
        )


if TEST:
    create_avg_table("test_avg")

### insert_to_db

In [None]:
def insert_to_db(table_name, table_data, cols=COLS_EMBEDDING):
    print("insert_to_db: table_name:", table_name, "len(table_data):", len(table_data))
    query = SQL("INSERT INTO {table_name} ({cols}) VALUES ({values}) ON CONFLICT(lemma) DO NOTHING")
    query = query.format(
        table_name=Identifier(table_name),
        cols=SQL(", ").join([Identifier(c) for c in cols]),
        values=SQL(", ").join([Placeholder() for _ in cols]),
    )
    print(query.as_string())
    cursor.executemany(query, table_data)


if TEST:
    table_data = [
        ["gehen", 25, [4.3, 1.2, 0.3] * 100],
        ["laufen", 17, [3.2, 1.7, 2.5] * 100],
    ]
    insert_to_db("test", table_data)

### load_word2vec_to_db

In [None]:
def load_word2vec_to_db(decade, lemma_occurrence_count_dict):
    print("load_word2vec_to_db: start: decade:", decade)
    model_path = MODELS_WORD2VEC_FOLDER + str(decade) + ".bin"
    model = Word2Vec.load(model_path)
    db_insertion_list = []
    for lemma in model.wv.index_to_key:
        if is_word(lemma):
            embedding = model.wv[lemma]
            embedding_normalized = embedding / np.linalg.norm(embedding)
            db_insertion_list.append((lemma, lemma_occurrence_count_dict[lemma], embedding_normalized))
    insert_to_db(f"word2vec__{decade}", db_insertion_list)


if TEST and RESET_DB:
    for decade, lemma_occurrence_count_dict in decade_lemma_occurrence_count_dict.items():
        load_word2vec_to_db(decade, lemma_occurrence_count_dict)

### load_glove_to_db

In [None]:
def load_glove_to_db(decade, lemma_occurrence_count_dict):
    VECTORS = {}
    model_path = MODELS_GLOVE_FOLDER + str(decade) + "_vector.txt"
    with open(model_path, "r") as f:
        for line in f:
            vals = line.rstrip().split(" ")
            VECTORS[vals[0]] = np.array([float(x) for x in vals[1:]])
    return VECTORS


if TEST and RESET_DB:
    load_glove_to_db(185, decade_lemma_occurrence_count_dict[185])

### query_generic

In [None]:
def query_generic(table_name, lemma_list=None, select_cols=COLS_EMBEDDING, order_by=None, order_desc=False, print_query=True):
    query = SQL("SELECT {select_cols} FROM {table_name}")
    query = query.format(select_cols=SQL(", ").join([Identifier(c) for c in select_cols]), table_name=Identifier(table_name))
    if lemma_list:
        query_where = SQL("WHERE lemma = ANY({lemma_list})")
        query_where = query_where.format(lemma_list=Placeholder("lemma_list"))
        query = SQL(" ").join([query, query_where])
        params = {"lemma_list": lemma_list}
    else:
        params = {}
    if order_by:
        query += SQL(" ORDER BY {order_col}").format(order_col=Identifier(order_by))
        if order_desc:
            query += SQL(" DESC")
    if print_query:
        print(query.as_string())
    cursor.execute(query=query, params=params)
    result = cursor.fetchall()
    if len(select_cols) == 1:
        result = [r[0] for r in result]
    return result


if TEST:
    print(len(query_generic("word2vec__185")))
    print(len(query_generic("word2vec__185", ["gehen"])))
    print(len(query_generic("word2vec__185", ["gehen", "laufen"])))
    print(len(query_generic("word2vec__185", ["gehen"], ["lemma"], "lemma", order_desc=True)))

### query_related

In [None]:
def query_related(table_name, lemma, n=10, select_cols=["lemma", "cos_sim"]):
    select_cols_sql = ", ".join(select_cols)
    cursor.execute(
        f"""
        WITH similarities AS (
            SELECT a.lemma, a.embedding, 1 - (a.embedding <=> b.embedding) AS cos_sim
            FROM {table_name} a
            CROSS JOIN (
                SELECT embedding FROM {table_name}
                WHERE lemma = '{lemma}'
            ) AS b
            WHERE lemma != '{lemma}'
        )
        SELECT {select_cols_sql}
        FROM similarities
        ORDER BY cos_sim DESC
        LIMIT {n};
    """
    )
    return cursor.fetchall()


if TEST:
    print(query_related("word2vec__185", "gehen"))
    print(query_related("word2vec__185", "Frau", n=5))
    print(query_related("word2vec__185", "Frau", n=1, select_cols=["embedding"])[0][0].shape)

### query_mutual_lemmas

In [None]:
def query_mutual_lemmas(table_list, include_count=True, print_query=True):
    select_occurrence_part = ""
    if include_count:
        select_occurrence_part = ", " + " + ".join([t + ".occurrence_count" for t in table_list])
    join_part = table_list[0]
    for table in table_list[1:]:
        join_part += " INNER JOIN " + table + " USING (lemma) "
    query = f"SELECT lemma {select_occurrence_part} FROM {join_part}"
    query += ";"
    if print_query:
        print("query_mutual_lemmas: query:", query)
    cursor.execute(query)
    result = cursor.fetchall()
    if not include_count:
        result = [r[0] for r in result]
    return result


if TEST:
    table_list_list = [
        [["word2vec__183"], True],
        [["word2vec__183", "word2vec__184"], True],
        [["word2vec__183", "word2vec__184", "word2vec__185"], False],
    ]
    for table_list, include_count in table_list_list:
        r = query_mutual_lemmas(table_list, include_count)
        print(len(r))
        print(r[0:50])

### query_average_create_table

In [None]:
def query_average_create_table(table_prefix, table_kind, decade_list, select_cols, table_name_avg=None, round_value=False):
    lemma_value_dict = {}
    decade_a = decade_list[0]
    decade_b = decade_list[1]
    decade_c = decade_list[2]
    for decade in decade_list[3:] + [None]:
        if table_kind == "embedding":
            table_name = f"{table_prefix}__{decade_a}"
        elif table_kind == "diff__cos_sim" or table_kind == "diff__relative":
            table_name = f"{table_prefix}__{table_kind}__{decade_a}_{decade_b}"
        elif table_kind == "diff__trajectory":
            table_name = f"{table_prefix}__{table_kind}__{decade_a}_{decade_c}"
        else:
            raise Exception("no table_kind specified")
        for lemma, value in query_generic(table_name, select_cols=select_cols):
            value_list = lemma_value_dict.get(lemma, [])
            value_list.append(value)
            lemma_value_dict[lemma] = value_list
        decade_a = decade_b
        decade_b = decade_c
        decade_c = decade
    lemma_value_average_dict = {}
    for lemma, value_list in lemma_value_dict.items():
        value_avg = sum(value_list) / len(value_list)
        if round_value:
            value_avg = int(value_avg)
        lemma_value_average_dict[lemma] = value_avg
    lemma_value_average_dict = sort_dict_by_value(lemma_value_average_dict)
    if table_name_avg:
        drop_table(table_name_avg)
        create_avg_table(table_name_avg, avg_is_int=round_value)
        lemma_value_average_list = [(k, v) for k, v in lemma_value_average_dict.items()]
        insert_to_db(table_name_avg, lemma_value_average_list, cols=["lemma", "avg"])
    return lemma_value_average_dict


if TEST:
    lemma_count_average_dict = query_average_create_table(
        "word2vec",
        table_kind="embedding",
        decade_list=decade_list_test,
        select_cols=["lemma", "occurrence_count"],
        table_name_avg="word2vec__occurrence_count__avg",
        round_value=True,
    )
    print(list(lemma_count_average_dict.items())[:10])

### query_over_mutual

In [None]:
def query_over_mutual(table_name_list, select_cols=COLS_EMBEDDING):
    print("query_over_mutual: start: table_name_list:", table_name_list)
    common_lemma = query_mutual_lemmas(table_name_list, include_count=False)
    embeddings_table_list = []
    for table_name in table_name_list:
        embeddings_table_list.append(query_generic(table_name, common_lemma, select_cols=select_cols, order_by="lemma"))
    return zip(*embeddings_table_list)


if TEST:
    for count, x in enumerate(query_over_mutual(["word2vec__183", "word2vec__184"])):
        a = x[0]
        b = x[1]
        print(a[0:2], a[2].shape, b[0:2], b[2].shape)
        if count == 10:
            break

### query_lemma_over_decades

In [None]:
def query_lemma_over_decades(decade_list, lemma_list, print_query=True):
    result_table = []
    for decade in decade_list:
        for lemma, embedding in query_generic(f"word2vec__{decade}", lemma_list, select_cols=["lemma", "embedding"], print_query=print_query):
            result_table.append((decade, lemma, embedding))
    return result_table


if TEST:
    print(len(query_lemma_over_decades([184,185,186], ["gehen", "essen"])))

## difference analysis functions

### calculate_cos_sim

In [None]:
def calculate_cos_sim(embedding_a: np.ndarray, embedding_b: np.ndarray) -> float:
    return np.dot(embedding_a, embedding_b) / (np.linalg.norm(embedding_a) * np.linalg.norm(embedding_b))

### calculate_cos_distance

In [None]:
def calculate_cos_distance(embedding_a: np.ndarray, embedding_b: np.ndarray) -> float:
    return 1 - calculate_cos_sim(embedding_a, embedding_b)

### calculate_tsne

In [None]:
def calculate_tsne(embeddings, perplexity=None):
    if perplexity is None:
        if len(embeddings) < 6:
            perplexity = len(embeddings) - 1
        else:
            perplexity = 40
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
    embeddings_reduced = tsne.fit_transform(np.array(embeddings))
    return embeddings_reduced


if TEST:
    embeddings_reduced = calculate_tsne(query_generic("word2vec__185", TEST_LEMMA_CLOSE_LIST, select_cols=["embedding"]))
    print(embeddings_reduced)

### calculate_trajectory_between_vectors

In [None]:
def calculate_trajectory_between_vectors(vector_a, vector_b, vector_c):
    vector_ab = vector_a - vector_b
    vector_bc = vector_b - vector_c
    return np.dot(vector_ab, vector_bc)


if TEST:
    vector_list_list = [
        (np.array([1, 2]), np.array([2, 3]), np.array([2, 5])),
        (np.array([1, 2]), np.array([2, 3]), np.array([2, 3])),
        (np.array([1, 2]), np.array([2, 3]), np.array([1, 2])),
    ]
    for vector_list in vector_list_list:
        print(calculate_trajectory_between_vectors(*vector_list))

### calculate_procrustes_alignment

In [None]:
def calculate_procrustes_alignment(table_name_a, table_name_b, table_aligned_name_b):
    print("calculate_procrustes_alignment: start: (table_name_a, table_name_b, table_aligned_name_b):", (table_name_a, table_name_b, table_aligned_name_b))
    overlap_matrix_a = []
    overlap_matrix_b = []
    for lemma_embedding_a, lemma_embedding_b in query_over_mutual([table_name_a, table_name_b]):
        occurrence_count_sqrt = np.log1p((lemma_embedding_a[1] + lemma_embedding_b[1]) / 2)
        overlap_matrix_a.append(lemma_embedding_a[2] * occurrence_count_sqrt)
        overlap_matrix_b.append(lemma_embedding_b[2] * occurrence_count_sqrt)
    overlap_matrix_a = np.stack(overlap_matrix_a)
    overlap_matrix_b = np.stack(overlap_matrix_b)

    # do procrustes transformation
    r, _ = orthogonal_procrustes(overlap_matrix_b, overlap_matrix_a)
    embeddings_table_all_b = query_generic(table_name_b, select_cols=COLS_EMBEDDING, order_by="lemma")
    matrix_b = [e[2] for e in embeddings_table_all_b]
    matrix_b = np.stack(matrix_b)
    matrix_b_aligned = matrix_b @ r
    matrix_b_aligned_normalized = matrix_b_aligned / np.linalg.norm(matrix_b_aligned, axis=1, keepdims=True)
    print("calculate_procrustes_alignment: matrix_b_aligned.shape:", matrix_b_aligned.shape)
    db_insertion_data = []
    for embeddings_table_b_data, embedding_b_aligned in zip(embeddings_table_all_b, matrix_b_aligned_normalized):
        db_insertion_data.append([embeddings_table_b_data[0], embeddings_table_b_data[1], embedding_b_aligned])
    drop_table(table_aligned_name_b)
    create_embeddings_table(table_aligned_name_b)
    insert_to_db(table_aligned_name_b, db_insertion_data)


if TEST:
    if RESET_DB:
        calculate_procrustes_alignment("word2vec__184", "word2vec__185", "word2vec__185__aligned")
        calculate_procrustes_alignment("word2vec__185__aligned", "word2vec__186", "word2vec__186__aligned")
    embeddings_table_184 = query_generic(
        "word2vec__184",
        TEST_LEMMA_RANGED_LIST,
        select_cols=["lemma", "embedding"],
        order_by="lemma",
    )
    embeddings_table_185 = query_generic(
        "word2vec__185",
        TEST_LEMMA_RANGED_LIST,
        select_cols=["lemma", "embedding"],
        order_by="lemma",
    )
    embeddings_table_186 = query_generic(
        "word2vec__186",
        TEST_LEMMA_RANGED_LIST,
        select_cols=["lemma", "embedding"],
        order_by="lemma",
    )
    embeddings_table_185_aligned = query_generic(
        "word2vec__185__aligned",
        TEST_LEMMA_RANGED_LIST,
        select_cols=["lemma", "embedding"],
        order_by="lemma",
    )
    embeddings_table_186_aligned = query_generic(
        "word2vec__186__aligned",
        TEST_LEMMA_RANGED_LIST,
        select_cols=["lemma", "embedding"],
        order_by="lemma",
    )
    for lemma, e_184, e_185, e_186, e_185_aligned, e_186_aligned in zip(
        TEST_LEMMA_RANGED_LIST,
        embeddings_table_184,
        embeddings_table_185,
        embeddings_table_186,
        embeddings_table_185_aligned,
        embeddings_table_186_aligned,
    ):
        print(lemma)
        print("184-185:", calculate_cos_sim(e_184[1], e_185[1]))
        print("185-186:", calculate_cos_sim(e_185[1], e_186[1]))
        print("184-185_aligned:", calculate_cos_sim(e_184[1], e_185_aligned[1]))
        print("185_aligned-186_aligned:", calculate_cos_sim(e_185_aligned[1], e_186_aligned[1]))

### calculate_cos_sim_between_tables

In [None]:
def calculate_cos_sim_between_tables(table_name_a, table_name_b, table_name_diff):
    print(
        "calculate_cos_sim_between_tables: start: (table_name_a, table_name_b, table_name_diff)",
        (table_name_a, table_name_b, table_name_diff),
    )
    db_insertion_data = []
    for lemma_embedding_a, lemma_embedding_b in query_over_mutual([table_name_a, table_name_b]):
        db_insertion_data.append(
            (
                lemma_embedding_a[0],
                int((lemma_embedding_a[1] + lemma_embedding_b[1]) / 2),
                calculate_cos_sim(lemma_embedding_a[2], lemma_embedding_b[2]),
            )
        )
    print("calculate_cos_sim_between_tables: len(db_insertion_data):", len(db_insertion_data))
    drop_table(table_name_diff)
    create_diff_table(table_name_diff)
    insert_to_db(table_name_diff, db_insertion_data, cols=["lemma", "occurrence_count", "diff"])


if TEST:
    diff_table_name = "word2vec__diff__cos_sim__184_185"
    if RESET_DB:
        calculate_cos_sim_between_tables("word2vec__185__aligned", "word2vec__186__aligned", diff_table_name)
    diff_table = query_generic(
        diff_table_name,
        select_cols=["lemma", "diff"],
        order_by="diff",
        order_desc=True,
    )
    plot_2d_scatter(diff_table, diff_table_name)

### calculate_trajectory_between_tables

In [None]:
def calculate_trajectory_between_tables(table_name_a, table_name_b, table_name_c, table_name_diff):
    print(
        "calculate_trajectory_between_tables: start: (table_name_a, table_name_b, table_name_c, table_name_diff):",
        (table_name_a, table_name_b, table_name_c, table_name_diff),
    )
    db_insertion_data = []
    for lemma_embedding_a, lemma_embedding_b, lemma_embedding_c in query_over_mutual([table_name_a, table_name_b, table_name_c]):
        db_insertion_data.append(
            (
                lemma_embedding_a[0],
                int((lemma_embedding_a[1] + lemma_embedding_b[1] + lemma_embedding_c[1]) / 3),
                calculate_trajectory_between_vectors(lemma_embedding_a[2], lemma_embedding_b[2], lemma_embedding_c[2]),
            )
        )
    print("calculate_cos_sim_between_tables: len(db_insertion_data):", len(db_insertion_data))
    drop_table(table_name_diff)
    create_diff_table(table_name_diff)
    insert_to_db(table_name_diff, db_insertion_data, cols=["lemma", "occurrence_count", "diff"])


if TEST:
    diff_table_name = "word2vec__diff__trajectory__184_186"
    if RESET_DB:
        calculate_trajectory_between_tables("word2vec__184", "word2vec__185__aligned", "word2vec__186__aligned", diff_table_name)
    diff_table = query_generic(
        diff_table_name,
        select_cols=["lemma", "diff"],
        order_by="diff",
        order_desc=True,
    )
    plot_2d_scatter(diff_table, diff_table_name)

### calculate_relative_diff_between_tables

In [None]:
def calculate_relative_diff_between_tables(table_name_a, table_name_b, table_name_diff):
    print("calculate_relative_diff_between_tables: start: (table_name_a, table_name_b):", (table_name_a, table_name_b))
    db_insertion_data = []
    for mutual in query_mutual_lemmas([table_name_a, table_name_b], include_count=True):
        lemma = mutual[0]
        related_lemma_embedding_a = query_related(table_name_a, lemma, n=100)
        related_lemma_embedding_b = query_related(table_name_b, lemma, n=100)
        similarity_dict_a = {l: c for l, c in related_lemma_embedding_a}
        similarity_dict_b = {l: c for l, c in related_lemma_embedding_b}
        common_related_lemma = set(similarity_dict_a.keys()) & set(similarity_dict_b.keys())
        diff = 0
        for lemma_related in common_related_lemma:
            diff += abs(similarity_dict_a[lemma_related] - similarity_dict_b[lemma_related])
        db_insertion_data.append((lemma, int(mutual[1] / 2), diff))
    print("calculate_cos_sim_between_tables: len(db_insertion_data):", len(db_insertion_data))
    drop_table(table_name_diff)
    create_diff_table(table_name_diff)
    insert_to_db(table_name_diff, db_insertion_data, cols=["lemma", "occurrence_count", "diff"])


if TEST:
    diff_table_name = "word2vec__diff__relative__184_185"
    if RESET_DB:
        calculate_relative_diff_between_tables("word2vec__184", "word2vec__185", diff_table_name)
    diff_table = query_generic(
        diff_table_name,
        select_cols=["lemma", "diff"],
        order_by="diff",
        order_desc=False,
    )
    plot_2d_scatter(diff_table, diff_table_name)

## compare diff methods

### create_random_diff_table

In [None]:
def create_random_diff_table(table_name_from, table_name_random_a, table_name_random_b):
    lemma_table = query_generic(table_name_from, select_cols=["lemma"])

    range_min = -1000
    range_max = 1000
    random_diff_table_a = []
    random_diff_table_b = []
    for lemma in lemma_table:
        random_diff_table_a.append((lemma, 0, random.randint(range_min, range_max) / 1000))
        random_diff_table_b.append((lemma, 0, random.randint(range_min, range_max) / 1000))
    drop_table(table_name_random_a)
    drop_table(table_name_random_b)
    create_diff_table(table_name_random_a)
    create_diff_table(table_name_random_b)
    insert_to_db(table_name_random_a, random_diff_table_a, cols=COLS_DIFF)
    insert_to_db(table_name_random_b, random_diff_table_b, cols=COLS_DIFF)


if TEST:
    table_name_a = "test__diff__random__a"
    table_name_b = "test__diff__random__b"
    create_random_diff_table("word2vec__diff__cos_sim__184_185", table_name_a, table_name_b)

### normalize_diff_table

In [None]:
def normalize_diff_table(diff_table_name, diff_table_name_new, invert=False):
    print("normalize_diff_table: start: table_name_diff:", diff_table_name)
    diff_table = query_generic(diff_table_name, select_cols=COLS_DIFF, order_by="diff")
    min_diff = None
    max_diff = None
    for diff_row in diff_table:
        diff_value = diff_row[2]
        if min_diff is None and max_diff is None:
            min_diff = diff_value
            max_diff = diff_value
        else:
            if diff_value < min_diff:
                min_diff = diff_value
            if diff_value > max_diff:
                max_diff = diff_value
    scale = 2 / (max_diff - min_diff)
    print(max_diff)
    diff_table_new = []
    for diff_row in diff_table:
        value = ((diff_row[2] - min_diff) * scale) - 1
        if invert:
            value *= -1
        diff_table_new.append((diff_row[0], diff_row[1], value))
    print("normalize_diff_table: len(diff_table_new):", len(diff_table_new))
    drop_table(diff_table_name_new)
    create_diff_table(diff_table_name_new)
    insert_to_db(diff_table_name_new, diff_table_new, cols=COLS_DIFF)


if TEST:
    table_list = [
        ("word2vec__diff__cos_sim__184_185", "word2vec__diff__cos_sim__normalized__184_185", False),
        ("word2vec__diff__trajectory__184_186", "word2vec__diff__trajectory__normalized__184_186", False),
        ("word2vec__diff__relative__184_185", "word2vec__diff__relative__normalized__184_185", True),
        ("test__diff__random__a", "test__diff__random__normalized__a", False),
        ("test__diff__random__b", "test__diff__random__normalized__b", False),
    ]
    for table in table_list:
        normalize_diff_table(*table)
    for table in table_list:
        plot_2d_scatter(query_generic(table[1], select_cols=["lemma", "diff"], order_by="diff", order_desc=True))

### merge_normalized_diff_tables

In [None]:
def merge_normalized_diff_tables(table_name_a, table_name_b, diff_compared_table_name):
    print(
        "merge_normalized_diff_tables: start: (table_name_a, table_name_b, diff_compared_table_name)",
        (table_name_a, table_name_b, diff_compared_table_name),
    )

    db_insertion_data = []
    for lemma_diff_a, lemma_diff_b in query_over_mutual([table_name_a, table_name_b], select_cols=["lemma", "diff"]):
        db_insertion_data.append(
            (
                lemma_diff_a[0],
                lemma_diff_a[1],
                lemma_diff_b[1],
                abs(lemma_diff_a[1] - lemma_diff_b[1]),
            )
        )
    print("merge_normalized_diff_tables: len(db_insertion_data):", len(db_insertion_data))
    drop_table(diff_compared_table_name)
    create_merged_diff_table(diff_compared_table_name)
    insert_to_db(diff_compared_table_name, db_insertion_data, cols=["lemma", "diff_a", "diff_b", "diff_diff"])


if TEST:
    table_list = [
        (
            "word2vec__diff__cos_sim__normalized__184_185",
            "word2vec__diff__trajectory__normalized__184_186",
            "word2vec__merged__cos_sim_trajectory__184_186",
        ),
        (
            "word2vec__diff__cos_sim__normalized__184_185",
            "word2vec__diff__relative__normalized__184_185",
            "word2vec__merged__cos_relative__184_186",
        ),
        (
            "word2vec__diff__trajectory__normalized__184_186",
            "word2vec__diff__relative__normalized__184_185",
            "word2vec__merged__trajectory_relative__184_186",
        ),
        (
            "test__diff__random__normalized__a",
            "test__diff__random__normalized__b",
            "test__merged__random__a_b",
        ),
    ]
    for table in table_list:
        merge_normalized_diff_tables(*table)
    for table in table_list:
        plot_merged_diff_table(
            query_generic(
                table[2],
                select_cols=["lemma", "diff_a", "diff_b", "diff_diff"],
                order_by="diff_diff",
                order_desc=False,
            )
        )

## aggregate functions

### load_decade

In [None]:
def load_decade(decade):
    occurrence_dict = create_occurrence_dict(decade)
    load_word2vec_to_db(decade, occurrence_dict)


if TEST and RESET_DB:
    load_decade(185)

### load_decade_pair

In [None]:
def load_decade_pair(decade_a, decade_b, a_is_aligned=True):
    load_decade(decade_a)
    load_decade(decade_b)
    table_name_a = f"word2vec__{decade_a}"
    table_name_b = f"word2vec__{decade_b}"
    if a_is_aligned:
        table_name_aligned_a = f"word2vec__{decade_a}__aligned"
    else:
        table_name_aligned_a = table_name_a
    table_name_aligned_b = f"word2vec__{decade_b}__aligned"
    calculate_procrustes_alignment(table_name_aligned_a, table_name_b, table_name_aligned_b)
    calculate_cos_sim_between_tables(table_name_aligned_a, table_name_aligned_b, f"word2vec__diff__cos_sim__{decade_a}_{decade_b}")
    calculate_relative_diff_between_tables(table_name_a, table_name_b, f"word2vec__diff__relative__{decade_a}_{decade_b}")


if TEST and RESET_DB:
    load_decade_pair(185, 186, a_is_aligned=False)

### create_average_diff_tables

In [None]:
def create_average_diff_tables(decade_list):
    decade_start = decade_list[0]
    decade_end = decade_list[-1]
    for diff_table in ["diff__cos_sim", "diff__trajectory", "diff__relative"]:
        table_prefix = "word2vec"
        query_average_create_table(
            table_prefix=table_prefix,
            table_kind=diff_table,
            decade_list=decade_list,
            select_cols=["lemma", "diff"],
            table_name_avg=f"{table_prefix}__{diff_table}__avg__{decade_start}_{decade_end}",
            round_value=False,
        )


if TEST and RESET_DB:
    create_average_diff_tables(create_decades_list(184, 186))

### load_decade_from_list

In [None]:
def load_decade_from_list(decade_list):
    decade_a = decade_list[0]
    decade_b = decade_list[1]
    decade_c = decade_list[2]
    load_decade_pair(decade_a, decade_b, a_is_aligned=False)
    for decade in decade_list[3:] + [None]:
        load_decade_pair(decade_b, decade_c)
        table_name_a = f"word2vec__{decade_a}"
        table_name_b = f"word2vec__{decade_b}"
        table_name_c = f"word2vec__{decade_c}"
        calculate_trajectory_between_tables(table_name_a, table_name_b, table_name_c, f"word2vec__diff__trajectory__{decade_a}_{decade_c}")
        decade_a = decade_b
        decade_b = decade_c
        decade_c = decade
    create_average_diff_tables(decade_list)

if TEST and RESET_DB:
    load_decade_from_list(create_decades_list(184, 186))

# load_all (warning: may delete entire DB!)

In [None]:
decade_list = create_decades_list()
if RESET_DB:
    reset_db(decade_list)
    load_decade_from_list(decade_list)

# Analysis

## global changes analysis

In [None]:
decade_list = create_decades_list()

### average changes

In [None]:
plot_average_diff(decade_list, "word2vec__diff__cos_sim__avg__156_191", lemma_min_occurrence=10)

In [None]:
plot_average_diff(decade_list, "word2vec__diff__trajectory__avg__156_191", lemma_min_occurrence=10)

In [None]:
plot_average_diff(decade_list, "word2vec__diff__relative__avg__156_191", lemma_min_occurrence=10, order_desc=False)

### changes per lemma

In [None]:
plot_lemma_diff(decade_list, "Aktion", "word2vec__diff__cos_sim", yaxis_range=[0,1])

In [None]:
plot_lemma_diff(decade_list, "Aktion", "word2vec__diff__trajectory")

In [None]:
plot_lemma_diff(decade_list, "Zeitung", "word2vec__diff__cos_sim", yaxis_range=[0, 1])

In [None]:
plot_lemma_diff(decade_list, "d", "word2vec__diff__cos_sim", yaxis_range=[0,1])

In [None]:
plot_lemma_diff(decade_list, "Adam", "word2vec__diff__cos_sim", yaxis_range=[0,1])

In [None]:
plot_lemma_diff(decade_list, "d", "word2vec__diff__trajectory")

In [None]:
plot_lemma_diff(decade_list, "Adam", "word2vec__diff__trajectory")

### trajectories

In [None]:
plot_lemma_and_decade(decade_list, ["Zeitung"])

# DB Close

In [None]:
# cursor.close()
# conn.close()