# Embedding generation

This notebook contains the code to train and generate the different embedding models that are evaluated in our paper.

Paper reference: _Splieth√∂ver, Keiff, Wachsmuth (2022): "No Word Embedding Model Is Perfect: Evaluating the Representation Accuracy for Social Bias in the Media", EMNLP 2022, Abu Dhabi._

Code & Data reference: https://github.com/webis-de/EMNLP-22

## Data preparation and loading

Please run the following two cells for any of the embedding models. They load the most common packages and set commonly used variables. They are necessary to run the training cells below.

In [None]:
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
import sqlite3
import sys

from os import listdir, path
from gensim.models import KeyedVectors, Word2Vec
from tokenizers import Tokenizer, normalizers
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import Lowercase, NFD, StripAccents, Strip, Replace
from tqdm.notebook import tqdm

PARENT_DIR = path.abspath("../src")
sys.path.append(PARENT_DIR)
from embedding_bias.config import (
    CRAWL_FIRST_YEAR, CRAWL_LAST_YEAR, SENTENCE_ENDING_TOKEN, NEWS_ARTICLE_DB_NAME)
from embedding_bias.preprocessing import preprocess_text
from embedding_bias.util import *

In [None]:
DATA_DIR = path.join(PARENT_DIR.parent, "data")
DB_PATH = path.join(DATA_DIR, "raw", NEWS_ARTICLE_DB_NAME)
ALLSIDES_RANKING_PATH = path.join(DATA_DIR, "raw", "allsides-ranking.csv")
OUTLET_CONFIG_PATH = path.join(DATA_DIR, "raw", "outlet-config.json")
WORD_SETS_PATH = path.join(DATA_DIR, "raw", "word-sets.json")

# Target sqlite database
target_db_connection = sqlite3.connect(DB_PATH)

# Outlet config file
outlet_config = pd.read_json(OUTLET_CONFIG_PATH)
outlet_selection = outlet_config

# Word sets
with open(WORD_SETS_PATH, "r") as f:
    word_sets = json.load(f)

# Allsides ranking
allsides_ranking = pd.read_csv(ALLSIDES_RANKING_PATH)

# Map from outlet name to political orientation
outlet_orientation_map = {
    o["name"].lower(): o["allsides_rating"]
    for i,o in outlet_selection.iterrows()}

# Groups of political orientations
orientation_groups = {
    "left": ["Lean Left", "Left"],
    "center": ["Center"],
    "right": ["Lean Right", "Right"]}

## word2vec models (Static embeddings)

In [None]:
W2V_ARTICLE_PREPROCESS_CACHE = path.join(
    DATA_DIR, "processed", "nato", "articles-preproc-cache.pkl")
W2V_MODEL_PATH = path.join(DATA_DIR, "models", "nato-w2v")
LOW_COUNT_TOKEN_REPLACE = "<unk>"

# Whitespace tokenization
# lowercasing
def hf_tokenize(texts):
    normalizers_seq = [
        Replace(SENTENCE_ENDING_TOKEN, " "),
        NFD(),
        Lowercase(),
        StripAccents(),
        Strip(),
    ]
    normalizer = normalizers.Sequence(normalizers_seq)
    pre_tokenizer = Whitespace()
    tokenizer = Tokenizer(WordLevel(unk_token=LOW_COUNT_TOKEN_REPLACE))
    tokenizer.normalizer = normalizer
    tokenizer.pre_tokenizer = pre_tokenizer
    trainer = WordLevelTrainer(special_tokens=[LOW_COUNT_TOKEN_REPLACE])

    print("Starting training.")
    tokenizer.train_from_iterator(texts, trainer)

    print("Starting tokenization.")
    output = [tokenizer.encode_batch(doc) for doc in tqdm(texts, mininterval=5)]

    print("Extracting tokens.")
    return [[t.tokens for t in doc] for doc in tqdm(output, mininterval=5)]

if path.exists(W2V_ARTICLE_PREPROCESS_CACHE):
    print("Found article cache. Loading from file.")
    with open(W2V_ARTICLE_PREPROCESS_CACHE, "rb") as f:
        articles = pickle.load(f)
