In [7]:
import sys
from os import path

# This depends on where the file is placed. If it is in the 'agents/llm/' it will work
# also there will be a log in 'agents/llm/env/logs' (it is relative to path) 
sys.path.append('../..')

from env.network_security_game import Network_Security_Environment
from env.scenarios import scenario_configuration, smaller_scenario_configuration, tiny_scenario_configuration
from env.game_components import ActionType, Action, IP, Data


from cyst.api.configuration import *
import openai
from tenacity import retry, stop_after_attempt
import argparse
import jinja2

In [51]:
# I do not show the locl services right now because they are not used in the game
local_services = ['bash', 'powershell', 'remote desktop service', 'windows login', 'can_attack_start_here']

# Needed to create the Action object
action_mapper = {
    "ScanNetwork": ActionType.ScanNetwork,
    "FindServices": ActionType.FindServices,
    "FindData": ActionType.FindData,
    "ExfiltrateData": ActionType.ExfiltrateData,
    "ExploitService": ActionType.ExploitService
}

In [64]:
# Validate the action send by the LLM
# If it is not valid I don't send it to the game
# There might be a bug in FindData and ExflitrateData because I could not test
# Aldo, the actions are invalid if are about hosts we don't know at the momemt
# bacause the environo

def validate_action_in_state(response, state):
    contr_hosts = [str(host) for host in state.controlled_hosts]
    known_hosts = [str(host) for host in state.known_hosts]
    known_nets = [str(net) for net in list(state.known_networks)]

    try:
        if response["action"] == 'ScanNetwork':
            if response["parameters"]["target_network"] in known_nets:
                return True 
        elif response["action"] == 'FindServices':
            if response["parameters"]["target_host"] in known_hosts:
                return True
        elif response["action"] == 'ExploitService':
            ip_addr = response["parameters"]["target_host"]
            if ip_addr in known_hosts: 
                for service in list(state.known_services[ip_addr]):
                    if service.name == response["parameters"]["target_service"]:
                        return True
        elif response["action"] == 'FindData':
            if response["parameters"]["target_host"] in contr_hosts:
                return True
        else:
            for ip_data in state.known_data:
                params = response["parameters"]
                if isinstance(params, str):
                    params = eval(params)
                ip_addr = params["source_host"]
                if ip_data == ip_addr and ip_addr in contr_hosts:
                    if params["data"] in list(state.known_data[ip_data]):
                        return True
        return False 
    except:
        return False

In [53]:
def create_status_from_state(state, memories):
    """
    Read the returned states and the memories (if exist) and create
    a status prompt for the LLM
    """
    contr_hosts = [str(host) for host in state.controlled_hosts]
    known_hosts = [str(host) for host in state.known_hosts]
    known_nets = [str(net) for net in list(state.known_networks)]

    prompt = "Current status:\n"
    if len(memories) > 0:
        for memory in memories:
            prompt += f'You have taken action {{"action":"{memory[0]}", "parameters":"{memory[1]}"}} in the past. {memory[2]}\n' 
    else:
        prompt += ""
    prompt += f"Controlled hosts are {' and '.join(contr_hosts)}\n"
    
    prompt += f"Known networks are {' and '.join(known_nets)}\n"
    prompt += f"Known hosts are {' and '.join(known_hosts)}\n"

    for ip_service in state.known_services:
        services = []
        if len(list(state.known_services[ip_service])) > 0:
            for serv in state.known_services[ip_service]:
                if serv.name not in local_services:
                    services.append(serv.name)
            if len(services) > 0:
                prompt += f"Known services for host {ip_service} are {' and '.join(str(services))}\n"
    
    for ip_data in state.known_data:
        if len(state.known_data[ip_data]) > 0:
            prompt += f"Known data for host {ip_data} are {' and '.join(list(state.known_data[ip_data]))}\n"

    return prompt

In [54]:
# Initial goal
goal = {
    "known_networks":set(),
    "known_hosts":set(),
    "controlled_hosts":set(),
    "known_services":{},
    "known_data":{IP("213.47.23.195"):Data("User1", "DataFromServer1")}
}

attacker_start = {
    "known_networks":set(),
    "known_hosts":set(),
    "controlled_hosts":{IP("213.47.23.195"),IP("192.168.2.2")},
    "known_services":{},
    "known_data":{}
}

In [55]:
# Create the environment with the big scenario
env = Network_Security_Environment(random_start=False, verbosity=0)
cyst_config = scenario_configuration.configuration_objects

In [57]:
# Init the environment
observation = env.initialize(win_conditions=goal, 
                                 defender_positions=False, 
                                 attacker_start_position=attacker_start, 
                                 max_steps=30, 
                                 agent_seed=42,
                                 cyst_config=cyst_config)


In [58]:
print(observation.state.as_json)

<bound method GameState.as_json of GameState(controlled_hosts={213.47.23.195, 192.168.2.2}, known_hosts={213.47.23.195, 192.168.2.2}, known_services={}, known_data={}, known_networks={192.168.2.0/24, 192.168.3.0/24, 192.168.1.0/24})>


In [47]:
# This is the status prompt
print(create_status_from_state(observation.state, []))

Current status:
Controlled hosts are 213.47.23.195 and 192.168.2.2
Known networks are 192.168.2.0/24 and 192.168.3.0/24 and 192.168.1.0/24
Known hosts are 213.47.23.195 and 192.168.2.2



In [48]:
# Create the action the LLM give manually
response = {"action":"FindServices", "parameters":{"target_host": "192.168.2.3"}}
params = response["parameters"]

In [63]:
# Validate the action
# This is optional in the manual test, let the environment tell you if it's a valid action
validate_action_in_state(response, observation.state)

False

In [59]:
# Take the action
action = Action(action_mapper[response["action"]], params)
action

Action <ActionType.FindServices|{'target_host': '192.168.2.3'}>

In [60]:
# Read the new state
observation = env.step(action)
observation.state.as_json

<bound method GameState.as_json of GameState(controlled_hosts={213.47.23.195, 192.168.2.2}, known_hosts={213.47.23.195, 192.168.2.2}, known_services={}, known_data={}, known_networks={192.168.2.0/24, 192.168.3.0/24, 192.168.1.0/24})>

In [62]:
# If the new state is "bigger" that the previous tell the LLM it was successfull
# If not tell it it was not helpful
# If it was invalid, say it was invalid