# Designing an Intelligent Agent for the Wumpus World - Part 3

The Wumpus World is a classic problem in artificial intelligence that involves designing an agent to navigate a grid-based environment filled with hazards, such as pits and a fearsome Wumpus, while searching for gold. The agent must operate under uncertainty, leveraging logical reasoning and perception to make decisions in a partially observable world.

This project focuses on implementing an intelligent Wumpus World agent capable of:
Perceiving environmental cues such as breezes (near pits) and stenches (near the Wumpus).
Applying logical inference to deduce the safe path to the gold.

Balancing exploration and caution to avoid fatal encounters.

The project highlights key AI concepts, including decision-making under uncertainty, logical reasoning, and search algorithms, providing a practical application of intelligent agent design in a challenging and dynamic environment.

In [1]:
from typing import List
from enum import Enum, auto
import networkx
import matplotlib.pyplot as plt
import random
import numpy
import torch
from pomegranate.distributions import Categorical
from pomegranate.distributions import ConditionalCategorical
from pomegranate.bayesian_network import BayesianNetwork
from itertools import product

In [2]:
class Percept():
    time_step: int
    bump: bool
    breeze: bool
    stench: bool
    scream: bool
    glitter: bool
    reward: int
    done: bool
        
    def __init__(self, time_step: int, bump: bool, breeze: bool, stench: bool, scream: bool, glitter: bool, reward: bool, done: bool):
         self.time_step = time_step
         self.bump = bump
         self.breeze = breeze
         self.stench = stench
         self.scream = scream
         self.glitter = glitter
         self.reward = reward
         self.done = done

    def __str__(self):
        return f"\ntime_step: {self.time_step}\nbump: {self.bump}\nstench: {self.stench}\nbreeze: {self.breeze}\nscream: {self.scream}\nglitter: {self.glitter}\nreward: {self.reward}"

In [3]:
class Action(Enum):
    LEFT = 0
    RIGHT = 1
    FORWARD = 2
    GRAB = 3
    SHOOT = 4
    CLIMB = 5

In [4]:
class Orientation(Enum):
    E = 0
    S = 1
    W = 2
    N = 3

    def symbol(self) -> str:
        if self == Orientation.E:
            return ">"
        elif self == Orientation.S:
            return "v"
        elif self == Orientation.W:
            return "<"
        elif self == Orientation.N:
            return "^"
        else:
            return "0"

    def turn_right(self) -> 'Orientation':
        current_index = self.value
        next_index = (current_index + 1) % len(Orientation)
        return Orientation(next_index)
        
    def turn_left(self) -> 'Orientation':
        current_index = self.value
        prev_index = (current_index - 1) % len(Orientation)
        return Orientation(prev_index)
    
    def reverse(self) -> 'Orientation':
        current_index = self.value
        prev_index = (current_index + 2) % len(Orientation)
        return Orientation(prev_index)
    
    def __str__(self):
        if self == Orientation.E:
            return "E"
        elif self == Orientation.S:
            return "S"
        elif self == Orientation.W:
            return "W"
        elif self == Orientation.N:
            return "N"
        else:
            return "0"

In [5]:
class Verbosity(Enum):
    NONE = 0
    LOGS = 1
    ALL = 2

class NaiveAgent:
    verbosity: Verbosity
    
    def __init__(self, verbosity):
        self.verbosity = verbosity
    
    def choose_action(self):
        return random.choice(list(Action))
    
    def run(self):
        env = Environment()
        cumulative_reward = 0
        percept = env.init(0.2, False)
        while not percept.done:
            if self.verbosity == Verbosity.ALL:
                env.visualize()
            if self.verbosity != Verbosity.NONE:
                print('Percept:', percept)
            action = self.choose_action()
            if self.verbosity != Verbosity.NONE:
                print()
                print('Action:', action)
                print()
            percept = env.step(action)
            cumulative_reward += percept.reward
        if self.verbosity == Verbosity.ALL:
            env.visualize()
        if self.verbosity != Verbosity.NONE:
            print('Percept:', percept)
        print('Cumulative reward:', cumulative_reward)

In [6]:
class Location:
    x: int
    y: int
        
    def __init__(self, x: int, y: int):
        self.x = x
        self.y = y
        
    def __str__(self):
        return f'({self.x}, {self.y})'
    
    def __eq__(self, other):
        return self.x == other.x and self.y == other.y
    
    def __hash__(self):
        return hash((self.x, self.y))
    
    def is_left_of(self, location: 'Location')->bool:
        return self.x == location.x-1 and self.y == location.y
        
   
    def is_right_of(self, location: 'Location')->bool:
        return self.x == location.x+1 and self.y == location.y
        
    
    def is_above(self, location: 'Location')->bool:
        return self.x == location.x and self.y == location.y+1
        
    
    def is_below(self, location: 'Location')->bool:
        return self.x == location.x and self.y == location.y-1
        
    
    def neighbours(self)->List['Location']:
        return (Location(self.x-1, self.y), Location(self.x+1, self.y), Location(self.x, self.y+1), Location(self.x, self.y-1))
    
    
    # return True if location given is self's location
    def is_location(self, location: 'Location')->bool:
        return self.x == location.x and self.y == location.y
    
    
    def at_left_edge(self) -> bool:
        return self.x == 1
    
    
    def at_right_edge(self) -> bool:
        # if not 4x4, we need to specify the width and height.
        return self.x == 4
    
    
    def at_top_edge(self) -> bool:
        # if not 4x4, we need to specify the width and height.
        return self.y == 4  
    
   
    def at_bottom_edge(self) -> bool:
        return self.y == 1  
    
    def is_edge(self):
        return self.at_left_edge() or self.at_right_edge() or self.at_top_edge() or self.at_bottom_edge()
    
    def is_corner(self):
        return (self.x==1 and self.y==1) or (self.x==1 and self.y==4) or (self.x==4 and self.y==1) or (self.x==4 and self.y==4)
    
    def get_edge_count(self):
        if self.is_corner():
            return 2
        elif self.is_edge():
            return 3
        else:
            return 4
    
    def forward(self, orientation) -> bool:
        if orientation == Orientation.N:
            self.y += 1
        elif orientation == Orientation.W:
            self.x -= 1
        elif orientation == Orientation.S:
            self.y -= 1
        elif orientation == Orientation.E:
            self.x += 1
        
        if self.x < 1:
            self.x = 1
            return True
        elif self.x > 4:
            self.x = 4
            return True
        
        if self.y < 1:
            self.y = 1
            return True
        elif self.y > 4:
            self.y = 4
            return True
        return False

    def set_to(self, location: 'Location'):
        self.x = location.x
        self.y = location.y
        
    
    @staticmethod
    def from_linear(n: int) -> 'Location':
        pass
        # convert an index from 0 to 15 to a location
    
    
    def to_linear(self)->int:
        pass
        # convert self to an index from 0 to 15    
    
    
    @staticmethod
    def random() -> 'Location':
        x = random.randint(1, 4)
        y = random.randint(1, 4)
        return Location(x, y)

