In [None]:
import weaviate
# import weaviate.classes.config as wc
# import weaviate.classes.query as wq
from weaviate.util import generate_uuid5
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer

from openie import StanfordOpenIE
from collections import defaultdict
from datasets import load_dataset
from tqdm.auto import tqdm

import numpy as np
import torch
from torch.nn import CrossEntropyLoss

from datasketch import MinHash, MinHashLSH
from nltk import ngrams

In [None]:
%env PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

In [None]:
properties = {
    "openie.affinity_probability_cap": 2 / 3,
    "openie.triple.strict": False
}
IEclient = StanfordOpenIE(properties=properties)

In [None]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
model = model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:
def compute_ppl(predictions, batch_size: int = 16, add_start_token: bool = True, max_length=512):
        if torch.cuda.is_available() == True:
            device = "cuda"
        else:
            device = "cpu"
    
        # if batch_size > 1 (which generally leads to padding being required), and
        # if there is not an already assigned pad_token, assign an existing
        # special token to also be the padding token
        if tokenizer.pad_token is None and batch_size > 1:
            existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
            # check that the model already has at least one special token defined
            assert (
                len(existing_special_tokens) > 0
            ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
            # assign one of the special tokens to also be the pad token
            tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

        if add_start_token and max_length:
            # leave room for <BOS> token to be added:
            assert (
                tokenizer.bos_token is not None
            ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
            max_tokenized_len = max_length - 1
        else:
            max_tokenized_len = max_length

        encodings = tokenizer(
            predictions,
            add_special_tokens=False,
            padding=True,
            truncation=True if max_tokenized_len else False,
            max_length=max_tokenized_len,
            return_tensors="pt",
            return_attention_mask=True,
        ).to(device)

        encoded_texts = encodings["input_ids"]
        attn_masks = encodings["attention_mask"]

        # check that each input is long enough:
        if add_start_token:
            assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
        else:
            assert torch.all(
                torch.ge(attn_masks.sum(1), 2)
            ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."

        ppls = []
        loss_fct = CrossEntropyLoss(reduction="none")

        for start_index in range(0, len(encoded_texts), batch_size):
            end_index = min(start_index + batch_size, len(encoded_texts))
            encoded_batch = encoded_texts[start_index:end_index]
            attn_mask = attn_masks[start_index:end_index]

            if add_start_token:
                bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
                encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
                attn_mask = torch.cat(
                    [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
                )

            labels = encoded_batch

            with torch.no_grad():
                out_logits = model(encoded_batch, attention_mask=attn_mask).logits

            shift_logits = out_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

            perplexity_batch = torch.exp(
                (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
                / shift_attention_mask_batch.sum(1)
            )

            ppls += perplexity_batch.tolist()

        return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

In [None]:
def get_triples_ie(text):
    res = [x for x in IEclient.annotate(text)]
    temp = [tuple(x.values()) for x in res]

    current = defaultdict(list)
    for t in temp:
        current[(t[0], t[1])].append(t)

    final = []
    for t in temp:
        s = "{} | {} | {}".format(t[0], t[1], t[2])
        if s not in final:
            final.append(s.replace("_", " "))
    
    lsh = MinHashLSH(threshold=0.4, num_perm=128)
    minhashes = {}
    for i, f in enumerate(final):
        minhash = MinHash(num_perm=128)
        for d in ngrams(f, 3):
            minhash.update("".join(d).encode('utf-8'))
        lsh.insert(i, minhash)
        minhashes[i] = minhash

    matches = {}
    for x, y in zip(final, minhashes):
        matches[x] = [final[z] for z in lsh.query(minhashes[y]) if z != y] 

    clusters = []
    covered = []
    for m in sorted(matches, key=lambda x: len(matches[x]), reverse=True):
        if m not in covered and len(matches[m]) > 0:
            clusters.append(matches[m])
            covered.extend(matches[m])

    clean = [x.replace(" | ", " ") for x in covered]
    if len(clean) == 0:
        return []
    ppls = dict(zip(covered, compute_ppl(predictions=clean, batch_size=128)["perplexities"]))

    best = []
    for c in clusters:
        scores = [ppls[x] for x in c]
        imin = np.argmin(scores)
        best.append(c[imin])

    ordered = []
    for f in final:
        if f in best:
            ordered.append(f)
    return ordered

In [None]:
embedding_model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to("cuda")

In [None]:
client = weaviate.connect_to_local() # must have Weaviate installed and running!

In [None]:
# run once!
# client.collections.create(
#     name="Triples",
#     properties=[
#         wc.Property(name="text", data_type=wc.DataType.TEXT),
#     ],
#     vectorizer_config=wc.Configure.Vectorizer.none(),
# )

In [None]:
triples = client.collections.get("Triples")

In [None]:
dataset = load_dataset("HuggingFaceFW/fineweb", "sample-10BT", split="train", streaming=True)

In [None]:
all_data = []
idx = 0
total = 0
for row in dataset:
    if idx == 10000000:
        break
    
    if idx % 100000 == 0 and idx != 0:
        print("{} rows processed.".format(idx))

    if len(row["text"]) > 10000 or row["language"] != "en":
        continue
    
    all_data.append(row["text"])

    if len(all_data) >= 1000:
        print("Extracting triples...")
        res = [get_triples_ie(x) for x in tqdm(all_data)]
        tri = []
        for r in res:
            tri.extend(r)
        
        total += len(tri)
        print("Inserting {} triples...".format(len(tri)))
        embeddings = embedding_model.encode(tri, task="text-matching", truncate_dim=32, max_length=64)

        with triples.batch.dynamic() as batch:
            for i, t in enumerate(tri):
                obj = {"text":t}
                vector = embeddings[i]
        
                batch.add_object(
                    properties=obj,
                    uuid=generate_uuid5(obj),
                    vector=vector
                )
        print("Finished. Total: {}".format(total))
        del tri[:]
        del tri
        del all_data[:]
        all_data = []

    idx += 1

print("Extracting triples...")
res = [get_triples_ie(x) for x in tqdm(all_data)]
tri = []
for r in res:
    tri.extend(r)

total += len(tri)
print("Inserting {} triples...".format(len(tri)))
embeddings = embedding_model.encode(tri, task="text-matching", truncate_dim=32, max_length=64)

with triples.batch.dynamic() as batch:
    for i, t in enumerate(tri):
        obj = {"text":t}
        vector = embeddings[i]

        batch.add_object(
            properties=obj,
            uuid=generate_uuid5(obj),
            vector=vector
        )
print("Finished. Total: {}".format(total))
del tri[:]
del tri
del all_data[:]
all_data = []

In [None]:
client.close()