In [None]:
import itertools

import datasets
import numpy as np
import pandas as pd
import torch
import tqdm
import transformers

EMBED_MODEL = 'nomic-ai/nomic-embed-text-v1.5'

In [None]:
ds = datasets.load_dataset('mikex86/stackoverflow-posts', split='train', streaming=True)
records = list(itertools.islice(iter(ds), 5))
pd.DataFrame(records).iloc[0]['Body']

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained(EMBED_MODEL)

body = next(iter(ds))['Body']
input_ids = tokenizer(body, return_tensors="np")
print(input_ids)

In [None]:
model = transformers.AutoModel.from_pretrained(EMBED_MODEL, trust_remote_code=True, safe_serialization=True)

if not torch.cuda.is_available():
    model = model.to('mps')

model.eval()

In [None]:
EMBED_DIM = 128
KEY_WINDOW = 32

# TODO: Values should be added as tokens, not embeddings!
VALUE_WINDOW = 64

DS_KEYS = list(next(iter(ds)).keys())

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    mean = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )
    mean = torch.nn.functional.layer_norm(mean, normalized_shape=(mean.shape[1],))
    return mean[:, :EMBED_DIM]


def split(record):
    tokens = [tokenizer.tokenize(body) for body in record["Body"]]
    token_ids = tokenizer(record["Body"], return_tensors="np")

    num_docs = token_ids["input_ids"].shape[0]
    result = []
    for i in range(num_docs):
        for j in range(0, token_ids["input_ids"][i].shape[0], KEY_WINDOW):
            result.append({
                "key_tokens": np.array(tokens[i][j : j + KEY_WINDOW]),
                "key_input_ids": token_ids["input_ids"][i][j : j + KEY_WINDOW],
                "key_token_type_ids": token_ids["token_type_ids"][i][j : j + KEY_WINDOW],
                "key_attention_mask": token_ids["attention_mask"][i][j : j + KEY_WINDOW],
                "value_tokens": np.array(tokens[i][j : j + VALUE_WINDOW]),
                "value_input_ids": token_ids["input_ids"][i][j : j + VALUE_WINDOW],
                "value_token_type_ids": token_ids["token_type_ids"][i][j : j + VALUE_WINDOW],
                "value_attention_mask": token_ids["attention_mask"][i][j : j + VALUE_WINDOW],
            })
    
    df = pd.DataFrame(result)

    # iterate over columns,  convert to torch tensors, padding to the maximum length
    result_dict = df.to_dict(orient="list")

    for k, v in result_dict.items():
        if "_tokens" in k:
            result_dict[k] = v
            continue

        max_len = max(len(array) for array in v)
        # pad all tensors to max len and convert to 2d array
        result_dict[k] = np.array([np.pad(array, (0, max_len - len(array))) for array in v])

    return result_dict


def encode(record):
    splits = split(record)
    for k, v in splits.items():
        if "_tokens" in k:
            continue
        else:
          splits[k] = torch.tensor(v).to("mps")

    with torch.no_grad():
        key_embedding = model(
            input_ids=splits["key_input_ids"],
            token_type_ids=splits["key_token_type_ids"],
            attention_mask=splits["key_attention_mask"],
        )
        key_embedding = mean_pooling(key_embedding, splits["key_attention_mask"])

    key_embedding = key_embedding.detach().cpu().numpy()
    print(len(splits["key_tokens"]), key_embedding.shape, value_embedding.shape)

    return {
        "key_embedding": key_embedding,
        "value_input_ids": splits["value_input_ids"],
        "value_token_type_ids": splits["value_token_type_ids"],
        "value_attention_mask": splits["value_attention_mask"],
        "key_tokens": splits["key_tokens"],
        "value_tokens": splits["value_tokens"],
    }

BATCH_SIZE = 16
encode_ds = ds.map(
    encode,
    batch_size=BATCH_SIZE,
    batched=True,
    remove_columns=DS_KEYS,
)

encode_iter = iter(encode_ds)
for i in range(1):
    batch = next(encode_iter)
    print('KEY_EMBED', batch['key_embedding'].mean())
    print('KEY', ''.join(batch['key_tokens']))
    print('VAL', ''.join(batch['value_tokens']))

In [None]:
# Let's test the embedding performance.
# Let's add 10000 documents to the ANN index and then query the index with a random document.

import numpy as np
import annoy

index = annoy.AnnoyIndex(EMBED_DIM, "angular")

records = []
encode_iter = itertools.islice(iter(encode_ds), 10000)
for i, batch in enumerate(tqdm.tqdm(encode_iter, total=10000)):
    records.append(batch)
    index.add_item(i, batch["key_embedding"])

index.build(10)

In [None]:
df = pd.DataFrame(records)
df[:10]

In [None]:
q = df['key_embedding'].iloc[2]
q.shape

idx, dist = index.get_nns_by_vector(q, 5, search_k=100, include_distances=True)
# make a dataframe of the closest docs
df['key_tokens'].iloc[idx].str.join(' ').to_list()