In [7]:
class Environment:
    wumpus_location: Location
    wumpus_alive: bool
    agent_location: Location
    agent_orientation: Orientation
    agent_has_arrow: bool
    agent_has_gold: bool
    game_over: bool
    gold_location: Location
    pit_locations: List[Location]
    time_step: int
    allow_climb_without_gold: bool
        
    def init(self, pit_prob: float, allow_climb_without_gold: bool):
        self.make_wumpus()
        self.make_gold()
        self.pit_locations = []
        self.make_pits(pit_prob)
        
        self.agent_location = Location(1,1)
        self.agent_orientation = Orientation.E
        self.agent_has_arrow = True
        self.agent_has_gold = False
        
        self.game_over = False
        self.time_step = 0
        self.allow_climb_without_gold = allow_climb_without_gold
        return Percept(self.time_step, False, self.is_breeze(), self.is_stench(), False, self.is_glitter(), 0, self.game_over)
         
        
    def get_random_location(self):
        temp_loc = None
        while True:
            temp_loc = Location.random()
            if not (temp_loc.x == 1 and temp_loc.y == 1):
                break
        return temp_loc
            
    def make_wumpus(self):
        self.wumpus_location = self.get_random_location()
        self.wumpus_alive = True
 
    def make_gold(self):
        self.gold_location = self.get_random_location()

    def make_pits(self, pit_prob: float):
        for y in range(1,5):
            for x in range(1,5):
                if x == 1 and y == 1:
                    pass
                elif random.random() <= pit_prob:
                    self.pit_locations.append(Location(x, y))
 
    def is_pit_at(self, location: Location) -> bool:
        if location in self.pit_locations:
            return True
        return False
        
    def is_pit_adjacent_to_agent(self) -> bool:
        for neighbor in self.agent_location.neighbours():
            if neighbor in self.pit_locations:
                return True
        return False

    def is_wumpus_adjacent_to_agent(self) -> bool:
        if self.wumpus_location in self.agent_location.neighbours():
            return True
        return False
        
    def is_agent_at_hazard(self)->bool:
        if self.agent_location is self.wumpus_location or\
            self.agent_location in self.pit_locations:
                return True
        return False

    def is_wumpus_at(self, location: Location) -> bool:
        return self.wumpus_location == location
        
    def is_agent_at(self, location: Location) -> bool:
        return self.agent_location == location 
    
    def is_gold_at(self, location: Location) -> bool:
        return self.gold_location == location
        
    def is_glitter(self) -> bool:
        return self.agent_location == self.gold_location
        
    def is_breeze(self) -> bool:
        if self.agent_location in self.pit_locations:
            return True
        for neighbor in self.agent_location.neighbours():
            if neighbor in self.pit_locations:
                return True
        return False

    def is_stench(self) -> bool:
        if self.agent_location == self.wumpus_location:
            return True
        for neighbor in self.agent_location.neighbours():
            if neighbor == self.wumpus_location:
                return True
        return False

    def wumpus_in_line_of_fire(self) -> bool:
        line_of_fire = []
        if self.agent_orientation == Orientation.N:
            if self.agent_location.y < 4:
                for y in range(self.agent_location.y, 5):
                    line_of_fire.append(Location(self.agent_location.x, y))
        if self.agent_orientation == Orientation.S:
            if self.agent_location.y > 1:
                for y in range(self.agent_location.y, 0, -1):
                    line_of_fire.append(Location(self.agent_location.x, y))
        if self.agent_orientation == Orientation.E:
            if self.agent_location.x < 4:
                for x in range(self.agent_location.x, 5):
                    line_of_fire.append(Location(x, self.agent_location.y))
        if self.agent_orientation == Orientation.W:
            if self.agent_location.x > 1:
                for x in range(self.agent_location.x, 0, -1):
                    line_of_fire.append(Location(x, self.agent_location.y))
                    
        if self.wumpus_location in line_of_fire:
            return True
        return False

    def kill_attempt(self) -> bool:
        if self.wumpus_in_line_of_fire() and self.wumpus_alive:
            self.wumpus_alive = False
            return True
        return False
        
    def step(self, action: Action) -> Percept:
        bump = False
        scream = False
        reward = -1
        if action == Action.FORWARD:
            bump = self.agent_location.forward(self.agent_orientation)
            if self.agent_has_gold:
                self.gold_location = self.agent_location
            if self.is_wumpus_at(self.agent_location) and self.wumpus_alive:
                reward -= 1000
                self.game_over = True
            if self.is_pit_at(self.agent_location) and not self.game_over:
                reward -= 1000
                self.game_over = True
        if action == Action.SHOOT and self.agent_has_arrow:
            reward -= 10
            scream = self.kill_attempt()
            self.agent_has_arrow = False
        if action == Action.CLIMB:
            if self.agent_location == Location(1,1):
                if self.allow_climb_without_gold:
                    self.game_over = True
                if self.agent_has_gold:
                    self.game_over = True
                    reward += 1000
        if action == Action.LEFT:
            self.agent_orientation = self.agent_orientation.turn_left()
        if action == Action.RIGHT:
            self.agent_orientation = self.agent_orientation.turn_right()
        if action == Action.GRAB and self.is_gold_at(self.agent_location):
            self.agent_has_gold = True
        
        self.time_step += 1
        return Percept(self.time_step, bump, self.is_breeze(), self.is_stench(), scream, self.is_glitter(), reward, self.game_over)
    
    # Visualize the game state
    def visualize(self):
        for y in range(3, -1, -1):
            line = '|'
            for x in range(0, 4):
                loc = Location(x+1, y+1)
                cell_symbols = [' ', ' ', ' ', ' ']
                if self.is_agent_at(loc): cell_symbols[0] = self.agent_orientation.symbol()
                if self.is_pit_at(loc): cell_symbols[1] = 'P'
                if self.is_wumpus_at(loc):
                    if self.wumpus_alive:
                        cell_symbols[2] = 'W'
                    else:
                        cell_symbols[2] = 'w'
                if self.is_gold_at(loc): cell_symbols[3] = 'G'
                for char in cell_symbols: line += char
                line += '|'
            print(line)

