# Notebook to run GPT, Gemini, Replicate models

In [None]:
from openai import OpenAI
import replicate
import vertexai
from vertexai.generative_models import GenerativeModel, GenerationConfig, HarmBlockThreshold, HarmCategory

import pandas as pd
import numpy as np
import os
import time
import re
import matplotlib.pyplot as plt

## Setup all APIs

In [None]:
# replicate
import os
from dotenv import find_dotenv, load_dotenv

find_dotenv()
load_dotenv()

In [None]:
from getpass import getpass
# ChatGPT
oai_key = os.getenv("OPENAI_API_KEY")

# if not oai_key:
#     oai_key = getpass("Enter your OpenAI API key: ")

openai_client = OpenAI(api_key=oai_key)

In [None]:
# Gemini
# project_id = ""   # add project ID and location
# vertexai.init(project=project_id, location="")

## Prediction Params & Method

In [None]:
# set parameters for more deterministic output
temperature = 0
top_p = 1
seed = 42
max_tokens = 2048

In [None]:
sys_prompt = 'You are a cybersecurity expert specializing in cyberthreat intelligence.'

In [None]:
model_mapping = {
    'llama3-70b': 'meta/meta-llama-3-70b-instruct',
    'llama3-8b': 'meta/meta-llama-3-8b-instruct',
    'gemini': 'gemini-1.5-pro', 
    'gpt3': 'gpt-3.5-turbo',
    'gpt4': 'gpt-4-turbo',
    'gpt4.1-mini': 'gpt-4.1-mini-2025-04-14',
    'gpt4o-mini': 'gpt-4o-mini-2024-07-18'
}

In [None]:
def get_single_prediction(question, model_name):
    if model_name.startswith('llama') or model_name.startswith('mistral'):
        # replicate
        model = model_mapping[model_name]
        prompt = sys_prompt + ' ' + question
        input = {'prompt': prompt, 'max_tokens': max_tokens, 'temperature': temperature, 'top_p': top_p, 'seed': seed}
        output = replicate.run(model, input=input)
        output = "".join(output)
    elif model_name.startswith('gemma'):
        # replicate
        model = model_mapping[model_name]
        prompt = sys_prompt + ' ' + question
        input = {'prompt': prompt, 'max_tokens': max_tokens, 'temperature': 0.01, 'top_p': top_p, 'seed': seed}
        output = replicate.run(model, input=input)
        output = "".join(output)
    elif model_name.startswith('01-ai'):
        # replicate
        model = model_mapping[model_name]
        prompt = sys_prompt + ' ' + question
        input = {'prompt': prompt, 'max_tokens': max_tokens, 'temperature': temperature, 'top_p': top_p, 'seed': seed}
        output = replicate.run(model, input=input)
        output = "".join(output)
    elif model_name.startswith('gpt'):
        # ChatGPT
        model = model_mapping[model_name]
        response = openai_client.chat.completions.create(
            model=model,
            messages=[
                {'role': 'system', 'content': sys_prompt},
                {'role': 'user', 'content': question}
            ],
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            seed=seed
        )
        output = response.choices[0].message.content
    elif model_name.startswith('gemini'):
        # Gemini
        model = model_mapping[model_name]
        model = GenerativeModel(model_name=model)
        prompt = sys_prompt + ' ' + question
        generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            max_output_tokens=max_tokens,
        )
        safety_settings = {
            HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
            HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_ONLY_HIGH,
        }
        response = model.generate_content(
            prompt,
            generation_config=generation_config,
            safety_settings=safety_settings
        )
        output = response.text
        time.sleep(1)   # so it doesn't throw error
    return output


#### Test

In [None]:
question = (
    "Analyze the following CVE description and map it to the appropriate CWE. "
    "Provide a brief justification. The last line of your answer should only contain the CWE ID.\n\n"
    "CVE Description:\n\n"
    "Dell EMC CloudLink 7.1 and all prior versions contain an Improper Input Validation Vulnerability. "
    "A remote low privileged attacker may potentially exploit this vulnerability, "
    "leading to execution of arbitrary files on the server."
)

##### Are all the APIS working?

In [None]:
print(get_single_prediction(question, 'gpt4.1-mini'))

In [None]:
# print(get_single_prediction(question, 'gemini'))

In [None]:
# print(get_single_prediction(question, 'llama3-8b'))

# Run Evaluation for a Dataset

### All formatting comes here
While these captures most output format of the LLMs we studied, we still had to manually collect some responses from the generated response file