else:
    print("No cache found. Loading from database.")
    articles = get_articles_as_df(
        allsides_ranking=allsides_ranking,
        db_connection=target_db_connection,
        outlet_selection=outlet_selection,
        preprocessed=True)
    print("Sentence-splitting articles.")
    articles.text = articles.text.apply(
        lambda x: x.split(SENTENCE_ENDING_TOKEN))
    print("Tokenizing articles.")
    # Requires min. 128GB mem., but is fast af
    articles["text_prep"] = hf_tokenize(articles.text.tolist())
    with open(W2V_ARTICLE_PREPROCESS_CACHE, "wb") as f:
        pickle.dump(articles, f)

In [None]:
class TextCorpora:
    """Memory-friendly data loader"""
    def __init__(self, documents):
        self.docs = documents

    def __iter__(self):
        for doc in self.docs:
            for sent in doc:
                yield sent.split()

def get_outlet_orientation(outlet: str) -> str:
    return outlet_orientation_map[outlet.lower()]

def train_embeddings(document_iterator):
    return Word2Vec(
        document_iterator,
        vector_size=300,  # dimensionality of word vectors
        window=5,  # max dist. between current and predicted word within a sent.
        min_count=5,  # ignores all words with total frequency lower than this
        workers=32,  # number of threads to train the model
        sg=1, # whether to use skip-gram (1) or cbow (0)
        epochs=5 # epochs to train for
    )

# ------------------------------------------------------------------------------
# Train word embedding models per orientation
for orientation, grouping in orientation_groups.items():
    # Train word2vec from data
    orientation_articles = articles[
        articles.orientation.isin(grouping)].text_prep
    print(f"Flattening articles for {orientation}.")
    articles_flat = [
        sentence for doc in orientation_articles for sentence in doc]
    print(f"Training word embedding model '{orientation}' with Word2Vec...")
    orientation_iter = TextCorpora(articles_flat)

    word_embedding = train_embeddings(orientation_iter)
    word_embedding.save(f"{W2V_MODEL_PATH}/{orientation}.model")

print("Done.")

## Frequency agnostic embeddings (FreqAgn)

In [None]:
from collections import Counter

from nltk.tokenize import word_tokenize

from tokenizers import Tokenizer, normalizers
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import Lowercase, NFD, StripAccents, Strip, Replace

tqdm.pandas()

In [None]:
def preprocess_text(text):
    text_ = text.replace(SENTENCE_ENDING_TOKEN, " ")
    # Moses tokenization, normalization, lowercasing
    return word_tokenize(text_.lower())

def hf_tokenize(texts):
    # IMPROTANT NOTE: REQUIRES tokenizers version >= 12.x
    normalizers_seq = [
        Replace(SENTENCE_ENDING_TOKEN, " "),
        NFD(),
        Lowercase(),
        StripAccents(),
        Strip()
    ]
    normalizer = normalizers.Sequence(normalizers_seq)
    pre_tokenizer = Whitespace()
    tokenizer = Tokenizer(WordLevel(unk_token=LOW_COUNT_TOKEN_REPLACE))
    tokenizer.normalizer = normalizer
    tokenizer.pre_tokenizer = pre_tokenizer
    trainer = WordLevelTrainer(
        vocab_size=MAX_VOCAB_SIZE,
        special_tokens=[LOW_COUNT_TOKEN_REPLACE, SENTENCE_ENDING_TOKEN],
        show_progress=True)

    print("Starting training.")
    tokenizer.train_from_iterator([" ".join(text) for text in texts], trainer)

    print("Starting tokenization.")
    output = tokenizer.encode_batch([" ".join(doc) for doc in texts])

    print("Extracting tokens.")
    return [doc.tokens for doc in tqdm(output, mininterval=5)]

In [None]:
# Loading the original data
AWSLSTM_CORPUS_PATH = path.join(DATA_DIR, "processed", "corpus-awd-lstm-format")
ARTICLE_PREPROCESS_CACHE = path.join(
    AWSLSTM_CORPUS_PATH, "articles_preprocessed.pkl")

LOW_COUNT_TOKEN_REPLACE = "<unk>"
MAX_VOCAB_SIZE = 40000

if path.exists(ARTICLE_PREPROCESS_CACHE):
    print("Found article cache. Loading from file.")
    with open(ARTICLE_PREPROCESS_CACHE, "rb") as f:
        articles = pickle.load(f)
