# setup

## imports

In [None]:
import os
from dataclasses import dataclass
from textwrap import dedent

import numpy as np
import psycopg
import spacy
import torch
import torch.nn.functional as F
from pgvector.psycopg import register_vector
from psycopg.sql import SQL, Identifier, Literal
from sklearn.manifold import TSNE
from spacy.tokens import Doc
from transformers import AutoModel, AutoTokenizer, XLMRobertaModel, XLMRobertaTokenizer

## global vars

In [None]:
# settings
ENABLE_TEST = True
TABLE_NAME_SENTENCES = os.getenv("table_name_sentences")
TABLE_NAME_LEMMAS = os.getenv("table_name_lemmas")
TABLE_NAME_EMBEDDINGS = os.getenv("table_name_embeddings")
INPUT_FILE_PATH = "/veld/input/" + os.getenv("input_file")

# models
# MODEL_NAME = "deepset/gbert-base"
MODEL_NAME = "dbmdz/bert-base-german-cased"
# MODEL_NAME = "FacebookAI/xlm-roberta-large"
# MODEL_NAME = "FacebookAI/roberta-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# tokenizer = XLMRobertaTokenizer.from_pretrained(MODEL_NAME)
# model = XLMRobertaModel.from_pretrained(MODEL_NAME)
model.eval()
nlp = spacy.load(os.getenv("spacy_model"))

# DB
conn = psycopg.connect(
    dbname=os.getenv("POSTGRES_DB"),
    user=os.getenv("POSTGRES_USER"),
    password=os.getenv("POSTGRES_PASSWORD"),
    host=os.getenv("POSTGRES_HOST"),
)
conn.autocommit = True
register_vector(conn)
cursor = conn.cursor()
cursor.execute("SELECT version();")
print(cursor.fetchone())


# data structures
@dataclass
class SentenceEmbedded:
    sent_id: int
    text: str
    token_list: list
    embedding_list: list
    lemma_list: list

# modular functions

## set_up_db

In [None]:
def set_up_db():
    for table_name in [TABLE_NAME_EMBEDDINGS, TABLE_NAME_LEMMAS, TABLE_NAME_SENTENCES]:
        query = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(table_name))
        print(query.as_string())
        cursor.execute(query)

    # sentences
    query = SQL(
        dedent(
            """\
            CREATE TABLE {table_name_sentences} (
                sentence_id INTEGER PRIMARY KEY,
                text TEXT
            )
            """
        )
    ).format(table_name_sentences=Identifier(TABLE_NAME_SENTENCES))
    print(query.as_string())
    cursor.execute(query)

    # lemmas
    query = SQL(
        dedent(
            """\
            CREATE TABLE {table_name_lemmas} (
                lemma_id SERIAL PRIMARY KEY,
                lemma_text TEXT,
                CONSTRAINT lemma_text_unique UNIQUE (lemma_text)
            )
            """
        )
    ).format(table_name_lemmas=Identifier(TABLE_NAME_LEMMAS))
    print(query.as_string())
    cursor.execute(query)

    # embeddings
    query = SQL(
        dedent(
            """\
            CREATE TABLE {table_name_embeddings} (
                token_text TEXT,
                token_id INTEGER,
                lemma_id_fk INTEGER REFERENCES {table_name_lemmas}(lemma_id),
                sentence_id_fk INTEGER REFERENCES {table_name_sentences}(sentence_id),
                PRIMARY KEY (sentence_id_fk, token_id),
                embedding VECTOR({vector_dim})
            )
            """
        )
    ).format(
        table_name_embeddings=Identifier(TABLE_NAME_EMBEDDINGS),
        table_name_lemmas=Identifier(TABLE_NAME_LEMMAS),
        table_name_sentences=Identifier(TABLE_NAME_SENTENCES),
        vector_dim=Literal(768),
    )
    print(query.as_string())
    cursor.execute(query)

## create_sent_embedded