In [None]:
def format_rcm(text):
    # Define the regex pattern for CWE ID
    cwe_pattern = r'CWE-\d+'

    # Find all matches in the text
    matches = re.findall(cwe_pattern, text)

    # Return the last match if any match is found, otherwise return the original text
    if matches:
        return matches[-1], True
    else:
        return text, False

def format_vsp(text):
    # Define the regex pattern for CVSS v3.1 vector string
    #cvss_pattern = r'AV:[^/]+?/AC:[^/]+?/PR:[^/]+?/UI:[^/]+?/S:[^/]+?/C:[^/]+?/I:[^/]+?/A:[^/]+?'
    cvss_pattern = r'AV:[A-Za-z]+/AC:[A-Za-z]+/PR:[A-Za-z]+/UI:[A-Za-z]+/S:[A-Za-z]+/C:[A-Za-z]+/I:[A-Za-z]+/A:[A-Za-z]+'


    # Find all matches in the text
    matches = re.findall(cvss_pattern, text)

    # Return the last match if any match is found, otherwise return the original text
    if matches:
        return matches[-1], True
    else:
        return text, False

def format_mcq(text):
    last_line = text.split('\n')[-1].rstrip()
    if last_line.startswith('A)') or last_line.startswith('B)') or last_line.startswith('C)') or last_line.startswith('D)'):
        return last_line[0]
    if last_line.endswith('A') or last_line.endswith('B') or last_line.endswith('C') or last_line.endswith('D'):
        return last_line[-1]
    if last_line.endswith('**'):
        return last_line[-3]
    if len(last_line) == 0:
        last_line = text.split('\n')[-2].rstrip()
        if last_line.startswith('A)') or last_line.startswith('B)') or last_line.startswith('C)') or last_line.startswith('D)'):
            return last_line[0]
        if last_line.endswith('A') or last_line.endswith('B') or last_line.endswith('C') or last_line.endswith('D'):
            return last_line[-1]
        if last_line.endswith('**'):
            return last_line[-3]
    return ' '.join(text.split('\n'))

def format_taa(text):
    # need to manually extract the attribution
    return ' '.join(text.split('\n'))

In [None]:
def run_evaluation(file_path, task, model_name):
    # Keep track of time and total #chars generated
    start_time = time.time()
    count_chars = 0
    instructions_failed = 0
    
    data = pd.read_csv(file_path, encoding='utf-8', sep='\t')

    # response contain the entire response, result the formatted result
    all_responses = []
    all_results = []
    
    for index, row in data.iterrows():
        prompt = row['Prompt']
        try:
            output = get_single_prediction(prompt, model_name)
            
            count_chars += len(output)
            all_responses.append(output)
            if task == 'rcm':
                answer, success = format_rcm(output)
                if not success:
                    instructions_failed += 1
            elif task == 'vsp':
                answer, success = format_vsp(output)
                if not success:
                    instructions_failed += 1      
            elif task == 'mcq':
                answer = format_mcq(output)
            elif task == 'taa':
                answer = format_taa(output)
            else:
                raise ValueError('Task unknown!')
        except Exception as e:
            answer = 'Error'
            all_responses.append(answer)
            print('Exception at row ', index+1)
            print(e)
        all_results.append(answer)
        print(index+1, answer)
        # print(index+1)


    time_taken = time.time() - start_time
    print('Time taken:', time_taken)
    print('#Characters generated:', count_chars)
    print('#Instructions failed:', instructions_failed)

    # Save all the responses & results
    out_response = file_path.split('.')[0] + '_' + model_name + '_response.txt'
    out_result = file_path.split('.')[0] + '_' + model_name + '_result.txt'

    with open(out_response, 'w', encoding='utf-8') as f:
        out_str = ''
        for i in range(len(all_responses)):
            out_str += '#####' + str(i+1) + '#####\n'
            out_str += all_responses[i]
            out_str += '\n\n'
        f.write(out_str)
    with open(out_result, 'w', encoding='utf-8') as f:
        f.write('\n'.join(all_results))

    print('------- Done --------')

In [None]:
run_evaluation('../data/cti-mcq.tsv', 'mcq', 'gpt4o-mini')

In [None]:
run_evaluation('../data/cti-rcm.tsv', 'rcm', 'gpt4o-mini')

In [None]:
run_evaluation('../data/cti-vsp.tsv', 'vsp', 'gpt4o-mini')

In [None]:
run_evaluation('../data/cti-taa.tsv', 'taa', 'gpt4o-mini')