else:
    print("No cache found. Loading from database.")
    articles = get_articles_as_df(
        db_connection=target_db_connection,
        outlet_selection=outlet_selection,
        allsides_ranking=allsides_ranking)
    print("Preprocessing articles.")
    articles.text = articles.text.apply(
        lambda x: x.split(SENTENCE_ENDING_TOKEN))
    # Requires min. 128GB mem., but is fast af
    articles["text_prep"] = hf_tokenize(articles.text)
    with open(ARTICLE_PREPROCESS_CACHE, "wb") as f:
        pickle.dump(articles, f)

In [None]:
LOW_COUNT_THRESHOLD = 5

def get_token_counts(articles, threshold=LOW_COUNT_THRESHOLD) -> list:
    token_counter = Counter()
    for i, row in articles.iterrows():
        token_counter.update(row.text_prep)

    print(f"Original dict size: {len(token_counter.keys())}")
    low_counts = [t for t, c in dict(token_counter).items() if c <= threshold]
    print(f"Tokens that don't pass the threshold: {len(low_counts)}")

    return dict(token_counter)

# Prepare our data into AWS-LSTM specific format (only needs to be run the first
# time).
# This is simple white-space separated tokens in a text file
# for orientation in articles.orientation.unique():
for orientation, grouping in orientation_groups.items():
    print(f"{'=' * 30} {orientation}")
    orientation_filename_safe = orientation.lower().replace(" ", "-")

    orientation_articles = articles[articles.orientation.isin(grouping)]
    one_percent_articles = (len(orientation_articles) - 1) / 100

    # Retrieve tokens with very few mentions to reduce vocab size
    print("Counting token occurrences.")
    token_counter = get_token_counts(
        orientation_articles, threshold=LOW_COUNT_THRESHOLD)
    placeholder = f" {LOW_COUNT_TOKEN_REPLACE} "

    train_indices = [0, int(one_percent_articles * 70)]
    test_indices = [
        train_indices[1] + 1,
        train_indices[1] + 1 + int(one_percent_articles * 20)]
    val_indices = [test_indices[1] + 1, -1]

    train_file_name = path.join(
        AWSLSTM_CORPUS_PATH, orientation_filename_safe, "train.txt")
    test_file_name = path.join(
        AWSLSTM_CORPUS_PATH, orientation_filename_safe, "test.txt")
    val_file_name = path.join(
        AWSLSTM_CORPUS_PATH, orientation_filename_safe, "valid.txt")

    with open(train_file_name, mode="w") as f:
        print("Generating train data.")
        for i, row in orientation_articles.iloc[train_indices[0]:train_indices[1]].iterrows():
            text_clean = [
                t if token_counter[t] > LOW_COUNT_THRESHOLD
                    else LOW_COUNT_TOKEN_REPLACE
                for t in row.text_prep]
            text_joint = " ".join(text_clean)
            f.write(text_joint)
            f.write("\n")
    with open(test_file_name, mode="w") as f:
        print("Generating test data.")
        for i, row in orientation_articles.iloc[test_indices[0]:test_indices[1]].iterrows():
            text_clean = [
                t if token_counter[t] > LOW_COUNT_THRESHOLD
                    else LOW_COUNT_TOKEN_REPLACE 
                for t in row.text_prep]
            text_joint = " ".join(text_clean)
            f.write(text_joint)
            f.write("\n")
    with open(val_file_name, mode="w") as f:
        print("Generating validation data.")
        for i, row in orientation_articles.iloc[val_indices[0]:val_indices[1]].iterrows():
            text_clean = [
                t if token_counter[t] > LOW_COUNT_THRESHOLD
                    else LOW_COUNT_TOKEN_REPLACE
                for t in row.text_prep]
            text_joint = " ".join(text_clean)
            f.write(text_joint)
            f.write("\n")

The steps above actually only prepare the data for the embedding generation. To run the training, please use the script `src/Frequency-Agnostic/frage-lstm-train.sh`. Refer to the README file for details.

## Decontextualized Embeddings (Decontext)


In [None]:
import nltk
import torch

from flair.embeddings import TransformerWordEmbeddings
from flair.data import Sentence
from nltk.tokenize import word_tokenize

In [None]:
# Download some NLTK dependencies
nltk.download('punkt')

