In [1]:
import pandas as pd
from pathlib import Path
from datasets import load_from_disk
from sentence_transformers import SentenceTransformer
import numpy as np
import srsly
import hnswlib as hb

In [None]:
model_name = "all-mpnet-base-v2" # "all-MiniLM-L6-v2"
sentence_encoder = SentenceTransformer(model_name)

In [2]:
ds_dict = load_from_disk("../data/processed/civil_comments/")
train_ds = ds_dict["train"] #.select(range(1000))
train_df = train_ds.to_pandas()



In [None]:
embeddings = sentence_encoder.encode(train_df["text"].tolist(), show_progress_bar=True, batch_size=512, device="cuda")

In [None]:
np.save("../data/processed/civil_comments/civil_comments_index.npy", embeddings, fix_imports=False)

In [None]:
meta = srsly.read_yaml("../data/processed/civil_comments/metadata.yaml")

In [None]:
meta["embedding_model"] = model_name

In [None]:
srsly.write_yaml("../data/processed/civil_comments/metadata.yaml", meta)

In [3]:
embeddings = np.load("../data/processed/civil_comments/civil_comments_index.npy")

In [None]:
p = hb.Index(space="cosine", dim=embeddings.shape[1])
p.set_ef(200)
p.init_index(max_elements=embeddings.shape[0], M=64, ef_construction=200)
p.add_items(embeddings, np.arange(embeddings.shape[0]))

In [None]:
p.save_index("../data/processed/civil_comments/civil_comments_index.bin")

In [4]:
p = hb.Index(space="cosine", dim=embeddings.shape[1])
p.load_index("../data/processed/civil_comments/civil_comments_index.bin")

In [5]:
ids, distances = p.knn_query(embeddings[100, :], k=10)

In [9]:
train_df.loc[
    train_df["unique_id"].isin(ids[0, 1:]), 
    "text"
].tolist()

['Stop stereotyping.',
 'Wow - stereotypes much!',
 'over generalize much?',
 'Holy generalization, Batman.',
 'What are you basing your generalizations on?',
 'Unfortunately actions like these just reinforce stereotypes.',
 'People that make such huge assumptions about others are just clueless.',
 'This is a ridiculous stereotype....Give it s break. You then wonder who so many on the other side label you are "fake news" and liars.',
 "I'm not sure this is a stereotype so much as a segment of voters.  Of course the vast majority of these folks (speller checker put in fools, funny) vote Republican...there would not GOP without these voters. But there are well-meaning conservatives that do not see themselves in the current Democratic Party either and that's our mistake. We should allow what Keilor thinks of conservatives advise us too. And really be inclusive. Of course thoughtful discourse must past the smell test...and Trump doesn't."]

In [None]:
from src.data.datamodule import DataModule

In [None]:
dm = DataModule.from_dataset_dict(ds_dict)