# Cybereason AI Engineer Home Assignment

## Background

In this assignment, you will receive a dataset of alerts of different vendors and types. Our task is to classify those alerts with MITRE ATT&CK techniques or sub-techniques (for more information about MITRE ATT&CK, please visit https://attack.mitre.org/). <br/>
You don't need to understand all MITRE ATT&CK TTP's, but you need to understand the concept of MITRE ATT&CK and how to use it to classify alerts. <br/>
We will focus on a few MITRE ATT&CK techniques and sub-techniques.

## Your Task - Classify new coming alerts

You received 2 datasets: train and test your goal is to build a model based on the train data ( a data with alerts and matching labels) and classify the test data. <br/>
Instructions:
- You have access to LLaMA-8B model - please LLaMA-8B ONLY [You are NOT allowed to use other models \ embedding models]
- For each single alert generation, you can call the LLM model ONLY ONE TIME.  Also, limit yourself to MAX 8K input tokens.
- You are not allowed to use any other open-source model AS IS. If you are choosing to use open-source model for some stage (not mandatory) you have to fine-tune or customize it. The reason for that is to make the task a bit harder :)
- You have an API function named `llm` that call the LLM throw an API call - it's limited and should be used for this task only!
- Please do not share the API key or use it to other purposes besides this task.
- please install the `requrements.txt` file before starting the task.