In [None]:
DECON_CACHE_PATH = path.join(DATA_DIR, "processed", "contextualized2static")
DECON_MODEL_PATH = path.join(DATA_DIR, "models", "contextualized2static")
SIM_EVAL_DATA_DIR = path.join(
    PARENT_DIR, "embedding_evaluation", "embedding_evaluation", "data")

# Max number of sentences before cutting of (due to CUDA-OOM)
MAX_SENTENCE_NUMBER = 1000000

# Pre-trained contextualized embedding model
embedding_model = TransformerWordEmbeddings("bert-base-uncased")

In [None]:
# Loading news article data
articles = get_articles_as_df(
    db_connection=target_db_connection,
    outlet_selection=outlet_selection,
    allsides_ranking=allsides_ranking)

In [None]:
vocabulary = get_test_vocabulary(word_sets=word_sets, similarity_eval_data_path=SIM_EVAL_DATA_DIR)
vocab_list = list(vocabulary)

# Split all articles by sentence; use the cached version, if available
sentences_per_orientation_cache_file = path.join(
    DECON_CACHE_PATH, "sentences-tokenized-per-orientation.pkl")
if path.exists(sentences_per_orientation_cache_file):
    print("Found sentence per orientation cache. Loading from file.")
    with open(sentences_per_orientation_cache_file, "rb") as f:
        sentences_by_orientation = pickle.load(f)
else:
    articles.text = articles.text.apply(
        lambda x: x.split(SENTENCE_ENDING_TOKEN))
    sentences_by_orientation = {}

# For each article, collect all sentences and sort them according to the
# political orientation of its media outlet.
# Use a cache, if available.
for orientation, grouping in orientation_groups.items():
    print("=" * 30, orientation)

    # Collect all sentences of the current orientation; we don't care about
    # document levels or specific outlets anymore at this point, so we can have
    # them simply in a flat list
    orientation_articles = articles[articles.orientation.isin(grouping)]
    sentences_tokenized_cache_file = path.join(
        DECON_CACHE_PATH, f"{orientation}-sentences-tokenized.pkl")

    if path.exists(sentences_tokenized_cache_file):
        print("Found tokenized cache. Loading from file.")
        with open(sentences_tokenized_cache_file, "rb") as f:
            orientation_sentences_tokenized = pickle.load(f)
    else:
        print("No tokenized cache found. Tokenizing sentences now.")
        orientation_sentences_tokenized = [
            word_tokenize(s.lower())
            for sents in tqdm(orientation_articles["text"])
            for s in sents]
        print("Caching tokenized sentences.")
        with open(sentences_tokenized_cache_file, "wb") as f:
            pickle.dump(orientation_sentences_tokenized, f)

    # Retrieve sentences of each token
    print("Retrieving sentences.")
    if orientation in sentences_by_orientation.keys():
        print("Found orientation in cache.")
        orientation_sentences = sentences_by_orientation[orientation]
    else:
        orientation_sentences = {}

    for word in tqdm(vocab_list):
        if word.lower() in orientation_sentences.keys():
            continue
        orientation_sentences[word.lower()] = [
            " ".join(s)
            for s in orientation_sentences_tokenized
            if word.lower() in s]

    sentences_by_orientation[orientation] = orientation_sentences

# Cache result to disk
with open(sentences_per_orientation_cache_file, "wb") as f:
    pickle.dump(sentences_by_orientation, f)

In [None]:
# Retrieve contextualized embeddings per token for each sentence
def embed_token(token, sentence: str) -> list:
    flair_sentence = Sentence(sentence)
    embedding_model.embed(flair_sentence)

    # Find token position in sentence
    # We use the flair sentence here instead of the original string to ensure,
    # that the index of the string and the embedding is really the same.
    i = flair_sentence.to_original_text().split().index(token)

    return flair_sentence[i].embedding