In [8]:
class Node:
    location: Location
    orientation: Orientation
        
    def __init__(self, location, orientation):
        self.location = location
        self.orientation = orientation

    def __eq__(self, other):
        return self.location == other.location and self.orientation == other.orientation
    
    def __str__(self):
        return f"X: {self.location.x} Y: {self.location.y} Or: {self.orientation}"
    
    def __hash__(self):
        return hash((self.location, self.orientation))

In [9]:
class NotSoNaiveAgent():
    verbosity: Verbosity
    has_gold: bool
    actions_to_victory: list
        
    def __init__(self, verbosity):
        self.has_gold = False
        self.verbosity = verbosity
        self.nodes = set()
        self.actions_to_victory = []
        self.graph = networkx.DiGraph()
            
    def add_new_nodes(self, location, was_forward, orientation):
        # Here we add all 4 orientations of a new location as node to the graph and nodes list
        # when we visit for the location for the first time
        newly_added_nodes = []
        for orient in Orientation:
            new_node = Node(location, orient)
            if new_node not in self.nodes:
                newly_added_nodes.append(new_node)
                self.nodes.add(new_node)
                self.graph.add_node(new_node, label=f"{new_node.location.x}, {new_node.location.y}, {new_node.orientation}")
                
        # Here we are adding edges between the connected rotations of the same location
        for new_node in newly_added_nodes:
            left_node = Node(location, new_node.orientation.turn_left())
            right_node = Node(location, new_node.orientation.turn_right())
            if left_node in self.nodes:
                self.graph.add_edge(new_node, left_node)
                self.graph.add_edge(left_node, new_node)
            if right_node in self.nodes:
                self.graph.add_edge(new_node, right_node)
                self.graph.add_edge(right_node, new_node)
        
        """If agent moved forward on the previous action, we are going to add a new edge
        between the previous node and the current one of the same orientation. ie. if agent
        moved from 1,1,E to 2,1,E.
        
        We are also adding a reversed path edge between the new location of reversed
        orientation and old location of reversed orientation such as if you moved from node
        1,1,E to 2,1,E we add a new edge from 2,1,W to 1,1,W
        """
        if was_forward:
            current_node = Node(location, orientation)
            previous_location = None
            if orientation == Orientation.E and location.x-1 > 0:
                previous_location = Location(location.x-1, location.y)
            elif orientation == Orientation.N and location.y-1 > 0:
                previous_location = Location(location.x, location.y-1)
            elif orientation == Orientation.W and location.x+1 < 5:
                previous_location = Location(location.x+1, location.y)
            elif orientation == Orientation.S and location.y+1 < 5:
                previous_location = Location(location.x, location.y+1)
            if previous_location:
                previous_node = Node(previous_location, orientation)
                new_node = None
                for node in self.nodes:
                    if node.location == location and node.orientation == orientation:
                        new_node = node
                        break
                if previous_node in self.nodes and new_node:
                    self.graph.add_edge(previous_node, new_node)

                reversed_previous_node = None
                reversed_new_node = None
                if Node(new_node.location, new_node.orientation.reverse()) in self.nodes:
                    reversed_new_node = Node(new_node.location, new_node.orientation.reverse())
                if Node(previous_node.location, previous_node.orientation.reverse()) in self.nodes:
                    reversed_previous_node = Node(previous_node.location, previous_node.orientation.reverse())
                if reversed_previous_node and reversed_new_node:
                    self.graph.add_edge(reversed_new_node, reversed_previous_node)
                        
            # Let's add edges between visited neighboring nodes.
            neighbors = []
            if location.x + 1 < 5:
                neighbors.append(Node(Location(location.x+1, location.y), Orientation.E))
            if location.y + 1 < 5:
                neighbors.append(Node(Location(location.x, location.y+1), Orientation.N))
            if location.x - 1 > 0:
                neighbors.append(Node(Location(location.x-1, location.y), Orientation.W))
            if location.y - 1 > 0:
                neighbors.append(Node(Location(location.x, location.y-1), Orientation.S))
            for neighbor in neighbors:
                if neighbor in self.nodes:
                    if neighbor.orientation == Orientation.E:
                        self.graph.add_edge(Node(location, Orientation.E), neighbor)
                        self.graph.add_edge(Node(neighbor.location, Orientation.W), Node(location, Orientation.W))
                    elif neighbor.orientation == Orientation.W:
                        self.graph.add_edge(Node(location, Orientation.W), neighbor)
                        self.graph.add_edge(Node(neighbor.location, Orientation.E), Node(location, Orientation.E))
                    elif neighbor.orientation == Orientation.N:
                        self.graph.add_edge(Node(location, Orientation.N), neighbor)
                        self.graph.add_edge(Node(neighbor.location, Orientation.S), Node(location, Orientation.S))
                    elif neighbor.orientation == Orientation.S:
                        self.graph.add_edge(Node(location, Orientation.S), neighbor)
                        self.graph.add_edge(Node(neighbor.location, Orientation.N), Node(location, Orientation.N))
            
        
    def draw_graph(self):
        if (self.verbosity == Verbosity.NONE):
            return
        pos = networkx.spring_layout(self.graph)  # positions for all nodes
        networkx.draw(self.graph, pos, with_labels=True, labels={node: str(node) for node in self.graph.nodes}, node_color='skyblue', node_size=500, edge_color='gray')
        plt.show()
            
    def choose_action(self):
        return random.choice([Action.LEFT, Action.RIGHT, Action.FORWARD])
    
    def find_actions_to_victory(self, current_node, nodes):
        self.actions_to_victory = []
        for node in nodes:
            if node == current_node:
                continue
                
            if Node(current_node.location, current_node.orientation.turn_left()) == node:
                self.actions_to_victory.append(Action.LEFT)
            elif Node(current_node.location, current_node.orientation.turn_right()) == node:
                self.actions_to_victory.append(Action.RIGHT)
            else:
                self.actions_to_victory.append(Action.FORWARD)
                
            current_node = node
                
            if current_node.location == Location(1,1):
                self.actions_to_victory.append(Action.CLIMB)
                break
    
    def manhattan_distance(self, node1, node2):
        return abs(node1.location.x - node2.location.x) + abs(node1.location.y - node2.location.y)
            
    def run(self):
        env = Environment()
        cumulative_reward = 0
        percept = env.init(0.2, False)
        was_forward = False
        while not percept.done:
            action = None
            self.add_new_nodes(Location(env.agent_location.x, env.agent_location.y), was_forward, env.agent_orientation)
            if self.verbosity == Verbosity.ALL:
                env.visualize()
            if self.verbosity != Verbosity.NONE:
                print('Percept:', percept)
            if len(self.actions_to_victory):
                action = self.actions_to_victory.pop(0)
            else:
                if percept.glitter and not self.has_gold:
                    action = Action.GRAB
                    self.has_gold = True
                    target_node = Node(Location(1, 1), Orientation.W)
                    current_node = Node(env.agent_location, env.agent_orientation)
                    self.find_actions_to_victory(current_node, networkx.astar_path(self.graph, current_node, target_node, heuristic=self.manhattan_distance))
                elif self.has_gold and env.agent_location.is_location(Location(1,1)):
                    action = Action.CLIMB
                else:
                    action = self.choose_action()
            
            if self.verbosity != Verbosity.NONE:
                print()
                print('Action:', action)
                print()
            percept = env.step(action)
            if action == Action.FORWARD and not percept.bump:
                was_forward = True
            else:
                was_forward = False
            cumulative_reward += percept.reward
        if self.verbosity == Verbosity.ALL:
            env.visualize()
            self.draw_graph()
        if self.verbosity != Verbosity.NONE:
            print('Percept:', percept)
        print('Cumulative reward:', cumulative_reward)
        

