In [1]:
import boto3
from botocore.exceptions import ClientError
from bs4 import BeautifulSoup
import chromadb
from collections import Counter
from collections import defaultdict
import json
from langchain_text_splitters import HTMLHeaderTextSplitter, HTMLSectionSplitter
import os
from pathlib import Path
from pinecone import Pinecone
import sqlite3

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_threat_feeds_secret():
    secret_name = "threat-feeds-secrets"
    region_name = "us-east-2"

    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name
    )

    try:
        get_secret_value_response = client.get_secret_value(
            SecretId=secret_name
        )
    except ClientError as e:
        # For a list of exceptions thrown, see
        # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html
        raise e

    secret = get_secret_value_response['SecretString']
    return json.loads(secret)

In [3]:
def chunk_document(splitter, doc):
    documents = splitter.split_text(doc)
    
    grouped_docs = defaultdict(list)
    
    for doc in documents:
        key = " > ".join(doc.metadata.values())
        grouped_docs[key].append(doc.page_content)

    agg_grouped_docs = {}
    for key, values in grouped_docs.items():
        agg_grouped_docs[key] = "\n\n".join(values)
        
    return list(agg_grouped_docs.values())

In [4]:
def persist_documents():
    crdb = chromadb.PersistentClient(path="pagevector")
    collection = crdb.get_or_create_collection(name="reports", metadata={
        "hnsw:M": 32,
        "hnsw:search_ef": 100
    })

    headers_to_split_on = [("h1", "Main Topic"), ("h2", "Sub Topic")]
    splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)

    pagedata = Path("pagedata")
    for source in os.listdir(pagedata):
        print(f"Indexing documents for {source}")
        for page in os.listdir(pagedata / source):
            fullpath = os.path.join(pagedata / source / page)
            report_id = Path(page).stem

            documents = []
            with open(fullpath) as f:
                documents = chunk_document(splitter, f.read())

            ids = []
            crdb_docs = []
            for i, value in enumerate(documents):
                ids.append(f"{i}:{report_id}")
                crdb_docs.append(value)

            try:
                collection.upsert(
                    ids=ids,
                    documents=crdb_docs,
                )
            except Exception as e:
                raise Exception(f"error upserting for {source} {report_id}: {e}")

        print(f"Finished indexing documents for {source}. Total {collection.count()} documents")

    print(f"Indexed {collection.count()} documents")

In [5]:
def compute_similarities(sq_conn):
    read_cur = sq_conn.cursor()
    update_cur = sq_conn.cursor()

    crdb = chromadb.PersistentClient(path="pagevector")
    collection = crdb.get_collection("reports")

    headers_to_split_on = [("h1", "Main Topic"), ("h2", "Sub Topic")]
    splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)

    count = 0
    for row in read_cur.execute("SELECT id, source FROM report"):
        report_id, source = row
        
        pagepath = os.path.join("pagedata", source, f"{report_id}.html")
        with open(pagepath) as f:
            contents = f.read()
            chunks = chunk_document(splitter, contents)

            results = collection.query(query_texts=chunks, n_results=5)
            ids = results["ids"]
            result_report_ids = [rid.split(":")[1] for id_set in ids for rid in id_set]
            filtered_report_ids = [rid for rid in result_report_ids if rid != report_id]

            result_count = Counter(filtered_report_ids)
            top_match_ids = [match[0] for match in result_count.most_common(10)]

            with sq_conn:
                update_cur.execute("""
                    UPDATE report SET
                        related_report_ids = ?
                    WHERE id = ?
                """, (",".join(top_match_ids), report_id,))
                
                if update_cur.rowcount != 1:
                    raise Exception(f"update_cur.rowcount = {update_cur.rowcount}")
            
            count += 1

    print(f"Computed similarities for {count} rows")

In [9]:
def upload_to_pinecone(sq_conn):
    # Initialize database
    with open("sqlite_schema.sql") as f:
        try:
            with sq_conn:
                sq_conn.executescript(f.read())
        except Exception as e:
            print(f"error applying schema: {e}")

    read_cur = sq_conn.cursor()
    update_cur = sq_conn.cursor()

    pinecone_api_key = get_threat_feeds_secret()['pinecone-api-key']
    pc = Pinecone(api_key=pinecone_api_key)

    assistant = pc.assistant.Assistant(assistant_name="threat-feeds-assistant")

    count = 0
    for row in read_cur.execute("SELECT id, source FROM report"):
        report_id, source = row

        pc_row = None
        with sq_conn:
            pc_row = sq_conn.execute("SELECT pinecone_file_id FROM pinecone_reports WHERE report_id = ?", (report_id,)).fetchone()

        if pc_row:
            count += 1
            print(f"Skipped pinecone report {report_id} ({count})")
            continue

        pagepath = os.path.join("parseddata", source, f"{report_id}.txt")

        upload_resp = assistant.upload_file(
            file_path=pagepath,
            metadata={"report_id": report_id},
            timeout=None
        )
        with sq_conn:
            update_cur.execute("""
                INSERT INTO pinecone_reports VALUES (?, ?)
            """, (report_id, upload_resp.id),)
            
            if update_cur.rowcount != 1:
                raise Exception(f"update_cur.rowcount = {update_cur.rowcount}")
        
        count += 1
        print(f"Uploaded pinecone report {report_id} ({count})")

In [None]:
if __name__ == '__main__':
    # persist_documents()
    
    sq_conn = sqlite3.connect("reports.db")
    try:
        # compute_similarities(sq_conn)
        upload_to_pinecone(sq_conn)
    finally:
        sq_conn.close()