# Retrieve all sentences a token appears in and generate pooled embeddings
for orientation, grouping in orientation_groups.items():
    print("=" * 30, orientation)
    orientation_context_embedded_cache_file = path.join(
        DECON_CACHE_PATH, f"orientation-context-embedded-{orientation}.pkl")

    if path.exists(orientation_context_embedded_cache_file):
        print("Found existing embedding cache. Loading from file.")
        with open(orientation_context_embedded_cache_file, "rb") as f:
            orientation_context_embeddings = pickle.load(f)
    else:
        print("No embedding cache found.")
        orientation_context_embeddings = {}

    # Retrieve the contextualized embeddings per token for each orientation
    for token, sentences in sentences_by_orientation[orientation].items():
        if len(sentences) < 1:
            print(f"No sentences for token '{token}'. Skipping")
            continue
        elif len(sentences) > MAX_SENTENCE_NUMBER:
            print(
                f"Found {len(sentences)} sentences for '{token}'. \
                Truncating to {MAX_SENTENCE_NUMBER}.")
            sentences = sentences[:MAX_SENTENCE_NUMBER]
        else:
            print(f"Found {len(sentences)} sentences for '{token}'. No truncation necessary.")

        if token in orientation_context_embeddings.keys():
            print(f"Found '{token}' in cache. Skipping.")
            continue

        print("Generating token embeddings.")
        token_embeds_lst = []
        for s in tqdm(sentences, mininterval=5, leave=False):
            token_embeds_lst.append(embed_token(token, s))
        token_sentence_embeddings = torch.vstack(token_embeds_lst)

        orientation_context_embeddings[token] = {
            "mean_pooled_embedding": torch.mean(token_sentence_embeddings, dim=0, keepdim=True),
            "max_pooled_embedding": torch.amax(token_sentence_embeddings, dim=0, keepdim=True),
            "min_pooled_embedding": torch.amin(token_sentence_embeddings, dim=0, keepdim=True)
        }

        # Write embeddings for current token to a cache file, as this operation can take some
        # time and the process might be interrupted multiple times.
        with open(orientation_context_embedded_cache_file, "wb") as f:
            pickle.dump(orientation_context_embeddings, f)

In [None]:
# Generating final token->embedding dictionary from the cached embedding file.
# Uses the specific pooled embeddings setting defined below.

vocabulary = get_test_vocabulary(word_sets=word_sets, similarity_eval_data_path=SIM_EVAL_DATA_DIR)
vocab_list = list(vocabulary)
pooled_embedding_type = "mean_pooled_embedding"

# For each orientation, load the embedding cache, copy tensors to CPU and build the final embedding
# dictionary. Write the results directly to disk.
for orientation, grouping in orientation_groups.items():
    print("=" * 30, orientation)
    model_path = f"{DECON_MODEL_PATH}/{orientation}.model"
    orientation_context_embedded_cache_file = path.join(
        DECON_CACHE_PATH, f"orientation-context-embedded-{orientation}.pkl")

    print("Loading cached orientation embeddings.")
    orientation_context_embeddings = {}
    with open(orientation_context_embedded_cache_file, "rb") as f:
        orientation_context_embeddings = pickle.load(f)

    vocab_size = len(orientation_context_embeddings)
    # Retrieve a sample token to get the embedding length
    embedding_size = len(orientation_context_embeddings["woman"][pooled_embedding_type][0])

    print("Copying vectors to CPU and writing to disk.")
    with open(f"{model_path}", "w") as f:
        f.write(f"{vocab_size} {embedding_size}\n")
        for token in tqdm(vocabulary, mininterval=2):
            try:
                token_vector = (
                    orientation_context_embeddings[token][pooled_embedding_type][0].tolist())
                token_vector_str = " ".join([str(d) for d in token_vector])
                f.write(f"{token} {token_vector_str}\n")
            except KeyError:
                print(f"Token {token} not found. Skipping.")

print("Generation done.")

## Temporal models

In [None]:
import torch

from flair.embeddings import TransformerWordEmbeddings
from flair.data import Sentence
from nltk.tokenize import word_tokenize

In [None]:
DECON_CACHE_PATH = path.join(DATA_DIR, "processed", "contextualized2static", "temporal")
DECON_MODEL_PATH = path.join(DATA_DIR, "models", "contextualized2static", "temporal")
SIM_EVAL_DATA_DIR = path.join(
    PARENT_DIR, "embedding_evaluation", "embedding_evaluation", "data")

# Max number of sentences before cutting of (due to CUDA-OOM)
MAX_SENTENCE_NUMBER = 1000000

# Pre-trained contextualized embedding model
embedding_model = TransformerWordEmbeddings("bert-base-uncased")