In [10]:
class Predicate():   
    def __init__(self, prob: float):
        self.p = prob
        
    def toList(self):
        return [1-self.p, self.p]
        
    def toCategorical(self):
        return Categorical([self.toList()])

class ProbAgent(NotSoNaiveAgent):
    pits = {}
    breezes = {}
    breeze_model = None
    heard_scream = False
    visited_locations = {}
    stenches = []
    wumpus_probs = []
    wumpus_alert = False
    stench_model = None
    pit_probs = []
    fully_explored_locations = {}
    risk_tolerance = 0.4
    
    def __init__(self, verbosity):
        self.pits = {}
        self.breezes = {}
        self.heard_scream = False
        self.visited_locations = {}
        self.stenches = []
        self.wumpus_probs = []
        self.wumpus_alert = False
        self.stench_model = None
        self.pit_probs = []
        self.fully_explored_locations = {}
        self.risk_tolerance = 0.4
        super().__init__(verbosity)
        for x in range(1,5):
            for y in range(1,5):
                if x == 1 and y == 1:
                    self.pits["11"] = Predicate(0.0).toCategorical()
                    self.visited_locations["11"] = True
                    self.fully_explored_locations["11"] = False
                    self.stenches.append(-1)
                    self.wumpus_probs.append(0)
                else:
                    self.pits[f"{x}{y}"] = Predicate(0.2).toCategorical()
                    self.visited_locations[f"{x}{y}"] = False
                    self.stenches.append(-1)
                    self.wumpus_probs.append(1./15)
    
    def generate_conditional_table(self, neighbor_count):
        table = []
        combis = []

        # Generate all possible combinations of neighbor states (False/True)
        for combination in product([False, True], repeat=neighbor_count):
            # breeze is true if any neighbor is pit
            if any(combination):
                combis.append(Predicate(1.0).toList())
            else:
                combis.append(Predicate(0.0).toList())

        if len(combis) == 4:
            table.append([combis[0], combis[1]])
            table.append([combis[2], combis[3]])
        elif len(combis) == 8:
            table.append([[combis[0], combis[1]], [combis[2], combis[3]]])
            table.append([[combis[4], combis[5]], [combis[6], combis[7]]])
        elif len(combis) == 16:
            table.append([[[combis[0], combis[1]], [combis[2], combis[3]]], [[combis[4], combis[5]], [combis[6], combis[7]]]])
            table.append([[[combis[8], combis[9]], [combis[10], combis[11]]], [[combis[12], combis[13]], [combis[14], combis[15]]]])

        return [table]

    def update_breeze(self, x, y, percept, edge_count):
        # We came to this location and we are alive means no pit here.
        self.pits[f"{x}{y}"] = Predicate(0.0).toCategorical()
        self.visited_locations[f"{x}{y}"] = True
        # If we don't have breeze on this location, it means none of our neighbors has pit
        if not percept.breeze:
            if x > 1:
                self.pits[f"{x - 1}{y}"] = Predicate(0.0).toCategorical()
            if x < 4:
                self.pits[f"{x + 1}{y}"] = Predicate(0.0).toCategorical()
            if y > 1:
                self.pits[f"{x}{y - 1}"] = Predicate(0.0).toCategorical()
            if y < 4:
                self.pits[f"{x}{y + 1}"] = Predicate(0.0).toCategorical()
                    
                    
        if f"{x}{y}" not in self.breezes.keys():
            self.breezes[f"{x}{y}"] = None
        self.breezes[f"{x}{y}"] = ConditionalCategorical(self.generate_conditional_table(edge_count))
    
    def build_bayesian_network(self, current_x, current_y):
        # Initialize variables for the Bayesian network
        variables = []
        edges = []

        # Check the current breeze location and relevant pits
        breeze_key = f"{current_x}{current_y}"
        if breeze_key in self.breezes.keys():
            # Check neighboring positions for pits
            neighbors = []
            if current_x > 1:  # Left
                neighbor_key = f"{current_x - 1}{current_y}"
                if neighbor_key in self.pits.keys():
                    neighbors.append(self.pits[neighbor_key])
                    variables.append(self.pits[neighbor_key])  # Add pit variable
            if current_y < 4:  # Up
                neighbor_key = f"{current_x}{current_y + 1}"
                if neighbor_key in self.pits.keys():
                    neighbors.append(self.pits[neighbor_key])
                    variables.append(self.pits[neighbor_key])  # Add pit variable
            if current_x < 4:  # Right
                neighbor_key = f"{current_x + 1}{current_y}"
                if neighbor_key in self.pits.keys():
                    neighbors.append(self.pits[neighbor_key])
                    variables.append(self.pits[neighbor_key])  # Add pit variable
            if current_y > 1:  # Down
                neighbor_key = f"{current_x}{current_y - 1}"
                if neighbor_key in self.pits.keys():
                    neighbors.append(self.pits[neighbor_key])
                    variables.append(self.pits[neighbor_key])  # Add pit variable
                    
            # Add the current breeze variable to the variables list
            variables.append(self.breezes[breeze_key])

            # Create edges between the breeze and its relevant neighboring pits
            for neighbor in neighbors:
                edges.append((neighbor, self.breezes[breeze_key]))

        # Build the Bayesian network with the filtered variables and edges
        return BayesianNetwork(variables, edges)
        
    def build_query(self, x, y, percept):
        query = []

        # Check bounds for neighbors
        if x > 1: # left neighbor
            if self.pits[f"{x - 1}{y}"].probs[0][1] == 0:
                query.append(0)
            elif self.pits[f"{x - 1}{y}"].probs[0][1] == 1:
                query.append(1)
            else:
                query.append(-1)
                
        if y < 4: # up neighbor
            if self.pits[f"{x}{y + 1}"].probs[0][1] == 0:
                query.append(0)
            elif self.pits[f"{x}{y + 1}"].probs[0][1] == 1:
                query.append(1)
            else:
                query.append(-1)
                
        if x < 4: # right neighbor
            if self.pits[f"{x + 1}{y}"].probs[0][1] == 0:
                query.append(0)
            elif self.pits[f"{x + 1}{y}"].probs[0][1] == 1:
                query.append(1)
            else:
                query.append(-1)

        if y > 1: # down neighbor
            if self.pits[f"{x}{y - 1}"].probs[0][1] == 0:
                query.append(0)
            elif self.pits[f"{x}{y - 1}"].probs[0][1] == 1:
                query.append(1)
            else:
                query.append(-1)

        if percept.breeze:
            query.append(1)
        else:
            query.append(0)
        return query

    
    def query_network(self, query):
        X = torch.tensor([query])
        X_masked = torch.masked.MaskedTensor(X, mask=X >= 0)
        return self.breeze_model.predict_proba(X_masked)
    
    def get_index(self, x, y):
        return (x-1)+(4*((y-1)%4))
    
    def update_stenches(self, x, y, percept):
        for prob in self.wumpus_probs:
            if prob == 1:
                return
        # We came to this location and we are alive means no wumpus here.
        self.wumpus_probs[self.get_index(x,y)] = 0
        
        neighbors = []
        if x > 1 and not self.visited_locations[f"{x-1}{y}"] and self.wumpus_probs[self.get_index(x-1, y)] > 0: # left neighbor
            neighbors.append(self.get_index(x-1,y))
        if x < 4 and not self.visited_locations[f"{x+1}{y}"] and self.wumpus_probs[self.get_index(x+1, y)] > 0: # right neighbor
            neighbors.append(self.get_index(x+1,y))
        if y > 1 and not self.visited_locations[f"{x}{y-1}"] and self.wumpus_probs[self.get_index(x, y-1)] > 0: # down neighbor
            neighbors.append(self.get_index(x,y-1))
        if y < 4 and not self.visited_locations[f"{x}{y+1}"] and self.wumpus_probs[self.get_index(x, y+1)] > 0: # up neighbor
            neighbors.append(self.get_index(x,y+1))

        if percept.stench:
            self.stenches[self.get_index(x, y)] = 1
            
            # If we have stench on this location, it means only our neighbors have probability of 
            # having wumpus so we equally split 1 between them.
            if x > 1 and self.wumpus_probs[self.get_index(x-1, y)] > 0:
                self.wumpus_probs[self.get_index(x-1,y)] = 1/len(neighbors)
            if x < 4 and self.wumpus_probs[self.get_index(x+1, y)] > 0:
                self.wumpus_probs[self.get_index(x+1,y)] = 1/len(neighbors)
            if y > 1 and self.wumpus_probs[self.get_index(x, y-1)] > 0:
                self.wumpus_probs[self.get_index(x,y-1)] = 1/len(neighbors)
            if y < 4 and self.wumpus_probs[self.get_index(x, y+1)] > 0:
                self.wumpus_probs[self.get_index(x,y+1)] = 1/len(neighbors)
                
            # None of the not neighboring locations have possibility of having wumpus
            for i in range(16):
                if i not in neighbors:
                    self.wumpus_probs[i] = 0
        else:
            self.stenches[self.get_index(x, y)] = 0

            # If we don't have stench on this location, it means none of our neighbors has wumpus
            if x > 1:
                self.wumpus_probs[self.get_index(x-1,y)] = 0
            if x < 4:
                self.wumpus_probs[self.get_index(x+1,y)] = 0
            if y > 1:
                self.wumpus_probs[self.get_index(x,y-1)] = 0
            if y < 4:
                self.wumpus_probs[self.get_index(x,y+1)] = 0
            
            # We need to split the probability of having wumpus between the unvisited locations
            # Not sure if this will have any effect but feels right to update it.
            unsafe_count = 0
            for i in range(16):
                if self.wumpus_probs[i] > 0:
                    unsafe_count += 1
            for i in range(16):
                if self.wumpus_probs[i] > 0:
                    self.wumpus_probs[i] = 1/unsafe_count

    def get_stench_probs(self, x, y):
        indexes = []
        if x > 1: # left
            indexes.append(self.get_index(x-1,y))
        if y < 4: # top
            indexes.append(self.get_index(x,y+1))
        if x < 4: # right
            indexes.append(self.get_index(x+1,y))
        if y > 1: # bottom
            indexes.append(self.get_index(x,y-1))
        probs = []
        for _index in indexes:
            probs.append(self.wumpus_probs[_index])
        return probs
    
    def get_breeze_probs(self):
        return [tensor[0, 1].item() for tensor in self.pit_probs[:-1]]
        
    def get_neighbors_probs(self, x ,y):
        breezes = self.get_breeze_probs()
        stenches = self.get_stench_probs(x, y)
        for stench in stenches:
            stench = 0
        neighbor_count = 0
        probs = []
        
        if x > 1: # left
            if stenches[neighbor_count] >= 0.5:
                self.wumpus_alert = True
            probs.append(stenches[neighbor_count] + breezes[neighbor_count] - (stenches[neighbor_count]*breezes[neighbor_count]))
            if breezes[neighbor_count] == 1:
                self.fully_explored_locations[f"{x-1}{y}"] = True
            neighbor_count += 1
        if y < 4: # top
            if stenches[neighbor_count] >= 0.5:
                self.wumpus_alert = True
            probs.append(stenches[neighbor_count] + breezes[neighbor_count] - (stenches[neighbor_count]*breezes[neighbor_count]))
            if breezes[neighbor_count] == 1:
                self.fully_explored_locations[f"{x}{y+1}"] = True
            neighbor_count += 1
        if x < 4: # right
            if stenches[neighbor_count] >= 0.5:
                self.wumpus_alert = True
            probs.append(stenches[neighbor_count] + breezes[neighbor_count] - (stenches[neighbor_count]*breezes[neighbor_count]))
            if breezes[neighbor_count] == 1:
                self.fully_explored_locations[f"{x+1}{y}"] = True
            neighbor_count += 1
        if y > 1: # bottom
            if stenches[neighbor_count] >= 0.5:
                self.wumpus_alert = True
            probs.append(stenches[neighbor_count] + breezes[neighbor_count] - (stenches[neighbor_count]*breezes[neighbor_count]))
            if breezes[neighbor_count] == 1:
                self.fully_explored_locations[f"{x}{y-1}"] = True
            neighbor_count += 1
            
        is_explored = True
        neighbor_count = 0
        if x > 1 and not self.visited_locations[f"{x-1}{y}"] and probs[neighbor_count] <= self.risk_tolerance:
            is_explored = False
        elif x > 1:
            neighbor_count += 1
        
        if y < 4 and not self.visited_locations[f"{x}{y+1}"] and probs[neighbor_count] <= self.risk_tolerance:
            is_explored = False
        elif y < 4:
            neighbor_count += 1
        
        if x < 4 and not self.visited_locations[f"{x+1}{y}"] and probs[neighbor_count] <= self.risk_tolerance:
            is_explored = False
        elif x < 4:
            neighbor_count += 1
            
        if y > 1 and not self.visited_locations[f"{x}{y-1}"] and probs[neighbor_count] <= self.risk_tolerance:
            is_explored = False
        elif y > 1:
            neighbor_count += 1
            
        if is_explored:
            self.fully_explored_locations[f"{x}{y}"] = True
        elif f"{x}{y}" not in self.fully_explored_locations.keys():
            self.fully_explored_locations[f"{x}{y}"] = False

        return probs
    
    def where_to_go(self, x, y, probs, is_stuck = False):
        is_left = False
        is_top = False
        is_right = False
        is_bottom = False
        neighbor_count = 0
        neighbor_coordinates = []
        min_prob = self.risk_tolerance

        if x > 1:
            if probs[neighbor_count] < min_prob:
                if not is_stuck and not self.visited_locations[f"{x-1}{y}"] or (is_stuck and not self.fully_explored_locations[f"{x-1}{y}"]):
                    min_prob = probs[neighbor_count]
                    is_left = True
            neighbor_count += 1
            neighbor_coordinates.append([x-1,y])
        
        if y < 4:
            if probs[neighbor_count] < min_prob:
                if not is_stuck and not self.visited_locations[f"{x}{y+1}"] or (is_stuck and not self.fully_explored_locations[f"{x}{y+1}"]):
                    min_prob = probs[neighbor_count]
                    is_left = False
                    is_top = True
            neighbor_count += 1
            neighbor_coordinates.append([x,y+1])
            
        if x < 4:
            if probs[neighbor_count] < min_prob:
                if not is_stuck and not self.visited_locations[f"{x+1}{y}"] or (is_stuck and not self.fully_explored_locations[f"{x+1}{y}"]):
                    min_prob = probs[neighbor_count]
                    is_left = False
                    is_top = False
                    is_right = True
            neighbor_count += 1
            neighbor_coordinates.append([x+1,y])
            
        if y > 1:
            if probs[neighbor_count] < min_prob:
                if not is_stuck and not self.visited_locations[f"{x}{y-1}"] or (is_stuck and not self.fully_explored_locations[f"{x}{y-1}"]):
                    min_prob = probs[neighbor_count]
                    is_left = False
                    is_top = False
                    is_right = False
                    is_bottom = True
            neighbor_count += 1
            neighbor_coordinates.append([x,y-1])
            
        
        location_to_go = []
        if is_left:
            location_to_go.append(x-1)
            location_to_go.append(y)
            
        if is_top:
            location_to_go.append(x)
            location_to_go.append(y+1)
            
        if is_right:
            location_to_go.append(x+1)
            location_to_go.append(y)
            
        if is_bottom:
            location_to_go.append(x)
            location_to_go.append(y-1)
            
        remaining_neighbors = []
        for coordinate in neighbor_coordinates:
            if not (self.visited_locations[f"{coordinate[0]}{coordinate[1]}"] or self.wumpus_probs[self.get_index(coordinate[0],coordinate[1])] == 1 or\
                    (f"{coordinate[0]}{coordinate[1]}" in self.fully_explored_locations.keys() and self.fully_explored_locations[f"{coordinate[0]}{coordinate[1]}"])):
                remaining_neighbors.append(coordinate)
                
        # We are leaving to the last available location from the original location, so let's mark the original as fully explored
        if len(remaining_neighbors) == 1 and location_to_go in remaining_neighbors:
            self.fully_explored_locations[f"{x}{y}"] = True
            
        return location_to_go
    
    def how_to_go(self, x, y, orientation, destination):
        actions_to_take = []
        if destination[0] > x:
            if orientation == Orientation.N:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.S:
                actions_to_take.append(Action.LEFT)
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.E:
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.W:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.FORWARD)
        elif destination[0] < x:
            if orientation == Orientation.N:
                actions_to_take.append(Action.LEFT)
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.S:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.E:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.W:
                actions_to_take.append(Action.FORWARD)
        elif destination[1] > y:
            if orientation == Orientation.N:
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.S:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.E:
                actions_to_take.append(Action.LEFT)
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.W:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.FORWARD)
        elif destination[1] < y:
            if orientation == Orientation.N:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.S:
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.E:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.FORWARD)
            elif orientation == Orientation.W:
                actions_to_take.append(Action.LEFT)
                actions_to_take.append(Action.FORWARD)
        return actions_to_take
    

    def to_shoot_wumpus(self, x, y, orientation):
        actions_to_take = []
        if orientation == Orientation.W:
            if x > 1 and self.wumpus_probs[self.get_index(x-1, y)] >= 0.5:
                actions_to_take.append(Action.SHOOT)
            elif y < 4 and self.wumpus_probs[self.get_index(x,y+1)] >= 0.5:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.SHOOT)
            elif x < 4 and self.wumpus_probs[self.get_index(x+1,y)] >= 0.5:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.SHOOT)
            elif y > 1 and self.wumpus_probs[self.get_index(x,y-1)] >= 0.5:
                actions_to_take.append(Action.LEFT)
                actions_to_take.append(Action.SHOOT)
        elif orientation == Orientation.N:
            if y < 4 and self.wumpus_probs[self.get_index(x,y+1)] >= 0.5:
                actions_to_take.append(Action.SHOOT)
            elif x < 4 and self.wumpus_probs[self.get_index(x+1,y)] >= 0.5:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.SHOOT)
            elif y > 1 and self.wumpus_probs[self.get_index(x,y-1)] >= 0.5:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.SHOOT)
            elif x > 1 and self.wumpus_probs[self.get_index(x-1,y)] >= 0.5:
                actions_to_take.append(Action.LEFT)
                actions_to_take.append(Action.SHOOT)
        elif orientation == Orientation.E:
            if x < 4 and self.wumpus_probs[self.get_index(x+1,y)] >= 0.5:
                actions_to_take.append(Action.SHOOT)
            elif y > 1 and self.wumpus_probs[self.get_index(x,y-1)] >= 0.5:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.SHOOT)
            elif x > 1 and self.wumpus_probs[self.get_index(x-1,y)] >= 0.5:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.SHOOT)
            elif y < 4 and self.wumpus_probs[self.get_index(x,y+1)] >= 0.5:
                actions_to_take.append(Action.LEFT)
                actions_to_take.append(Action.SHOOT)
        elif orientation == Orientation.S:
            if y > 1 and self.wumpus_probs[self.get_index(x,y-1)] >= 0.5:
                actions_to_take.append(Action.SHOOT)
            elif x > 1 and self.wumpus_probs[self.get_index(x-1,y)] >= 0.5:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.SHOOT)
            elif y < 4 and self.wumpus_probs[self.get_index(x,y+1)] >= 0.5:
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.RIGHT)
                actions_to_take.append(Action.SHOOT)
            elif x < 4 and self.wumpus_probs[self.get_index(x+1,y)] >= 0.5:
                actions_to_take.append(Action.LEFT)
                actions_to_take.append(Action.SHOOT)
        return actions_to_take
    

    def run(self):
        env = Environment()
        cumulative_reward = 0
        percept = env.init(0.2, True)
        was_forward = False
        was_shoot = False
        while not percept.done:
            if was_shoot:
                if percept.scream:
                    self.wumpus_alert = False
                    for i in range(len(self.wumpus_probs)):
                        self.wumpus_probs[i] = 0
                else:
                    if env.agent_orientation == Orientation.W and env.agent_location.x > 1:
                        self.wumpus_probs[self.get_index(env.agent_location.x-1, env.agent_location.y)] = 0
                    elif env.agent_orientation == Orientation.N and env.agent_location.y < 4:
                        self.wumpus_probs[self.get_index(env.agent_location.x, env.agent_location.y+1)] = 0
                    elif env.agent_orientation == Orientation.E and env.agent_location.x < 4:
                        self.wumpus_probs[self.get_index(env.agent_location.x+1, env.agent_location.y)] = 0
                    elif env.agent_orientation == Orientation.S and env.agent_location.y > 1:
                        self.wumpus_probs[self.get_index(env.agent_location.x, env.agent_location.y-1)] = 0
                    
                    
                    for i in range(len(self.wumpus_probs)):
                        if self.wumpus_probs[i] != 0:
                            self.wumpus_probs[i] = 1
                
                if f"{env.agent_location.x}{env.agent_location.y}" in self.fully_explored_locations.keys() and \
                            self.fully_explored_locations[f"{env.agent_location.x}{env.agent_location.y}"]:
                        self.fully_explored_locations[f"{env.agent_location.x}{env.agent_location.y}"] = False
                        
            # Pits-Breezes
            self.update_breeze(env.agent_location.x, env.agent_location.y, percept, env.agent_location.get_edge_count())
            self.breeze_model = self.build_bayesian_network(env.agent_location.x, env.agent_location.y)
            self.pit_probs = self.query_network(self.build_query(env.agent_location.x, env.agent_location.y, percept))
            if self.verbosity != Verbosity.NONE:
                print("Breeze Probs: ", self.pit_probs)
                
            # Wumpus-Stenches
            self.update_stenches(env.agent_location.x, env.agent_location.y, percept)
            if self.verbosity != Verbosity.NONE:
                print("Stench Probs: ", self.wumpus_probs)
            
            self.get_neighbors_probs(env.agent_location.x, env.agent_location.y)
            
            action = None
            self.add_new_nodes(Location(env.agent_location.x, env.agent_location.y), was_forward, env.agent_orientation)
            if self.verbosity == Verbosity.ALL:
                env.visualize()
            if self.verbosity != Verbosity.NONE:
                print('Percept:', percept)
            if len(self.actions_to_victory):
                action = self.actions_to_victory.pop(0)
            else:
                if percept.glitter and not self.has_gold:
                    action = Action.GRAB
                    self.has_gold = True
                    target_node = Node(Location(1, 1), Orientation.W)
                    current_node = Node(env.agent_location, env.agent_orientation)
                    self.find_actions_to_victory(current_node, networkx.astar_path(self.graph, current_node, target_node, heuristic=self.manhattan_distance))
                elif self.has_gold and env.agent_location.is_location(Location(1,1)):
                    action = Action.CLIMB
                else:
                    destination = self.where_to_go(env.agent_location.x, env.agent_location.y, self.get_neighbors_probs(env.agent_location.x, env.agent_location.y))
                    if not len(destination):
                        destination = self.where_to_go(env.agent_location.x, env.agent_location.y, self.get_neighbors_probs(env.agent_location.x, env.agent_location.y), True)  
                    if self.wumpus_alert and env.wumpus_alive and env.agent_has_arrow:
                        self.actions_to_victory = self.to_shoot_wumpus(env.agent_location.x, env.agent_location.y, env.agent_orientation)
                        action = self.actions_to_victory.pop(0)
                    else:
                        if not len(destination):
                            target_node = Node(Location(1, 1), Orientation.W)
                            current_node = Node(env.agent_location, env.agent_orientation)
                            self.find_actions_to_victory(current_node, networkx.astar_path(self.graph, current_node, target_node, heuristic=self.manhattan_distance))
                        else:
                            self.actions_to_victory = self.how_to_go(env.agent_location.x, env.agent_location.y, env.agent_orientation, destination)

                        if not len(self.actions_to_victory):
                            action = Action.CLIMB
                        else:
                            action = self.actions_to_victory.pop(0)
            
            if self.verbosity != Verbosity.NONE:
                print()
                print('Action:', action)
                print()
            percept = env.step(action)
            if action == Action.FORWARD and not percept.bump:
                was_forward = True
            else:
                was_forward = False
            if action == Action.SHOOT:
                was_shoot = True
            else:
                was_shoot = False
            cumulative_reward += percept.reward
        if self.verbosity == Verbosity.ALL:
            env.visualize()
            self.draw_graph()
        if self.verbosity != Verbosity.NONE:
            print('Percept:', percept)
        print('Cumulative reward:', cumulative_reward)
        return cumulative_reward

