## Libraries

In [None]:
!pip install langchain langchain_core langchain_community langchain-huggingface torch accelerate bitsandbytes docarray unstructured jq openpyxl

## Environment Variables and Constants

Set the API keys and environment variables required for running the app

In [2]:
import os
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "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"
os.environ['LANGCHAIN_API_KEY'] = "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"
os.environ["OPENAI_API_KEY"] = "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"
os.environ["HF_HUB_CACHE"] = "/scratch/project_2013047/models"
os.environ["HF_HOME"] = "/scratch/project_2013047/models"


TEMPLATE = """
You are an assistant in security risk analysis.
You will be provided with risk scenarios that have certain threats and vulnerabilities. For the threats you will also be provided with possible counter measures.
You will be provided with a user scenario and based on that you will be provided with context of related scenarios from you retrieval vector store.
You will also be provided with possible countermeasures for all similar scenarios.
You need to suggest the appropriate counter measure for the user scenario and give a reasoining as to why it is appropriate.
Answer the question based only on the following context. If the question does not relate with the context, just reply 'I don't know'

User: {user}

Scenarios: {scenarios}

countermeasures: {countermeasures}
"""


## Requirements

Install the requirements and import relevant modules

In [4]:
import json
import os

import pandas as pd
import torch
from langchain_community.document_loaders import JSONLoader
from langchain_community.vectorstores import DocArrayInMemorySearch
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
                                   HuggingFacePipeline)
from langchain_core.documents.base import Document
from pandas import DataFrame
from transformers import BitsAndBytesConfig
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.base import RunnableLambda
from langchain_community.document_loaders import UnstructuredExcelLoader

## Utitlity Functions

Utitlity functions to perform different operations like loading data, formatting data etc.

In [47]:
def load_excel_to_dataframe(file_path: str, header=0, index_col=0, reset_index=False, cols: list = None) -> DataFrame:
    """
    Loads the remediation table into a pandas dataframe.

    Args:
        * file_path (str): The path to the remidations excel file
        * header (int): the row to set as header row
        * index_col: column ids. provide if there are rows with multi level sub subrows
        * cols (list | None): list of renamed column names
    Returns:
        A pandas `DataFrame` object with the loaded data
    """

    df = pd.read_excel(file_path, header=header, index_col=index_col)
    df = df.reset_index()

    if cols:
        df.columns = cols
    return df

def convert_df_to_dict(df: DataFrame, save_path=None) -> dict:
    """
    Converts a pandas `DataFrame` to python dictionary format

    Args:
        * df (DataFrame): A panads `DataFrame` object with required data
        * save_path (str): path if want to save the json (dict) data in a file
    Returns:
        json data
    """

    json_data = df.to_json(orient="records")
    dict_data = json.loads(json_data)

    if save_path:
        with open(save_path, 'w') as f:
            json.dump(dict_data, f)
            
    return dict_data

def set_prompt(template) -> ChatPromptTemplate:
    """
    Set up the chat prompt to be used with the model

    Args:
        * template (str): The prompt template to use

    Returns:
        `ChatPromptTemplate` object
    """

    prompt = ChatPromptTemplate.from_template(template)
    return prompt

def load_lrm_model_from_hf(model_id) -> ChatHuggingFace:
    """
    Loads lrm model from hugging face to a `ChatHuggingFace` model

    Args:
        model_id (str): the hugging face url of the model

    Returns:
        A `ChatHuggingFace` model
    """

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
    )

    llm = HuggingFacePipeline.from_model_id(
    model_id=model_id,
    task="text-generation",
    pipeline_kwargs=dict(
        max_new_tokens=1024,
        do_sample=False,
        repetition_penalty=1.03,
        return_full_text=False
        ),
    model_kwargs={"quantization_config": quantization_config},
    device_map="auto",
    )
    model = ChatHuggingFace(llm=llm)
    return model

def create_retrievar_from_vector_store(docs: list):
    """
    Create a retrievar from a vector store.
    Embeds documnents, stores into a vector store and creates
    a retrievar that can be used to retrieve relevant documents.

    Args:
        docs (list): A list of `Documents` which need to be embeded.

    Returns:
    
    """

    embed = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    vector_store = DocArrayInMemorySearch.from_documents(docs, embed)
    retrievar = vector_store.as_retriever()

    return retrievar