# Final Solution:
My final solution below have these key elements on it:
* Few-Shot Prompting with Structured Examples (19 shot) on LLaMA-8B,
* Adjusting the [model params](https://console.groq.com/docs/text-chat)
* Text cleaning - simple cleaning (could get better)
* Replacing similar examples with different ones (without embedding model, but menually reading and removing, due to the instruction limitations)
* DS Undersampling - To balance the DS
* LabelEncoder - encoded the alert code before classification and then decoded them back after prediction
* Define the model behaviour with a "system" message
* JSON Format - The model handles one alert at a time, and return Json, then validate the stracture and return python class with the JSON attributes, it inherit from pydantic-BaseModel the JSON stracture is: {alert: alert text, uid: uid provided with it, code: The label which was encoded by the LabelEncoder, name: alert name, confidence_score: the confidence of the prediction}

Measurements:

As mentioned above, I took 19 samples to the prompt.

I used the other 81 samples to validate the model performance.

```
Number of correct predictions: 79
Number of incorrect predictions: 0
Confusion Matrix:
[[24  0  0  0  0  0]
 [ 0 26  0  0  0  0]
 [ 0  0 20  0  0  0]
 [ 0  0  0  6  0  0]
 [ 0  0  0  0  1  0]
 [ 0  0  0  0  0  2]]

Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        24
           2       1.00      1.00      1.00        26
           4       1.00      1.00      1.00        20
           5       1.00      1.00      1.00         6
           6       1.00      1.00      1.00         1
           7       1.00      1.00      1.00         2

    accuracy                           1.00        79
   macro avg       1.00      1.00      1.00        79
weighted avg       1.00      1.00      1.00        79

Micro-Averaged Precision: 1.0
Micro-Averaged Recall: 1.0
Micro-Averaged F1 Score: 1.0
```


In [None]:
# Load the Drive helper and mount
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install -r '/content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/requirements.txt'
!pip install tqdm

Collecting groq (from -r /content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/requirements.txt (line 1))
  Downloading groq-0.9.0-py3-none-any.whl (103 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.5/103.5 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jupyter (from -r /content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/requirements.txt (line 2))
  Downloading jupyter-1.0.0-py2.py3-none-any.whl (2.7 kB)
Collecting python-dotenv (from -r /content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/requirements.txt (line 3))
  Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)
Collecting httpx<1,>=0.23.0 (from groq->-r /content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/requirements.txt (line 1))
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m

In [None]:
import os
from groq import Client
from dotenv import load_dotenv
from pydantic import BaseModel # To validate the JSON class stracture
import json # for json.dumps on MitreAttack class
from sklearn.metrics import confusion_matrix, classification_report, precision_score, recall_score, f1_score # To measure performance
import pandas as pd # For data operations
from sklearn.preprocessing import LabelEncoder # Encode the code into numeric label from 0-(n-1) where n is number of classes
import re # for text cleaning

# class to represent and validate the JSON
class MitreAttack(BaseModel):
    alert: str
    uid: str
    code: int
    name: str
    confidence_score: float

# class representing the llm model
class LlamaSmallClassifier():
  def __init__(self, env_path):
    # load .env file
    load_dotenv(dotenv_path=env_path)
    # initialize Groq API client
    groq_api_key = os.getenv("GROQ_API_KEY")
    self.client = Client(api_key=groq_api_key)

  def classify(self, prompt, uid, min_label, max_label):
    """Query the Groq API with a prompt and return JSON response,
    then validate the stracture and return python class - MitreAttack with
    the JSON attributes
    min_label - the minimum label value
    max_label - the maximum label value
    prompt - the prompt to query the model
    uid - the uid provided with the alert
    JSON Schema:
    {
      alert: alert text,
      uid: uid provided with it,
      code: The label which was encoded by the LabelEncoder,
      name: alert name,
      confidence_score: the confidence of the prediction
    }
    """

    json_schema_string = (
    "{\"alert\":\"string (alert message Do not return any additional text, explanations, or comments)\","
    "\"uid\":\"string (uid provided by the user with the alert)\","
    "\"code\":\"number (MITRE ATT&CK technique or sub-technique encoded label Do not return any additional text, explanations, or comments)\","
    "\"name\":\"string (MITRE ATT&CK technique or sub-technique name Do not return any additional text, explanations, or comments)\","
    "\"confidence_score\":\"float (number (0-1)) Do not return any additional text, explanations, or comments\"}"
)

    possible_classes = f'({min_label}-{max_label})'
    # Find the index of "code\": \"number "
    label_index = json_schema_string.find("code\": \"number ")
    # Find the index right after "code\": \"number "
    insert_position = label_index + len("code\": \"number ")
    # insert the possible classes to deny hallucinations
    json_schema_string = json_schema_string[:insert_position] + " " + possible_classes + json_schema_string[insert_position:]
    system_message = (
    f"You are a text classification model. "
    f"Your task is to classify the given alert message and respond in JSON according to the schema below. "
    f"Do not return any additional text, explanations, or comments. "
    f"The JSON schema should include: {json_schema_string}. "
    f"The JSON should have the following structure: "
    f"{json.dumps(MitreAttack.model_json_schema(), indent=2)}. "
    f"Do not return any additional text, explanations, or comments."
    )
    response = self.client.chat.completions.create(
      model=os.getenv("GROQ_MODEL_NAME"),
      messages=[
          {
            "content": system_message, # Define the model behavior
            "role": "system",
            "name": uid # The client name assinged as the alert uid
          },
          {
            "content": prompt, # Assign the (few shot) prompt of one alert
            "role": "user",
            "name": uid # The client name assinged as the alert uid
          }
      ],
      max_tokens=512,
      temperature=0,
      top_p=0.7,
      stream=False,
      stop=None,
      response_format={"type": "json_object"}
      )
    return MitreAttack.model_validate_json(response.choices[0].message.content)


  def print_classification_report(self, y, y_pred):
    """Print the confusion matrix, classification report and micro-averaged reports
    for the given labels and predictions
    y - The ground truth labels
    y_pred - The predicted labels"""

    # Calculate confusion matrix
    conf_matrix = confusion_matrix(y, y_pred)
    print("Confusion Matrix:")
    print(conf_matrix)

    # Calculate classification report
    class_report = classification_report(y, y_pred, target_names=None)
    print("\nClassification Report:")
    print(class_report)

    # Calculate micro-averaged precision
    micro_precision = precision_score(y, y_pred, average='micro')
    print(f"Micro-Averaged Precision: {micro_precision}")

    # Calculate micro-averaged recall
    micro_recall = recall_score(y, y_pred, average='micro')
    print(f"Micro-Averaged Recall: {micro_recall}")

    # Calculate micro-averaged F1 score
    micro_f1 = f1_score(y, y_pred, average='micro')
    print(f"Micro-Averaged F1 Score: {micro_f1}")



# class representing the prompt
class Prompt():
  def __init__(self, prompt=""):
    self.prompt = prompt

  # Using few shot prompting
  def design_json_prompt(self, learning_alerts, user_alerts, min_label, max_label):
    """ Returns a prompt to query llama-3-8b for mitre attack alert classification in JSON format """
    user_alerts_str = ""
    user_alerts_str += f"Alert message: {user_alerts['alert']}, uid: {user_alerts['uid']}\n"
    learning_alerts_str = ""
    for index, item in learning_alerts.iterrows():
      learning_alerts_str += f"Alert code: {item['code']}, Alert name: {item['name']}, Alert message: {item['alert']}\n"
    json_schema_string = "{\"alert\":\"string (alert message Do not return any additional text, explanations, or comments)\", \"uid\":\"string (uid provided by the user with the alert)\", \"code\": \"number  (MITRE ATT&CK technique or sub-technique encoded label Do not return any additional text, explanations, or comments)\", \"name\": \"string (MITRE ATT&CK technique or sub-technique name Do not return any additional text, explanations, or comments)\", \"confidence_score\": \"float (number (0-1)) Do not return any additional text, explanations, or comments\"}"

    possible_classes = f'({min_label}-{max_label})'
    # Find the index of the word "label"
    label_index = json_schema_string.find("code\": \"number ")
    # Find the index right after the word "label"
    insert_position = label_index + len("code\": \"number ")
    json_schema_string = json_schema_string[:insert_position] + " " + possible_classes + json_schema_string[insert_position:]
    prompt = f"""\
Here are the learning alerts with their codes, names, and messages:

{learning_alerts_str}
Here are the new alert messages to classify:
{user_alerts_str}

Instructions:

1. For each alert message in the new set, analyze its content to determine the appropriate MITRE ATT&CK technique or sub-technique based on the learning alerts.
2. Use the patterns and classifications from the learning alerts to match each new alert message to the correct alert code.
3. You must return json of the classification label only. No additional text is allowed.
Present the results in the following format only, no other text:
<{json_schema_string}>
"""
    return prompt


# Methods to handle the data operations

def replace_alerts(sampled_alerts, remaining_alerts, new_sampled_alert, new_remaining_alert):
  """Replace alerts with different ones in dataframes
  move new_sampled_alert from remaining_alerts to sampled_alerts
  move new_remaining_alert from sampled_alerts to remaining_alerts
  This is based on Text and not on index"""

  # Identify the indices
  remaining_alerts_index_1 = remaining_alerts.loc[remaining_alerts['alert'] == new_sampled_alert].index[0]
  sampled_alerts_index_1 = sampled_alerts.loc[sampled_alerts['alert'] == new_remaining_alert].index[0]

  # Extract the rows to be moved
  remaining_alerts_row_1 = remaining_alerts.loc[remaining_alerts_index_1]
  sampled_alerts_row_1 = sampled_alerts.loc[sampled_alerts_index_1]

  # Remove the rows from the original DataFrames
  remaining_alerts = remaining_alerts.drop([remaining_alerts_index_1, remaining_alerts_index_2])
  sampled_alerts = sampled_alerts.drop(sampled_alerts_index_1)

  # Add the rows to the other DataFrames using pd.concat
  sampled_alerts = pd.concat([sampled_alerts, pd.DataFrame([remaining_alerts_row_1])], ignore_index=True)
  remaining_alerts = pd.concat([remaining_alerts, pd.DataFrame([sampled_alerts_row_1])], ignore_index=True)

  return sampled_alerts, remaining_alerts


def move_alert(sampled_alerts, remaining_alerts, new_sampled_alert):
  """move new_sampled_alert from remaining_alerts to sampled_alerts
  This is based on Text and not on index"""

  # Identify the indices
  remaining_alerts_index_2 = remaining_alerts.loc[remaining_alerts['alert'] == new_sampled_alert].index[0]

  # Extract the rows to be moved
  remaining_alerts_row_2 = remaining_alerts.loc[remaining_alerts_index_2]

  # Remove the rows from the original DataFrames
  remaining_alerts = remaining_alerts.drop([remaining_alerts_index_2])

  # Add the rows to the other DataFrames using pd.concat
  sampled_alerts = pd.concat([sampled_alerts, pd.DataFrame([remaining_alerts_row_2])], ignore_index=True)

  return sampled_alerts, remaining_alerts


# Does not used in main, but been used in the development process, HAVE not been carefully tested
def split_alerts(df):
  """ split the alerts to 2 of each class randomly, not cosidering similarity
  this method been used for the first data split and could been more efficient
  if it would consider representative of the entire feature distribution of a class """
  # Group by 'code' and sample 2 alerts for each code, if possible
  def sample_alerts(group):
      return group.sample(2, random_state=42) if len(group) >= 2 else group

  sampled_alerts = df.groupby('code').apply(sample_alerts).reset_index(drop=True)

  # Identify the alerts that were not selected
  remaining_alerts = df.loc[~df.index.isin(sampled_alerts.index)]

  # Display the sampled alerts and the remaining random alerts
  print("Sampled Alerts:")
  # Display the sampled alerts
  for i in range(len(sampled_alerts)):
      print(f"Alert code: {sampled_alerts['code'][i]}, Alert name:{sampled_alerts['name'][i]}, Alert message: {sampled_alerts['alert'][i]}")

  print("\nRandom Remaining Alerts:")
  # Display the remaining random alerts
  for i in range(len(random_remaining_alerts)):
      print(f"Alert code: {random_remaining_alerts['code'][i]}, Alert name:{random_remaining_alerts['name'][i]}, Alert message: {random_remaining_alerts['alert'][i]}")

  return sampled_alerts, remaining_alerts


# This method assumes that uid is a unique identifier of an alert
def get_train_val_alerts(prompt_df_path, data_path):
  """get the TRAIN ALERTS FOR THE PROMPT EXAMLES
  and other data, it validates that df_train and df_val
  dont have the same values based on the uid's
  prompt_df_path = path to the prompt data df
  data_path = path to the original train data provided"""

  df_train = pd.read_csv(prompt_df_path)
  data = pd.read_csv(data_path)

  # Finding common uids
  common_uids = df_train['uid']

  # Dropping the rows from `data` with common uids
  df_val = data[~data['uid'].isin(common_uids)]

  return df_train, df_val

# Text Preprocessing function, simple and could do more operations such as:
# remove an entire email adress for example
def preprocess_text(text):
    text = text.lower()  # Lowercase
    text = re.sub(r'[^a-z\s]', '', text)  # Remove special characters and numbers
    text = re.sub(r'\s+', ' ', text).strip()  # Remove extra spaces
    return text

# call this method on the train_data if a mistake been made on a prediction
def move_incorrect_prediction_to_train(df_train, df_val, correct_predictions):
  """move incorrect prediction to df_train remove them from df_val
  correct_prediction is bool list with True where the predictions were correct"""
  # Get the indices of the wrong predictions
  false_indices = [i for i, correct in enumerate(correct_predictions) if not correct]

  # Select rows from df_val where correct_predictions is False
  rows_to_move = df_val.iloc[false_indices]

  # Remove these rows from df_val
  df_val = df_val.drop(index=false_indices).reset_index(drop=True)

  # Append these rows to df_train using pd.concat
  df_train = pd.concat([df_train, rows_to_move]).reset_index(drop=True)

  return df_train, df_val


In [None]:
from tqdm import tqdm
import numpy as np


def main():

  # Load the data
  data = pd.read_csv('/content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/train.csv')

  # Load the data which was prepared for the prompt
  df_train, df_val = get_train_val_alerts(
      '/content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/prompt_train.csv',
      '/content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/train.csv'
      )

  # Clean the data
  data['alert'] = data['alert'].apply(preprocess_text)
  df_train['alert'] = df_train['alert'].apply(preprocess_text)
  df_val['alert'] = df_val['alert'].apply(preprocess_text)

  # Label Encoding
  label_encoder = LabelEncoder()
  data['code'] = label_encoder.fit_transform(data['code'])
  df_train['code'] = label_encoder.transform(df_train['code'])
  df_val['code'] = label_encoder.transform(df_val['code'])

  min_label = 0
  max_label = len(label_encoder.classes_)-1

  # Initialize the LlamaSmallClassifier
  llama_small_classifier = LlamaSmallClassifier("/content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/.env")

  # Initialize the Prompt class
  prompt = Prompt()

  tqdm.pandas()

  mitre_attack_arr = []
  y_preds = []
  y = []

  # Function to process each row
  def process_row(row, prompt, llama_small_classifier, df_train, min_label, max_label):
      my_prompt = prompt.design_json_prompt(learning_alerts=df_train, user_alerts=row, min_label=min_label, max_label=max_label)
      mitre_attack = llama_small_classifier.classify(my_prompt, row['uid'], min_label, max_label)
      y_preds.append(mitre_attack.code)
      y.append(row['code'])
      return mitre_attack

  # Apply the function to each row with a progress bar, using a lambda to pass additional arguments
  mitre_attack_arr = df_val.progress_apply(lambda row: process_row(row, prompt, llama_small_classifier, df_train, min_label, max_label), axis=1)

  # Ensure the results are properly added to mitre_attack_arr
  mitre_attack_arr = list(mitre_attack_arr)

  # Convert to numpy arrays
  actuals = np.array(y)
  predictions = np.array(y_preds) # we can save space by using mitre_attack_arr[i]['code'], this is more readable in my opinion

  # Calculate correct predictions
  correct_predictions = (actuals == predictions)
  num_correct = np.sum(correct_predictions)
  num_incorrect = len(actuals) - num_correct

  print(f"Number of correct predictions: {num_correct}")
  print(f"Number of incorrect predictions: {num_incorrect}")

  # if we had mistakes, can also remove similars from df_train if req.
  # df_train, df_val = move_incorrect_prediction_to_train(df_train, df_val, correct_predictions)
  # inverse_transform WITH LABEL ENCODER BEFORE SAVING THE NEW SAMPLES!!


  llama_small_classifier.print_classification_report(actuals, predictions)

  mitre_attack_arr = []

  # Function to process each row
  def process_row(row, prompt, llama_small_classifier, df_train, min_label, max_label):
      my_prompt = prompt.design_json_prompt(learning_alerts=df_train, user_alerts=row, min_label=min_label, max_label=max_label)
      mitre_attack = llama_small_classifier.classify(my_prompt, row['uid'], min_label, max_label)
      return mitre_attack

  # Now lets predict test results

  df_test = pd.read_csv('/content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/test.csv')

  # Apply the function to each row with a progress bar, using a lambda to pass additional arguments
  mitre_attack_arr = df_test.progress_apply(lambda row: process_row(row, prompt, llama_small_classifier, df_train, min_label, max_label), axis=1)

  # Ensure the results are properly added to mitre_attack_arr
  mitre_attack_arr = list(mitre_attack_arr)

  # Convert encoded labels back to original codes
  for attack in mitre_attack_arr:
      attack.code = label_encoder.inverse_transform([attack.code])[0]

  res_data = [(attack.uid, attack.code) for attack in mitre_attack_arr]
  # Create a DataFrame
  df = pd.DataFrame(res_data, columns=['uid', 'code'])

  # Specify the output file path
  output_file_path = '/content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/Test Predictions.csv'

  # Save the DataFrame to a CSV file
  df.to_csv(output_file_path, index=False)

  print(f"Data has been saved to {output_file_path}")


if __name__ == "__main__":
  main()

100%|██████████| 79/79 [03:18<00:00,  2.52s/it]


Number of correct predictions: 79
Number of incorrect predictions: 0
Confusion Matrix:
[[24  0  0  0  0  0]
 [ 0 26  0  0  0  0]
 [ 0  0 20  0  0  0]
 [ 0  0  0  6  0  0]
 [ 0  0  0  0  1  0]
 [ 0  0  0  0  0  2]]

Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        24
           2       1.00      1.00      1.00        26
           4       1.00      1.00      1.00        20
           5       1.00      1.00      1.00         6
           6       1.00      1.00      1.00         1
           7       1.00      1.00      1.00         2

    accuracy                           1.00        79
   macro avg       1.00      1.00      1.00        79
weighted avg       1.00      1.00      1.00        79

Micro-Averaged Precision: 1.0
Micro-Averaged Recall: 1.0
Micro-Averaged F1 Score: 1.0


100%|██████████| 45/45 [02:27<00:00,  3.27s/it]


Data has been saved to /content/drive/MyDrive/AI Engineer Home Assignment-20240709T133517Z-001/AI Engineer Home Assignment/Test Predictions.csv


## Questions

- Evaluate your results - How many correct answers you have? How many incorrect?
- Do you have any suggestion to make the results more accurate? (Do not implement)

*   Evaluate your results - How many correct answers you have? How many incorrect?
```
Few-Shot Prompting with Structured Examples:
Number of correct predictions: 79
Number of incorrect predictions: 0
```

*   Do you have any suggestion to make the results more accurate? (Do not implement)


1. Try more prompt Eng techniques
(Zero-shot Prompting
Chain-of-Thought Prompting
Self-Consistency
Generate Knowledge Prompting
Prompt Chaining
Tree of Thoughts
Retrieval Augmented Generation
Automatic Reasoning and Tool-use
Automatic Prompt Engineer
Active-Prompt
Directional Stimulus Prompting
Program-Aided Language Models
ReAct
Reflexion
Graph Prompting)
2. Try different data manipulations such as paraphrasing, synonym replacement, back-translation, Normalize text in different ways
3. Try create synthetic data and balance the DS
4. Try different model params
5. Supervised Fine-Tuning / RAG / PEFT
6. Try using the most accurate model for Text classification
7. Analyze misclassified alerts to understand common failure modes and iteratively improve the model
8. Ensure data diversity between specific class possible examples (instead of random sampling, can be done with siamese model and cosine similarity)
9. Add all the possible classes to the prompt (of mitre att&ck in general)
10. check confidence score in the json for more samples and move samples with low confidence score to the prompt (but make sure to not overfit the model)