In [11]:
# Verbosity has 3 levels, NONE will show only the cumulative score, LOGS will show the percept logs
# ALL will show both visualization and logs together.
total_score = 0
#ProbAgent(Verbosity.ALL).run()
for i in range(1000):
    total_score += ProbAgent(Verbosity.NONE).run()
print("GRAND TOTAL AVERAGE: ", total_score/1000)



Cumulative reward: -13
Cumulative reward: 966
Cumulative reward: 970
Cumulative reward: -13
Cumulative reward: -22
Cumulative reward: 981
Cumulative reward: -40
Cumulative reward: -2
Cumulative reward: 970
Cumulative reward: -11
Cumulative reward: -2
Cumulative reward: -13
Cumulative reward: 993
Cumulative reward: -2
Cumulative reward: -2
Cumulative reward: -2
Cumulative reward: -2
Cumulative reward: -2
Cumulative reward: -19
Cumulative reward: -2
Cumulative reward: -13
Cumulative reward: -11
Cumulative reward: 955
Cumulative reward: -22
Cumulative reward: -2
Cumulative reward: -2
Cumulative reward: -19
Cumulative reward: -2
Cumulative reward: -2
Cumulative reward: -17
Cumulative reward: 993
Cumulative reward: -34
Cumulative reward: 993
Cumulative reward: -13
Cumulative reward: -19
Cumulative reward: -2
Cumulative reward: 966
Cumulative reward: 991
Cumulative reward: -19
Cumulative reward: -22
Cumulative reward: -2
Cumulative reward: 991
Cumulative reward: 988
Cumulative reward: 959
Cu