def remediation_lookup(scenario_docs: list[Document]) -> dict:
    """
    Get the list of remediations knowing the threat and
    vulnerability
    The method performs a lookup to get possible countermeasures
    to provided threat and vulnerability.

    Args:

    Returns:
        dictionary of possible countermeasures
    """

    # TODO: handle NoneType
    # TODO: handle leading spaces in doc.page_content
    countermeasures = []

    for doc in scenario_docs:
        
        scen_dict = json.loads(doc.page_content)
        if scen_dict["risk_id"] is None:
            continue
        threat_id = scen_dict["risk_id"].strip()
        vuln_id = scen_dict["vuln_id"].strip()
        x = remediations_df[remediations_df["threat_id"]==threat_id]
        df = x[x["vuln_id"]==vuln_id]
        
        countermeasures.extend(convert_df_to_dict(df))
    
    return countermeasures

def setup_rag_chain(model_id, template):
    """
    create the rag chain
    """
    
    scenarios_df = load_excel_to_dataframe("./data/Scenarios.xlsx", cols=["scen_id", "scen", "extended", "short", "details", "risk_id", "risk_desc", "vuln_id", "vuln_desc", "risk_occud_type"])
    scenarios_dict = convert_df_to_dict(scenarios_df, save_path="./data/scen.json")
    loader_scen = JSONLoader(file_path="./data/scen.json",jq_schema='.[]', text_content=False)
    scenarios_doc = loader_scen.load()
    scenarios = create_retrievar_from_vector_store(scenarios_doc)

    prompt = set_prompt(template)
    
    model = load_lrm_model_from_hf(model_id=model_id)

    output_parser = StrOutputParser()

    chain = (
        {
            "user": RunnablePassthrough(),
            "scenarios": scenarios,
            "countermeasures": scenarios | RunnableLambda(remediation_lookup),
        } 
        | prompt
        | model
        | output_parser
        
    )

    return chain
    
    
    
    

## RAG Workflow

This is where the RAG workflow starts from

In [None]:
# # load excel docs
remediations_df = load_excel_to_dataframe("./data/Remediations.xlsx", header=2, index_col=[0,1,2,3,4], cols=["threat_id", "threat_desc", "vuln_id", "vuln_desc", "vthe", "remediation_id", "remediation_desc", "tech_nature"])
# remediations_df = pd.read_excel("./data/Remediations.xlsx", header=2, index_col=[0,1,2,3,4])

# TODO: fix this issue where some values event after 
# resetting the index are NaN
for i in range(1360, 1419):
    remediations_df.loc[i, "threat_id"] = "M26"
scenarios_df = load_excel_to_dataframe("./data/Scenarios.xlsx", cols=["scen_id", "scen", "extended", "short", "details", "risk_id", "risk_desc", "vuln_id", "vuln_desc", "risk_occud_type"])

# # convert dataframe to json
# # TODO: strip spaces from cells i.e risk_id = ' M3'
scenarios_dict = convert_df_to_dict(scenarios_df, save_path="./data/scen.json")
# remediations_dict = convert_df_to_dict(remediations_df, save_path="./data/rem.json")

# # convert json data to lanchain Document format
loader_scen = JSONLoader(file_path="./data/scen.json",jq_schema='.[]', text_content=False)
# loader_rem = JSONLoader(file_path="./data/rem.json",jq_schema='.[]', text_content=False)

scenarios_doc = loader_scen.load()
# remediations_doc = loader_rem.load()

# # save the Documents to the vector store and get
# # the retrievar
scenarios = create_retrievar_from_vector_store(scenarios_doc)
# remediations = create_retrievar_from_vector_store(remediations_doc)


In [None]:
chain = setup_rag_chain()

In [None]:
chain = setup_rag_chain(model_id="O1-OPEN/OpenO1-LLama-8B-v0.1", template=TEMPLATE)

In [None]:
a = chain.invoke("Only authorized users can open the cabinets containing classified documents and no tracking is required.")

In [None]:
print(a)