In [None]:
# Loading news article data
articles = get_articles_as_df(
    db_connection=target_db_connection,
    outlet_selection=outlet_selection,
    allsides_ranking=allsides_ranking)

In [None]:
# Filter articles based on their publication year, so we only keep the timeframe of interest
articles = articles[~(articles.date.str.len() < 10)]
articles["date_dt"] = pd.to_datetime(articles.date, format="%Y-%m-%d")
articles = articles[(articles.date_dt.dt.year >= 2010) & (articles.date_dt.dt.year <= 2021)]
articles["year"] = articles.date_dt.apply(lambda x: x.year)

In [None]:
# Retrieve and cache all sentences for each orientation and year

vocabulary = get_test_vocabulary(word_sets=word_sets, similarity_eval_data_path=SIM_EVAL_DATA_DIR)
vocab_list = list(vocabulary)

# Split all articles by sentence; use the cached version, if available
sentences_per_orientation_cache_file = path.join(
    DECON_CACHE_PATH, "sentences-tokenized-per-orientation.pkl"
if path.exists(sentences_per_orientation_cache_file):
    print("Found sentence per orientation cache. Loading from file.")
    with open(sentences_per_orientation_cache_file, "rb") as f:
        sentences_by_orientation = pickle.load(f)
else:
    articles.text = articles.text.apply(lambda x: x.split(SENTENCE_ENDING_TOKEN))
    sentences_by_orientation = {}

# For each article, collect all sentences and sort them according to the
# political orientation of its media outlet and the year of its publication.
# Use a cache, if available.
for orientation, grouping in orientation_groups.items():
    print("=" * 30, orientation)
    articles_filtered = articles[articles.orientation.isin(grouping)]

    for year in articles_filtered.year.unique():
        print("=" * 20, year)

        # Collect all sentences of the current orientation and year; we don't care about document
        # levels or specific outlets at this point, so we can simply collect them in a flat list.
        orientation_year_articles = articles_filtered[articles_filtered.year == year]
        print(f"Processing {len(orientation_year_articles)} articles.")
        sentences_tokenized_cache_file = path.join(
            DECON_CACHE_PATH, f"{orientation}-{year}-sentences-tokenized.pkl")

        if path.exists(sentences_tokenized_cache_file):
            print("Found tokenized cache. Loading from file.")
            with open(sentences_tokenized_cache_file, "rb") as f:
                orientation_year_sentences_tokenized = pickle.load(f)
        else:
            print("No tokenized cache found. Tokenizing sentences now.")
            orientation_year_sentences_tokenized = [
                word_tokenize(s.lower())
                for sents in tqdm(orientation_year_articles["text"]) for s in sents]
            print("Caching tokenized sentences.")
            with open(sentences_tokenized_cache_file, "wb") as f:
                pickle.dump(orientation_year_sentences_tokenized, f)

        # Retrieve sentences of each token
        print("Retrieving sentences.")
        if f"{orientation}-{year}" in sentences_by_orientation.keys():
            print("Found orientation in cache.")
            orientation_year_sentences = sentences_by_orientation[f"{orientation}-{year}"]
        else:
            orientation_year_sentences = {}

        # Retrieve sentences for each token and join the tokens of each sentence into a string.
        for word in tqdm(vocab_list):
            word_lower = word.lower()
            if word_lower in orientation_year_sentences.keys():
                continue
            orientation_year_sentences[word_lower] = [
                " ".join(s) for s in orientation_year_sentences_tokenized if word_lower in s]

        sentences_by_orientation[f"{orientation}-{year}"] = orientation_year_sentences

# Cache final embedding dictionary
with open(sentences_per_orientation_cache_file, "wb") as f:
    pickle.dump(sentences_by_orientation, f)

In [None]:
# Retrieve contextualized embeddings per token for each sentence
def embed_token(token, sentence: str) -> list:
    flair_sentence = Sentence(sentence)
    embedding_model.embed(flair_sentence)

    # Find token position in sentence
    # We use the flair sentence here instead of the original string to ensure, that the index of the
    # string and the embedding is really the same.
    i = flair_sentence.to_original_text().split().index(token)

    return flair_sentence[i].embedding

# Retrieve all sentences a token appears in and generate pooled embeddings
for orientation, grouping in orientation_groups.items():
    print("=" * 30, orientation)
    articles_filtered = articles[articles.orientation.isin(grouping)]

    for year in articles_filtered.year.unique():
        print("=" * 20, year)
        orientation_year_context_embedded_cache_file = path.join(
            DECON_CACHE_PATH, f"orientation-context-embedded-{orientation}-{year}.pkl")

        if path.exists(orientation_year_context_embedded_cache_file):
            print("Found existing embedding cache. Loading from file.")
            with open(orientation_year_context_embedded_cache_file, "rb") as f:
                orientation_year_context_embeddings = pickle.load(f)
        else:
            print("No embedding cache found.")
            orientation_year_context_embeddings = {}

        # Retrieve the contextualized embeddings per word for each orientation
        for token, sentences in sentences_by_orientation[f"{orientation}-{year}"].items():
            if len(sentences) < 1:
                print(f"No sentences for token '{token}'. Skipping")
                continue
            elif len(sentences) > MAX_SENTENCE_NUMBER:
                print(
                    f"Found {len(sentences)} sentences for '{token}'. \
                    Truncating to {MAX_SENTENCE_NUMBER}.")
                sentences = sentences[:MAX_SENTENCE_NUMBER]
            else:
                print(f"Found {len(sentences)} sentences for '{token}'. No truncation necessary.")

            if token in orientation_year_context_embeddings.keys():
                print(f"Found '{token}' in cache. Skipping.")
                continue

            print("Generating token embeddings.")
            token_embeds_lst = []
            for s in tqdm(sentences, mininterval=5, leave=False):
                token_embeds_lst.append(embed_token(token, s))
            token_sentence_embeddings = torch.vstack(token_embeds_lst)

            orientation_year_context_embeddings[token] = {
                "mean_pooled_embedding": torch.mean(token_sentence_embeddings, dim=0, keepdim=True),
                "max_pooled_embedding": torch.amax(token_sentence_embeddings, dim=0, keepdim=True),
                "min_pooled_embedding": torch.amin(token_sentence_embeddings, dim=0, keepdim=True)
            }

            # Cache generated embeddings to disk
            with open(orientation_year_context_embedded_cache_file, "wb") as f:
                pickle.dump(orientation_year_context_embeddings, f)

In [None]:
# Generating final token->embedding dictionary from the cached embedding file.
# Uses the specific pooled embeddings setting defined below.

vocabulary = get_test_vocabulary(word_sets=word_sets, similarity_eval_data_path=SIM_EVAL_DATA_DIR)
vocab_list = list(vocabulary)
pooled_embedding_type = "mean_pooled_embedding"

# Load embedding cache
for orientation, grouping in orientation_groups.items():
    print("=" * 30, orientation)
    articles_filtered = articles[articles.orientation.isin(grouping)]

    for year in articles_filtered.year.unique():
        print("=" * 20, year)

        model_path = f"{DECON_MODEL_PATH}/{orientation}-{year}.model"
        orientation_year_context_embedded_cache_file = path.join(
            DECON_CACHE_PATH, f"orientation-context-embedded-{orientation}-{year}.pkl")

        print("Loading cached orientation embeddings.")
        orientation_year_context_embeddings = {}
        with open(orientation_year_context_embedded_cache_file, "rb") as f:
            orientation_year_context_embeddings = pickle.load(f)

        # vocab_size = len(orientation_year_context_embeddings)
        embedding_size = len(orientation_year_context_embeddings["woman"][pooled_embedding_type][0])

        print("Copying vectors to CPU and writing to disk.")
        token_vector_strings = []
        for token in tqdm(vocabulary, mininterval=2):
            try:
                token_vector = (
                    orientation_year_context_embeddings[token][pooled_embedding_type][0].tolist())
                token_vector_line = " ".join([str(d) for d in token_vector])
                token_vector_strings.append(f"{token} {token_vector_line}")
            except KeyError:
                print(f"Token {token} not found. Skipping.")

        # Save final embedding dictionary to disk
        with open(f"{model_path}", "w") as f:
            f.write(f"{len(token_vector_strings)} {embedding_size}\n")
            for token_vector_line in token_vector_strings:
                f.write(f"{token_vector_line}\n")

print("Generation done.")