In [None]:
def create_sent_embedded(sent_id, sent, tokens_list, embeddings_list):
    token_list = []
    embedding_list = []
    lemma_list = []
    t_prev = None
    e_prev = []
    for i, (t, e) in enumerate(zip(tokens_list, embeddings_list)):
        if t and t.startswith("##"):
            t_prev += t[2:]
            e_prev.append(e)
        else:
            if t_prev:
                token_list.append(t_prev)
                embedding_list.append(torch.mean(torch.stack(e_prev), dim=0))
            t_prev = t
            e_prev = [e]
    token_list.append(t_prev)
    embedding_list.append(torch.mean(torch.stack(e_prev), dim=0))

    doc = Doc(nlp.vocab, words=token_list)
    doc = nlp.get_pipe("tok2vec")(doc)
    doc = nlp.get_pipe("lemmatizer")(doc)
    for token in doc:
        lemma_list.append(token.lemma_)

    return SentenceEmbedded(
        sent_id=sent_id,
        text=sent,
        token_list=token_list,
        embedding_list=embedding_list,
        lemma_list=lemma_list,
    )

## infer_embeddings

In [None]:
def infer_embeddings(sent_id, sent):
    inputs = tokenizer(sent, return_tensors="pt", add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.squeeze(0)
    return create_sent_embedded(sent_id, sent, tokens, embeddings)

## insert_into_db

In [None]:
def insert_into_db(table, sent_embedded):

    query = SQL(
        "INSERT INTO {table_name_sentences} (sentence_id, text) VALUES (%s, %s)"
    ).format(table_name_sentences=Identifier(TABLE_NAME_SENTENCES))
    if ENABLE_TEST:
        print(query.as_string())
    cursor.execute(query, (sent_embedded.sent_id, sent_embedded.text))

    for token_id, (token_text, lemma_text, embedding) in enumerate(
        zip(sent_embedded.token_list, sent_embedded.lemma_list, sent_embedded.embedding_list)
    ):

        # lemmas
        query = SQL(
            dedent(
                """\
                INSERT INTO {table_name_lemmas} (lemma_text) VALUES ({lemma_text})
                ON CONFLICT (lemma_text) DO NOTHING
                """
            )
        ).format(
            table_name_lemmas=Identifier(TABLE_NAME_LEMMAS),
            lemma_text=Literal(lemma_text),
        )
        if ENABLE_TEST:
            print(query.as_string())
        cursor.execute(query)

        # embeddings
        query = SQL(
            dedent(
                """\
                INSERT INTO {table_name_embeddings} (
                    token_text,
                    token_id,
                    lemma_id_fk,
                    sentence_id_fk,
                    embedding
                )
                VALUES (
                    %s,
                    {token_id},
                    (SELECT lemma_id from {table_name_lemmas} WHERE lemma_text=%s),
                    {sentence_id_fk},
                    %s
                )
                """
            )
        ).format(
            table_name_embeddings=Identifier(TABLE_NAME_EMBEDDINGS),
            table_name_lemmas=Identifier(TABLE_NAME_LEMMAS),
            token_id=Literal(token_id),
            sentence_id_fk=Literal(sent_embedded.sent_id)
        )
        if ENABLE_TEST:
            print(query.as_string())
        cursor.execute(query, (token_text, lemma_text, embedding.tolist(),))

## iterate_over_file

In [None]:
def iterate_over_file():
    with open(INPUT_FILE_PATH, "r") as file:
        for line in file:
            yield line

# inference and persistence

In [None]:
set_up_db()

limit = 10
for sent_id, sent in enumerate(iterate_over_file()):
    if ENABLE_TEST and sent_id == limit:
        break
    sent_embedded = infer_embeddings(sent_id, sent)
    if ENABLE_TEST:
        print("-------------------------------------------------------------")
        print(sent.replace("\n", ""))
        print(sent_embedded.token_list)
        print(sent_embedded.lemma_list)
    insert_into_db(TABLE_NAME_EMBEDDINGS, sent_embedded)