In [7]:
import json
from chromadb import Client
from chromadb.config import Settings
import chromadb.utils.embedding_functions as embedding_functions

# Load your JSON
with open('mitre_fight.json') as f:
    raw_data = json.load(f)

# Flatten each threat entry into a document
docs = []
metadatas = []

for threat_id, content in raw_data.items():
    doc = f"ID: {threat_id}\nName: {content.get('Name')}\nDescription: {content.get('Description')}"
    doc += f"\nPlatform: {content.get('Platform')}\nTactics: {content.get('Tactics')}"
    
    for example in content.get("Procedure Examples", []):
        doc += f"\nProcedure Example - {example.get('name')}: {example.get('description')}"

    for detection in content.get("Detection", []):
        doc += f"\nDetection - {detection.get('id')}: {detection.get('description')}"

    for asset in content.get("Critical Assets", []):
        doc += f"\nCritical Asset - {asset.get('name')}: {asset.get('description')}"

    for pre in content.get("Pre-Conditions", []):
        doc += f"\nPre-condition - {pre.get('name')}: {pre.get('description')}"

    for post in content.get("Post-Conditions", []):
        doc += f"\nPost-condition - {post.get('name')}: {post.get('description')}"

    for mitigation in content.get("Mitigations", []):
        doc += f"\nMitigation - {mitigation.get('name')}: {mitigation.get('description')}"

    # for ref in content.get("References", []):
    #     doc += f"\nReference - {ref.get('name')}: {ref.get('description')}"

    docs.append(doc)
    metadatas.append({"id": threat_id, "name": content.get("Name")})
    llamaindex_docs = [
    Document(text=doc, metadata=meta)
    for doc, meta in zip(docs, metadatas)
    ]


In [9]:
import chromadb
from llama_index.core import Document, VectorStoreIndex
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core.storage.storage_context import StorageContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

embed_model  = HuggingFaceEmbedding(model_name = "BAAI/bge-large-en")
        
chroma_client = chromadb.PersistentClient(path='mitre_chroma')
vector_store = ChromaVectorStore(chroma_collection=chroma_client.get_or_create_collection('mitre'))
storage_context = StorageContext.from_defaults(vector_store=vector_store)

index = VectorStoreIndex.from_documents(
            llamaindex_docs,
            storage_context=storage_context,
            embed_model=embed_model
        )

In [21]:
retriever = index.as_retriever(similarity_top_k=4)

# query = 'how do I mitigate handling of NAS counter values typically used in generating/verifying the message authentication codes (MAC) for replay protection of NAS layer messages'

query = 'Rogue AF/NEF modifies UE\u2019s configuration for a given external service'

res = retriever.retrieve(query)

In [22]:
for node in res:
    print(node.metadata, f"-- {node.score:.4f}")

{'id': 'FGT5022', 'name': 'Alter Subscriber Profile'} -- 0.7666
{'id': 'FGT5008', 'name': 'Redirection of traffic via user plane network function'} -- 0.7471
{'id': 'FGT1600.501', 'name': 'Radio Interface'} -- 0.7333
{'id': 'FGT1608.502', 'name': 'Configure Operator Core Network'} -- 0.7189
