# Environment 

In [14]:
import gym
from gym import Env, spaces
import numpy as np
import pygame
import random
import time
import math
# import folium  # Disabled - not needed for black background
# from selenium import webdriver  # Disabled - not needed for black background
# from PIL import Image  # Disabled - not needed for black background
import os
import copy
from typing import List

In [15]:
class Ship:
    def __init__(self, environment, ship_id: int, x: int = 0, y: int = 0, speed=5, screen_width: int = 600, screen_height: int = 400, 
                 ship_type: str = 'ship', firing_range: int = 100, ship_health: int = 100, 
                 reload_delay: float = 0.5, target_delay: float = 0.2, helicop_count: int = 0, torpedo_count: int = 100, 
                 torpedo_fire_speed: float = 2.0, torpedo_damage: int = 1, decoyM_count: int = 0, 
                 decoyM_speed: float = 4.0, decoyM_blast_range: float = 2.0):
        
        self.ship_id = ship_id
        self.x = x
        self.y = y
        self.speed = speed
        self.width = screen_width
        self.height = screen_height
        self.ship_type = ship_type
        self.ship_health = ship_health
        self.firing_range = firing_range
        self.reload_delay = reload_delay
        self.target_delay = target_delay
        
        # details ffor torpedo 
        self.torpedo_fire_speed = torpedo_fire_speed
        self.torpedo_count = torpedo_count
        self.torpedo_damage = torpedo_damage
        self.torpedoes = []  # Store active torpedoes
        self.last_fire_time = 0
        self.target_lock_time = 0
        
        # for Decoy missile 
        self.decoyM_count = decoyM_count
        self.decoy_missile = []  # Store active decoy_missile
        self.decoyM_speed = decoyM_speed
        self.decoyM_blast_range = decoyM_blast_range
        self.last_decoy_fire_time = 0
        self.decoy_target_lock_time = 0
        
        
        self.helicop_count = helicop_count
        self.env = environment

    
    def set_position(self, x: int, y: int) -> bool:
        """Ensures the ship's position stays within screen boundaries."""
        # Clamp coordinates to stay within screen boundaries
        clamped_x = np.clip(x, 0, self.width - 10)
        clamped_y = np.clip(y, 0, self.height - 10)

        self.x, self.y = clamped_x, clamped_y
        return (x == clamped_x and y == clamped_y)
    
    
    def move_ship_to_direction(self, heading) -> None:
        """Moves the ship in the direction of its heading based on speed."""
        # Update position based on heading
        delta_x = self.speed * np.cos(np.radians(heading))
        delta_y = self.speed * np.sin(np.radians(heading))
        self.x += delta_x
        self.y += delta_y

        # Clamp position to screen boundaries
        self.set_position(int(self.x), int(self.y))
        

    def move_ship_to_coordinates(self, target_position, threshold=0.01, angle_increment=45, max_angle_adjustment=90):
        """
        Move the ship towards the target position, avoiding collisions by adjusting the heading dynamically.
        Instead of random angle adjustment, it checks for an optimal collision-free direction.
        """
        current_position = self.get_position()
        direction_vector = target_position - current_position
        distance = np.linalg.norm(direction_vector)

        # Move the ship if it's close enough to the target position (within the threshold)
        if distance <= threshold:
            self.set_position(target_position[0], target_position[1])
            return True  # Movement is complete

        # Check for collisions before moving the ship
        ship, ship_in_way = self.env.check_for_collisions_while_ship_moves(self, self.env.ships, target_position)

        if ship_in_way:
            # Calculate the current movement angle toward the target
            current_movement_angle = np.arctan2(direction_vector[1], direction_vector[0]) * (180 / np.pi)

            # Try adjusting the angle in small increments until a collision-free path is found
            for adjustment in range(0, max_angle_adjustment + angle_increment, angle_increment):
                # Check both left and right directions
                for angle_adjustment in [-adjustment, adjustment]:
                    new_heading = (current_movement_angle + angle_adjustment) % 360

                    # Convert the adjusted heading back to a unit vector for movement
                    new_direction_vector = np.array([
                        np.cos(np.radians(new_heading)),
                        np.sin(np.radians(new_heading))
                    ])

                    # Normalize and calculate the new position based on the new heading
                    move_distance = min(self.speed, distance)
                    new_position = current_position + new_direction_vector * move_distance

                    # Check if this new direction avoids a collision
#                     ship, ship_in_way = self.env.check_for_collisions_while_ship_moves(self, self.env.ships, new_position)
                    ship_in_way = self.env.check_if_blocking_los(current_position, new_position, ship.get_position())
                    if not ship_in_way:  # No collision in this direction
                        self.set_position(new_position[0], new_position[1])
                        return False  # Movement is not yet complete but avoiding a collision

        # If no collision is detected, move directly toward the target
        move_distance = min(self.speed, distance)
        new_position = current_position + (direction_vector / distance) * move_distance

        # Set the new position
        return not self.set_position(new_position[0], new_position[1])  # Return True if boundary reached


    def get_position(self) -> np.ndarray:
        """Returns the current position of the ship as a numpy array."""
        return np.array([self.x, self.y])

    
    def take_damage(self, damage: int):
        """Reduces the ship's health by a given damage value and returns reward and done status."""
        self.ship_health = max(0, self.ship_health - damage)
        reward = 100 if self.ship_health == 0 else 50
        done = self.ship_health == 0
        return reward, done

    
    def target_in_range(self, target_pos, threshold=5) -> bool:
        """Checks if the target is within firing range."""
        distance_to_target = np.linalg.norm(target_pos - self.get_position())
        return distance_to_target <= self.firing_range + threshold
    

    def __repr__(self) -> str:
        info = (
            f"Ship(id={self.ship_id}, type={self.ship_type}, x={self.x}, y={self.y}, "
            f"health={self.ship_health}, speed={self.speed}, firing_range={self.firing_range}, "
            f"torpedo_count={self.torpedo_count}, torpedo_speed={self.torpedo_fire_speed}, torpedo_damage={self.torpedo_damage}, "
            f"active_torpedoes={len(self.torpedoes)}, "
            f"decoy_missiles={self.decoyM_count}, "
            f"helicop_count={self.helicop_count}"
        )
        return info


In [16]:
class Torpedo:
    def __init__(self, torpedo_id, x, y, speed, damage, direction, source, target, screen_width, screen_height):
        """
        Initialize a torpedo with position, speed, damage, source, and target attributes.
        """
        self.id = torpedo_id
        self.x = x
        self.y = y
        self.speed = speed
        self.damage = damage
        self.direction = direction
        self.source = source  # The ship that launched the torpedo
        self.target = target  # Target ship
        self.screen_width = screen_width
        self.screen_height = screen_height

        self.target_hit = False

    
    def move(self):
        """Update the torpedo's position based on its speed and direction."""
        self.x += self.speed * np.cos(self.direction)
        self.y += self.speed * np.sin(self.direction)

    
    def within_bounds(self):
        """
        Check if the torpedo is still within the screen bounds.
        Returns True if the torpedo is within bounds, False otherwise.
        """
        return 0 <= self.x <= self.screen_width and 0 <= self.y <= self.screen_height

    
    def hit_target(self, threshold=10):
        """
        Check if the torpedo has hit its target based on a distance threshold.
        Calls target's `take_damage` method and returns reward and done status if hit.
        """
        distance = np.linalg.norm(np.array([self.x, self.y]) - self.target.get_position())
        if distance < threshold:
            reward, done = self.target.take_damage(self.damage)
            self.target_hit = True
            return reward, done
        return 0, False

    
    def check_collision(self, ships, threshold=10):
        """
        Check for collisions with ships other than the target.
        Uses vector projection to detect if the torpedo is on a collision course.
        Returns True if a collision is detected, False otherwise.
        """
        for ship in ships:
            if ship == self.target:
                continue  # Skip checking for collision with the target

            ship_position = np.array(ship.get_position())
            torpedo_position = np.array([self.x, self.y])
            direction_vector = np.array([np.cos(self.direction), np.sin(self.direction)])

            # Vector from torpedo to ship
            vector_to_ship = ship_position - torpedo_position
            projection = np.dot(vector_to_ship, direction_vector)

            # Check if the ship is in the direction of the torpedo and within collision range
            if projection > 0:  # If ship is ahead of the torpedo
                distance = np.linalg.norm(vector_to_ship)
                if distance < threshold:
                    return True  # Collision detected
        return False  # No collision detected

    
    def __repr__(self):
        """
        String representation of the torpedo for debugging purposes.
        """
        return (f"Torpedo(id={self.id}, x={self.x:.2f}, y={self.y:.2f}, speed={self.speed}, "
                f"damage={self.damage}, source={self.source}, target={self.target})")


In [17]:
class DefenseSystem():
    def __init__(self, env):
        """
        Initialize the DefenseSystem with defenders, HVU (High-Value Unit), and attacker.
        """
        self.env = env  # Reference to the NavalShipEnv instance
        
        # Defense system state
        self.defense_active = False
        self.formation_done = False
    
    
    def move_defenders_in_circle(self, circle_radius=100, speed=0.01):
        """Move defender ships in a circular pattern around the HVU with controlled speed."""
        center_x, center_y = self.env.hvu.get_position()

        # Speed controls how fast the ships move around the HVU.
        for i, ship in enumerate(self.env.defender_ships):
            # Introduce a speed factor to control movement speed
            angle = (self.env.info['step count'] * speed + i * 2 * np.pi / self.env.num_defenders) % (2 * np.pi)

            # Calculate the new position in a circular formation
            new_x = center_x + circle_radius * np.cos(angle)
            new_y = center_y + circle_radius * np.sin(angle)
            new_position = np.array([new_x, new_y])

            if self.env.info['step count'] == 0:  # Initial setup on the first step
                # Set the starting position for the ship
                ship.set_position(new_x, new_y)
            else:
                # Transition into the circular formation by moving to the target position
                ship.move_ship_to_coordinates(new_position)
                
                
    def set_defenders_in_circle(self, circle_radius=100):
        """Place defender ships in a fixed circular formation around the HVU (does not move)."""

        center_x, center_y = self.env.hvu.get_position()

        num_defenders = len(self.env.defender_ships)
        angle_step = 2 * np.pi / num_defenders  # Evenly distribute defenders

        for i, ship in enumerate(self.env.defender_ships):
            # Compute fixed position for each defender based on its index
            angle = i * angle_step  # Each ship gets a unique angle
            new_x = center_x + circle_radius * np.cos(angle)
            new_y = center_y + circle_radius * np.sin(angle)

            # Set defender's position (without moving it dynamically)
            ship.set_position(int(new_x), int(new_y))
            


    def check_for_defense_activation(self):
        """
        Check if the attacker is detected by any defender and activate defense mechanism if necessary.
        """
        in_range_defenders = self.attacker_within_defender_range()

        if in_range_defenders:  # If any defender detects the attacker
            self.att_revealed_pos = self.env.attacker_ship.get_position()
            self.hvu_revealed_pos = self.env.hvu.get_position()
            self.defense_active = True  # Activate defense mechanism
            
    
    def attacker_within_defender_range(self):
        """Check if the attacker is within any defenders' firing range and return the IDs of the defenders."""
        in_range_def = []
        attacker_pos = self.env.attacker_ship.get_position()

        # Loop through defender ships and check if attacker is in their firing range
        for defender in self.env.defender_ships:
            if defender.target_in_range(attacker_pos):
                in_range_def.append(defender)
        return in_range_def
    
    
    def handle_defense_mechanism(self, formation_type):
        """
        Execute the defense mechanism: move the entire unit (defenders and HVU) away from the attacker.
        """

        if not self.formation_done:
            # Make defender move to make the desired formation
            self.formation_done = self.move_defenders_in_formation(formation_type)
            
            
        if self.formation_done:
            # Find the direction to move away from the attacker
            self.exit_direction = self.find_escape_direction(self.hvu_revealed_pos, self.att_revealed_pos)
            
            # Move the entire unit away from the attacker and check if they are in a safe zone
            is_unit_safe = self.move_unit_away_from_attacker(self.exit_direction)

            # Deactivate defense once the unit is in the safe zone
            if is_unit_safe:
                self.defense_active = False
                self.formation_done = False
            
    
    def move_defenders_in_formation(self, formation_type):
        if formation_type == 'circle':
            target_positions = self.circular_formation()
        elif formation_type == 'triangle':
            target_positions = self.triangular_formation()
        elif formation_type == 'line':
            target_positions = self.line_formation()
        elif formation_type == 'wedge':
            target_positions = self.wedge_formation()
        elif formation_type == 'semicircle':
            target_positions = self.half_circular_formation()
        
            
        if self.env.info['step count'] == 0:  # Initial setup on the first step
            for i, ship in enumerate(self.env.defender_ships):
                # Set the starting position for the ship
                ship.set_position(target_positions[i][0], target_positions[i][1])
            return True
        else:        
            return self.make_formation(self.env.defender_ships, target_positions)
        
    
    
    def half_circular_formation(self, circle_radius=100, start_angle_deg=90):
        center_x, center_y = self.env.hvu.get_position()
        num_defenders = len(self.env.defender_ships)

        # Convert angle range to radians
        start_angle_rad = np.radians(start_angle_deg)
        end_angle_rad = start_angle_rad + np.pi  # 180° span

        # Step between angles
        if num_defenders > 1:
            angle_step = (end_angle_rad - start_angle_rad) / (num_defenders - 1)
        else:
            angle_step = 0  # Only one ship at center

        # Calculate and return positions
        return [
            (
                center_x + circle_radius * np.cos(start_angle_rad + i * angle_step),
                center_y + circle_radius * np.sin(start_angle_rad + i * angle_step)
            )
            for i in range(num_defenders)
        ]

            
    
    # Function to generate target positions in a circular formation
    def circular_formation(self, radius=100):

        center=self.env.hvu.get_position()
        num_defenders = len(self.env.defender_ships)
        
        angle_step = 2 * np.pi / num_defenders
        return [
            (
                center[0] + radius * np.cos(i * angle_step),
                center[1] + radius * np.sin(i * angle_step)
            )
            for i in range(num_defenders)
        ]

    
    # Function to generate target positions in a line formation
    def triangular_formation(self, radius_increment=50, initial_angle=0):
        """
        Creates a triangular formation for defenders with rotated orientations in multiple circles.
        """
        positions = []
        angle_increment = 360 / 3  # Each triangle has 3 ships, so 120 degrees apart
        num_circles = math.ceil(len(self.env.defender_ships) / 3)
        hvu_position = self.env.hvu.get_position()
        for circle in range(1, num_circles + 1):
            # Calculate the radius of the current circle
            radius = circle * radius_increment

            # Rotate each triangle slightly more with each circle
            rotation_angle = initial_angle + (circle - 1) * 180  # Increase the angle by 20 degrees per circle
            rotation_radians = np.radians(rotation_angle)

            for i in range(3):  # 3 ships in each triangle
                angle_degrees = i * angle_increment + rotation_angle
                angle_radians = np.radians(angle_degrees)

                # Calculate the position of each defender in the triangle
                x = hvu_position[0] + radius * np.cos(angle_radians)
                y = hvu_position[1] + radius * np.sin(angle_radians)
                positions.append((x, y))

        return positions
    
    
    def line_formation(self, line_length=300, distance_from_hvu=30):

        # Vector from HVU to attacker
        hvu_to_attacker_vector = self.att_revealed_pos - self.hvu_revealed_pos
        
        # Convert to float type before normalizing
        hvu_to_attacker_vector = hvu_to_attacker_vector.astype(float)

        # Normalize this vector to unit length
        hvu_to_attacker_vector /= np.linalg.norm(hvu_to_attacker_vector)

        # Step 2: Calculate the perpendicular vector
        # Perpendicular vector in 2D is simply (-y, x)
        perp_vector = np.array([-hvu_to_attacker_vector[1], hvu_to_attacker_vector[0]])
        
        # The line should be distance_from_hvu away from the HVU, in the direction of the perpendicular vector
        center_of_line = self.hvu_revealed_pos + hvu_to_attacker_vector * distance_from_hvu
        
        return self.find_points_along_vector(center_of_line, perp_vector, line_length)
        
    
    def find_points_along_vector(self, center, direction_vector, line_length):
        """
        Calculate n equally spaced points along a line centered at (x_center, y_center) using a direction vector.
        """
        total_points = len(self.env.defender_ships)
        
        # Normalize the direction vector to get a unit vector
        direction_vector = direction_vector / np.linalg.norm(direction_vector)

        # Calculate the spacing between points
        if total_points == 1:
            spacing = 0
        else:
            spacing = line_length / (total_points - 1)  # Spacing between points

        points = []

        # For odd number of points, the center is part of the points
        if total_points % 2 == 1:
            half_points = total_points // 2
            for i in range(-half_points, half_points + 1):
                offset = i * spacing
                new_x = center[0] + offset * direction_vector[0]
                new_y = center[1] + offset * direction_vector[1]
                points.append((new_x, new_y))

        # For even number of points, there is no exact center point
        else:
            half_points = total_points // 2
            for i in range(-half_points, half_points):
                offset = (i + 0.5) * spacing  # Shift by half to balance the points around the center
                new_x = center[0] + offset * direction_vector[0]
                new_y = center[1] + offset * direction_vector[1]
                points.append((new_x, new_y))

        return points
    
    
    def wedge_formation(self, spread=300, distance_from_hvu=40):
        """
        Generates positions for n defender ships in a V-shaped formation to protect the HVU.
        Half of the V will be above the hvu_to_attacker_vector and half below.
        """

        # Vector from HVU to attacker
        hvu_to_attacker_vector = self.att_revealed_pos - self.hvu_revealed_pos
        
        # Convert to float type before normalizing
        hvu_to_attacker_vector = hvu_to_attacker_vector.astype(float)

        # Normalize the hvu_to_attacker vector to get its direction
        hvu_to_attacker_vector /= np.linalg.norm(hvu_to_attacker_vector)

        # Perpendicular vector to the HVU to attacker vector
        perp_vector = np.array([-hvu_to_attacker_vector[1], hvu_to_attacker_vector[0]])

        # The center of the V formation should be distance_from_hvu away from the HVU along the hvu_to_attacker_vector
        center = self.hvu_revealed_pos + hvu_to_attacker_vector * distance_from_hvu

        n = len(self.env.defender_ships)
        positions = []

        # For the upper part of the V (above the hvu_to_attacker_vector)
        for i in range(n):
            x_offset = (i - (n // 2)) * spread / n
            y_offset = abs(i - (n // 2)) * (spread / n)

            # Position above the hvu_to_attacker_vector
            position = center + x_offset * perp_vector - y_offset * hvu_to_attacker_vector
            positions.append((position[0], position[1]))

        return positions

    
    def make_formation(self, defenders, target_positions):
        """
        Move each defender to their assigned position and check if the formation is complete.
        """
        line_formed = [ship.move_ship_to_coordinates(pos) for ship, pos in zip(defenders, target_positions)]
        return all(line_formed)
    
    
    def find_escape_direction(self, hvu_position, attacker_position):
        """
        Calculate the movement direction opposite to the attacker.
        """
        # Vector from the leader to the attacker
        hvu_to_attacker_vector = attacker_position - hvu_position
        # Normalize the vector, ensuring it's of float type
        hvu_to_attacker_vector = hvu_to_attacker_vector.astype(np.float64) / np.linalg.norm(hvu_to_attacker_vector)

        return -hvu_to_attacker_vector  # Move in the opposite direction
                    
    
    def move_unit_away_from_attacker(self, direction, safe_distance=300):
        """
        Move the entire unit (defenders and HVU) in the opposite direction from the attacker.
        Return True if the unit is in a safe zone (distance >= safe_distance).
        """
        # Move all defenders in the specified direction
        for ship in self.env.defender_ships:
            new_position = ship.get_position() + direction * ship.speed
            ship.set_position(new_position[0], new_position[1])

        # Move the HVU in the same direction
        new_hvu_position = self.env.hvu.get_position() + direction * self.env.hvu.speed
        self.env.hvu.set_position(new_hvu_position[0], new_hvu_position[1])

        # Check if the unit is far enough from the attacker
        unit_distance_from_attacker = np.linalg.norm(new_hvu_position - self.att_revealed_pos)
        return unit_distance_from_attacker >= safe_distance  

In [18]:
class FireMechanism:
    def __init__(self, environment):
        self.env = environment
        
    def _handle_firing_mechanics(self):
        reward = 0
        
        # Handle Attacker firing at HVU
        # Check if the attacker is within firing range and has a clear line of sight
        if self.env.attacker_ship.target_in_range(self.env.hvu.get_position()):
            reward += 2  # Reward for being hvu in range of Attacker firing range
            if self.env.check_los_attacker():
                reward += 5 # Reward for LOS
                reward += self._handle_attacker_firing()
     

        # Handle defender firing at attacker
        for defender in self.env.defender_ships:
            if defender.target_in_range(self.env.attacker_ship.get_position()):
                reward -= 3 # Penalty for being in defender range
                if self.env.check_los_defender(defender):
                    reward -= 5 # Penalty for being in defender LOS
                    reward -= self._handle_defender_firing(defender)
                    
        return reward
    
    
    def _handle_attacker_firing(self):
        reward = 0  # Base reward for being in line of sight and firing range.
        if self.validate_and_fire(self.env.attacker_ship, self.env.hvu):
            self.env.attacker_fired = True
        reward += 20  # Additional reward for firing
        return reward
    
    
    def _handle_defender_firing(self, defender):
        penalty = 0 
        # Defender only fires when attacker is firing
        if self.env.attacker_fired:
            if self.validate_and_fire(defender, self.env.attacker_ship): # check and Fire the torpedo towards target
                penalty += 20
        return penalty
    
    
    def validate_and_fire(self, ship, target):
        """Fires a torpedo at the target if within range and cooldown periods are satisfied."""
        current_time = time.time()
        torpedo_fired = False

        if current_time - ship.last_fire_time >= ship.reload_delay:  # Check if ship can fire
            if ship.target_lock_time == 0:  # Locking target for the first time
                ship.target_lock_time = current_time
            
            if current_time - ship.target_lock_time >= ship.target_delay:  # Fire if lock complete
                torpedo_fired = self.fire_torpedo(ship, target)
        
        return torpedo_fired
                
    def fire_torpedo(self, ship, target):
        """Fire a torpedo towards the target if available."""
        if ship.torpedo_count > 0:
            torpedo = Torpedo(
                torpedo_id=f"T{ship.ship_id}_{ship.torpedo_count}",
                x=ship.x, y=ship.y,
                speed=ship.torpedo_fire_speed,
                damage=ship.torpedo_damage,
                direction=np.arctan2(target.y - ship.y, target.x - ship.x),
                source=ship, target=target,
                screen_width=self.env.width, screen_height=self.env.height
            )

            ship.torpedoes.append(torpedo)
            ship.torpedo_count -= 1
            ship.last_fire_time = time.time()  # Update last fire time
            ship.target_lock_time = 0  # Reset target lock
            return True
        return False
    
    
    def _update_torpedo_position(self):
        """Update the torpedo positions and calculate rewards/penalties for the attacker and defenders."""
        # Update attacker torpedo positions and calculate rewards
        attacker_reward, att_target_destroyed = self.update_torpedo(self.env.attacker_ship, self.env.ships)
        self.env.reward += attacker_reward

        # Check if the attacker destroyed the HVU
        if att_target_destroyed:
            self.env.info['HVU destroyed'] += 1
            self.env.done = True
            return

        # Update defender torpedo positions and calculate penalties
        for defender in self.env.defender_ships:
            penalty, def_target_destroyed = self.update_torpedo(defender, self.env.ships)
            self.env.reward -= penalty

            # Check if any defender destroyed the attacker
            if def_target_destroyed:
                self.env.info['attacker destroyed'] += 1
                self.env.done = True
                return
    
    
    def update_torpedo(self, ship, env_ships, threshold=10):
        """Update torpedo positions, check for collisions, and determine rewards."""
        reward = 0
        target_destroyed = False

        for torpedo in ship.torpedoes[:]:
            torpedo.move()

            # Remove torpedoes that leave the screen or hit a target
            if not torpedo.within_bounds() or torpedo.check_collision(env_ships, threshold):
                ship.torpedoes.remove(torpedo)
                continue

            # Handle hitting the target
            r, target_destroyed = torpedo.hit_target(threshold)
            reward += r
            if torpedo.target_hit:
                ship.torpedoes.remove(torpedo)
            
            if target_destroyed:  # Stop further processing if target is destroyed
                break

        return reward, target_destroyed
    
    
    
    def check_collisions(self):
        """Check for collisions between attacker and HVU or defenders."""

        # Get the position of the attacker ship
        attacker_pos = self.env.attacker_ship.get_position()

        # Check for collision between attacker and the HVU (central ship)
        hvu_pos = self.env.hvu.get_position()
        if self.check_collision(attacker_pos, hvu_pos):
            print(f"Collision detected between the attacker and the HVU! Both ships are destroyed.")
            self.env.attacker_ship.ship_health = 0
            self.env.hvu.ship_health = 0
            return True

        # Check for collision between attacker and any of the defender ships
        for defender_ship in self.env.defender_ships:
            defender_pos = defender_ship.get_position()
            if self.check_collision(attacker_pos, defender_pos):
                print(f"Collision detected between the attacker and defender {defender_ship.ship_id}! Both ships are destroyed.")
                self.env.attacker_ship.ship_health = 0
                defender_ship.ship_health = 0
                return True

        return False
    

    def check_collision(self, pos1, pos2, collision_range=20):
        """
        Check if two entities are within collision range.
        """
        return np.linalg.norm([pos2[0] - pos1[0], pos2[1] - pos1[1]]) < collision_range
    

In [19]:
class DecoyMissileManager:
    def __init__(self, env):
        self.env = env
        
    def handle_DecoyM_firing_mechanics(self):
        if self.env.attacker_ship.torpedoes:  # Check if attacker ship has active torpedo
            attacker_ship_active_torpedoes = self.env.attacker_ship.torpedoes 
            for active_torpedo in attacker_ship_active_torpedoes:
                torpedo_pos = active_torpedo.x , active_torpedo.y

                # Check if attacker_torpedo is in decoy missile range if it is then launch missile
                for defender in self.env.defender_ships:
                    if defender.ship_type == 'def_decoyM':
                        defender_pos = defender.get_position()
                        if np.linalg.norm(defender.get_position() - torpedo_pos) <= defender.firing_range:
                            self.validate_and_fire(defender, active_torpedo)
    
    
    def validate_and_fire(self, ship, target):
        """Fires a torpedo at the target if within range and cooldown periods are satisfied."""
        current_time = time.time()
        decoy_missile_fired = False

        if current_time - ship.last_decoy_fire_time >= ship.reload_delay:  # Check if ship can fire
            if ship.decoy_target_lock_time == 0:  # Locking target for the first time
                ship.decoy_target_lock_time = current_time
            
            if current_time - ship.decoy_target_lock_time >= ship.target_delay:  # Fire if lock complete
                decoy_missile_fired = self.fire_missile(ship, target)
        
        return decoy_missile_fired
    
    
    def fire_missile(self, ship, target):
        
        if ship.decoyM_count > 0:
            decoy_missile = {
                'id': f"DM{ship.ship_id}_{ship.decoyM_count}",
                'x': ship.x, 'y': ship.y,
                'speed': ship.decoyM_speed,
                'source': ship, 'target': target,
                'target_hit': False
            }
            
            ship.decoy_missile.append(decoy_missile)
            ship.decoyM_count -=1
            ship.last_decoy_fire_time = time.time()  # Update last fire time
            ship.decoy_target_lock_time = 0  # Reset target lock
            return True        
        return False
    

    def update_decoy_missile(self):
        for ship in self.env.ships:
            if ship.ship_type == 'def_decoyM': # check if it's decoy missile defender ship
                
                if not ship.decoy_missile: # if if there is active missile 
                    continue
                
                for missile in ship.decoy_missile:
                    
                    # Move the decoy missile
                    target_torpedo_pos = np.array([missile['target'].x, missile['target'].y])
                    missile_pos = np.array([missile['x'], missile['y']])
                    
                    direction = target_torpedo_pos - missile_pos
                    distance = np.linalg.norm(direction)

                    if distance != 0:
                        direction = direction / distance
                        missile['x'] += missile['speed'] * direction[0]
                        missile['y'] += missile['speed'] * direction[1]

                     
                    
                    # Remove Missile that leave the screen or hit a target
                    if not self.within_bounds(missile) or self.check_collision(self.env.ships, missile, ship.decoyM_blast_range) or missile['target'] not in self.env.attacker_ship.torpedoes: 
                        ship.decoy_missile.remove(missile)
                        continue
                
                    
                    
                    # Handle hitting the target torpedo
                    if distance < ship.decoyM_blast_range: #check if target position in distroy range of decoy missile
                        
                        missile['target'].source.torpedoes.remove(missile['target']) # distroy the target torpedo
                        ship.decoy_missile.remove(missile) # distroy the decoy Missile
                        missile['target_hit'] = True
                        continue
                    

    def within_bounds(self, missile):
        return (
            0 <= float(missile['x']) <= self.env.width and
            0 <= float(missile['y']) <= self.env.height
        )
    
    
    def check_collision(self, ships, missile, threshold=10):
        """
        Check for collisions with ships other than then target.
        Uses vector projection to detect if the torpedo is on a collision course.
        Returns True if a collision is detected, False otherwise.
        """
        for ship in ships:
            
            if ship.ship_id == missile['source'].ship_id: # don't check for source of decoy missile
                continue
            
            ship_position = np.array(ship.get_position())
            missile_pos = np.array([missile['x'], missile['y']])
            
            direction = ship_position - missile_pos
            distance = np.linalg.norm(direction) # Current Direction of missile
            
            if distance < threshold:
                return True # Collision detected
        return False # No collision detected


In [20]:
class HelicopterManager:
    """
    Manages the movement of helicopters launched from defender ships.
    Each helicopter:
    1. Takes off from its defender ship.
    2. Moves towards the HVU and enters the circular path at the nearest point.
    3. Completes a full circle around the HVU.
    4. Returns to its base defender ship before the next helicopter takes its turn.
    """
    def __init__(self, env):
        self.env = env
        
        self.helicopter_position = None
        
        self.helicopter_index = 0
        self.helicopter_angle = 0
        self.helicopter_active = False
        self.helicopter_takeoff_position = None  # Stores takeoff position for accurate return
        self.helicopter_state = None
        
    
    def move_defenders_helicop(self):
        if not self.env.defender_ships:
            return

        # Get the current active defender and calculate required positions
        current_defender = self.env.defender_ships[self.helicopter_index]
        defender_pos = np.array(current_defender.get_position(), dtype=np.float64)
        hvu_pos = np.array(self.env.hvu.get_position(), dtype=np.float64)
        circular_entry_pos, norm_direction = self._calculate_circular_entry(defender_pos, hvu_pos)

        # If helicopter is inactive, launch it
        if not self.helicopter_active:
            self._launch_helicopter(current_defender, defender_pos)
        
        # Move helicopter according to its current state
        if self.helicopter_active:
            self._move_helicopter(current_defender, hvu_pos, defender_pos, circular_entry_pos, norm_direction)
        

    def _calculate_circular_entry(self, defender_pos, hvu_pos):
        """ Calculates the nearest entry point to the circular path from the defender's position. """
        direction = defender_pos - hvu_pos
        norm_direction = direction / np.linalg.norm(direction)  # Normalize direction
        return hvu_pos + norm_direction * self.env.helicop_path_radius, norm_direction

    def _launch_helicopter(self, current_defender, defender_pos):
        """ Launches the helicopter from the defender ship and stores its takeoff position. """
        if current_defender.helicop_count > 0:
            current_defender.helicop_count -= 1
            self.helicopter_position = defender_pos 
            self.helicopter_active = True
            self.helicopter_angle = np.degrees(np.arctan2(defender_pos[1] - self.env.hvu.get_position()[1], 
                                                           defender_pos[0] - self.env.hvu.get_position()[0]))
            self.helicopter_state = "takeoff"
        
        else:
            self.helicopter_index = (self.helicopter_index + 1) % len(self.env.defender_ships)

    def _move_helicopter(self, current_defender, hvu_pos, defender_pos, circular_entry_pos, norm_direction):
        """ Controls the movement of the helicopter based on its current state. """
        state_actions = {
            "takeoff": lambda: self._move_to_entry_point(current_defender, circular_entry_pos),
            "circle": lambda: self._move_in_circle(current_defender, hvu_pos, norm_direction),
            "return": lambda: self._return_to_base(current_defender, defender_pos)
        }
        if self.helicopter_state in state_actions:
            state_actions[self.helicopter_state]()

    def _move_to_entry_point(self, current_defender, circular_entry_pos):
        """ Moves the helicopter to the nearest circular entry point. """
        direction = circular_entry_pos - self.helicopter_position
        distance = np.linalg.norm(direction)
        if distance > self.env.helicop_speed:
            self.helicopter_position += (direction / distance) * self.env.helicop_speed
        else:
            self.helicopter_state = "circle"
            self.helicopter_position = circular_entry_pos
            self.helicopter_angle = np.degrees(np.arctan2(self.helicopter_position[1] - self.env.hvu.get_position()[1],
                                                           self.helicopter_position[0] - self.env.hvu.get_position()[0]))

    def _move_in_circle(self, current_defender, hvu_pos, norm_direction):
        """ Moves the helicopter in a circular path around the HVU. """
        self.helicopter_angle += self.env.helicop_speed
        angle_rad = np.radians(self.helicopter_angle)
        self.helicopter_position = hvu_pos + self.env.helicop_path_radius * np.array([np.cos(angle_rad), np.sin(angle_rad)])

        # Check if a full loop is completed, return to base
        if self.helicopter_angle >= 360 + np.degrees(np.arctan2(norm_direction[1], norm_direction[0])):
            self.helicopter_state = "return"

    def _return_to_base(self, current_defender, defender_pos):
        """ Returns the helicopter to its defender ship. """
        direction = defender_pos - self.helicopter_position
        distance = np.linalg.norm(direction)
        if distance >= 10:
            self.helicopter_position += (direction / distance) * self.env.helicop_speed
        else:
            self.helicopter_position = defender_pos
            self.helicopter_active = False
            self.helicopter_index = (self.helicopter_index + 1) % len(self.env.defender_ships)
            self.helicopter_state = "takeoff"
            current_defender.helicop_count += 1

In [21]:
class MapGenerator:
    
    def __init__(self, env, map_center=[3.0000, 86.0000], zoom=6):
        
        self.env = env
        self.map_center = map_center  # Store center lat/lon
        self.zoom = zoom  # Store zoom level
        self.map_initialized = False  # Flag to check if the map has been initialized
        self.map_image_path = None
        self.map_image = None

        # Initialize the map once
        self.map_initialise()
          
    def render_background(self):
        """Renders a black background instead of the real-world map."""
        self.env.screen.fill((0, 0, 0))  # Fill screen with black color
        # Optionally, you can still draw the grid over the black background
        # self._draw_grid(self.env.screen)  # Draw grid over the black background

    def map_initialise(self):
        """Skip map initialization since we're using black background."""
        if self.map_initialized:
            return  # Prevent re-initialization if already done

        # Skip map generation - we're using black background
        self.map_image = None
        self.map_initialized = True  # Mark initialization complete

#     def _generate_map(self):
#         """Generates a real-world map with real latitude-longitude grid lines."""
#         html_map_file = "map.html"
#         png_map_file = "map.png"

#         # Generate Folium map
#         folium_map = folium.Map(location=self.map_center, zoom_start=self.zoom)

#         # Add real-world latitude/longitude grid lines
# #         lat_spacing, lon_spacing = self._get_grid_spacing(self.zoom)

# #         for lat in np.arange(-90, 90, lat_spacing):
# #             folium.PolyLine([(lat, -180), (lat, 180)], color="blue", weight=0.5).add_to(folium_map)

# #         for lon in np.arange(-180, 180, lon_spacing):
# #             folium.PolyLine([(-90, lon), (90, lon)], color="blue", weight=0.5).add_to(folium_map)

#         # ✅ Add latitude & longitude popups to verify accuracy
#         folium_map.add_child(folium.LatLngPopup())

#         folium_map.save(html_map_file)

#         # Convert HTML map to PNG using Selenium
#         options = webdriver.ChromeOptions()
#         options.add_argument("--headless")
#         driver = webdriver.Chrome(options=options)

#         driver.get(f"file://{os.path.abspath(html_map_file)}")
#         driver.set_window_size(env.width, env.height)
#         driver.save_screenshot(png_map_file)
#         driver.quit()

#         return png_map_file


    def _generate_map(self):
        """Generates a real-world map with Google Maps tiles and real latitude-longitude grid lines."""

        html_map_file = "map.html"
        png_map_file = "map.png"

        # Use Google Maps API tiles
        GOOGLE_MAPS_API_KEY = "AIzaSyB6Uu3uJi5BBIfBjQCRc_Tvmqdd_Gd9vtg"

        # Create a Folium map using Google Maps tiles
        folium_map = folium.Map(
            location=self.map_center,
            zoom_start=self.zoom,
            tiles=f"https://mt1.google.com/vt/lyrs=m&x={{x}}&y={{y}}&z={{z}}&key={GOOGLE_MAPS_API_KEY}",
            attr="Google Maps"
        )

        # Add latitude & longitude popups for accuracy verification
        folium_map.add_child(folium.LatLngPopup())

        # Save the map as an HTML file
        folium_map.save(html_map_file)

        # Convert HTML map to PNG using Selenium
        options = webdriver.ChromeOptions()
        options.add_argument("--headless")
        driver = webdriver.Chrome(options=options)

        driver.get(f"file://{os.path.abspath(html_map_file)}")
        driver.set_window_size(self.env.width, self.env.height)
        time.sleep(1)  # Wait for 2 seconds to allow tile images to load
        driver.save_screenshot(png_map_file)
        driver.quit()

#         print(f" Map generated and saved as {png_map_file}")
        return png_map_file

    def _get_grid_spacing(self, zoom):
        """Determines latitude/longitude grid spacing based on zoom level."""
        if zoom >= 12:
            return 0.01, 0.01  
        elif zoom >= 10:
            return 0.1, 0.1  
        elif zoom >= 8:
            return 0.5, 0.5  
        elif zoom >= 6:
            return 1, 1  
        elif zoom >= 4:
            return 5, 5  
        else:
            return 10, 10  

    def _latlon_to_pixels(self, lat, lon):
        """Converts latitude & longitude to pixel positions using Mercator projection."""

        # Prevents infinite values near the poles
        lat = np.clip(lat, -89.9999, 89.9999)  

        # Convert lat/lon to Mercator projection
        def lat_to_mercator(lat):
            return math.log(math.tan(math.pi / 4 + math.radians(lat) / 2))

        # Google Maps uses linear scaling for longitude
        def lon_to_mercator(lon):
            return math.radians(lon)

        # Map dimensions
        screen_width, screen_height = self.env.width, self.env.height

        # Convert map center to Mercator projection
        center_x = lon_to_mercator(self.map_center[1])
        center_y = lat_to_mercator(self.map_center[0])

        # Scale factor for given zoom level
        scale = 256 * 2**self.zoom  

        # Convert lat/lon to pixel coordinates
        x = (lon_to_mercator(lon) - center_x) * scale + screen_width / 2
        y = (center_y - lat_to_mercator(lat)) * scale + screen_height / 2

        # Debugging Output
#         print(f"Lat: {lat:.2f}, Lon: {lon:.2f} -> X: {x:.2f}, Y: {y:.2f}")

        return int(x), int(y)

    
    def _draw_grid(self, screen, grid_color=(200, 200, 200)):
        """
        Draws real-world latitude & longitude grid lines with labeled coordinates.
        """

        # Determine grid spacing dynamically based on zoom level
        lat_spacing, lon_spacing = self._get_grid_spacing(self.zoom)

        # Screen dimensions
        screen_width, screen_height = self.env.width, self.env.height

        # Store grid labels to avoid overlapping
        grid_labels = []

        # Draw latitude lines (horizontal)
        for lat in np.arange(-90, 90, lat_spacing):
            start_x, start_y = self._latlon_to_pixels(lat, -180)
            end_x, end_y = self._latlon_to_pixels(lat, 180)

            pygame.draw.line(screen, grid_color, (start_x, start_y), (end_x, end_y), 1)

            # Add latitude label at the left side of the map
            if 0 <= start_y < screen_height:
                grid_labels.append((10, start_y - 5, f"{lat:.1f}°"))

        # Draw longitude lines (vertical)
        for lon in np.arange(-180, 180, lon_spacing):
            start_x, start_y = self._latlon_to_pixels(-90, lon)
            end_x, end_y = self._latlon_to_pixels(90, lon)

            pygame.draw.line(screen, grid_color, (start_x, start_y), (end_x, end_y), 1)

            # Add longitude label at the top of the map
            if 0 <= start_x < screen_width:
                grid_labels.append((start_x + 5, 10, f"{lon:.1f}°"))

        # Render latitude & longitude labels
        font = pygame.font.SysFont('Arial', 12, bold=True)
        for x, y, text in grid_labels:
            text_surface = font.render(text, True, (255, 255, 255))  # White text
            screen.blit(text_surface, (x, y))

    
    def pixels_to_latlon(self, x, y):
        """Converts pixel (x, y) coordinates back to latitude & longitude using Mercator projection."""

        # Convert latitude from Mercator projection back to degrees
        def mercator_to_lat(merc_y):
            return math.degrees(2 * math.atan(math.exp(merc_y)) - math.pi / 2)

        # Google Maps uses linear scaling for longitude
        def mercator_to_lon(merc_x):
            return merc_x * 180.0  # Convert back to longitude range [-180, 180]

        # Map dimensions
        screen_width, screen_height = self.env.width, self.env.height

        # Scale factor for given zoom level
        scale = 256 * 2**self.zoom  

        # Convert map center to Mercator projection
        center_x = self.map_center[1] / 180.0  # Linear longitude conversion
        center_y = math.log(math.tan((math.pi / 4) + (math.radians(self.map_center[0]) / 2)))

        # Convert pixel coordinates back to Mercator projection
        merc_x = center_x + (x - screen_width / 2) / scale
        merc_y = center_y - (y - screen_height / 2) / scale

        # Convert Mercator values back to lat/lon
        lat = mercator_to_lat(merc_y)
        lon = mercator_to_lon(merc_x)

        return lat, lon
    
    
    def _render_latlon(self, x, y, rect_size):
        """
        Render pixel coordinates below the ship at (x, y) position.
        """
        # Display pixel coordinates instead of lat/lon
        pixel_text = f"({int(x)}, {int(y)})"

        # Render pixel coordinates text
        font = pygame.font.SysFont('Arial', 10, bold=True)
        text_surface = font.render(pixel_text, True, (255, 255, 255))  # White text for visibility on black background

        # Get text width & height for centering
        text_width, text_height = text_surface.get_size()

        # Position text **centered below** the ship
        text_x = int(x) - text_width // 2  # Center horizontally
        text_y = int(y) + rect_size // 2 + 5  # Just below the ship, with padding

        # Blit the text onto the screen
        self.env.screen.blit(text_surface, (text_x, text_y))


In [22]:
class UIManager:
    def __init__(self, env):
        self.env = env
        self.title_font = pygame.font.SysFont('Arial', 26, bold=True)
        self.label_font = pygame.font.SysFont('Arial', 18)
        self.input_font = pygame.font.SysFont('Arial', 18)
        

    def get_user_input_screen(self):
        """
            Displays the user input screen and returns defender setup configuration.
            Returns:
                num_defenders (int): Total number of defender ships
                use_custom (bool): Whether user customized ship configuration
                num_sonar (int): Ships with sonar
                num_heli (int): Ships with helicopter
                num_decoy (int): Ships with decoy
                selected_formation (str): "triangle" or "circle"
            """

        user_inputs, input_boxes, colors = self._initialize_inputs()

        checkbox_checked = False
        error_message = ""
        done = False

        while not done:
            self.env.screen.fill((240, 240, 240))
            self._render_title("Environment Details")

            self._render_checkbox(checkbox_checked)
            self._render_formation_selection(user_inputs)  

            submit_button = self._render_inputs(user_inputs, input_boxes, colors, checkbox_checked)

            if error_message:
                self._render_error(error_message, submit_button.y + 50)

            pygame.display.flip()
            pygame.time.Clock().tick(30)

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    quit()

                elif event.type == pygame.MOUSEBUTTONDOWN:
                    checkbox_checked, input_boxes, colors = self._handle_mouse_click(
                        event.pos, checkbox_checked, input_boxes
                    )
                    self._handle_formation_click(event.pos, user_inputs)  

                    if submit_button.collidepoint(event.pos):
                        done, error_message = self._validate_inputs(user_inputs, checkbox_checked)

                elif event.type == pygame.KEYDOWN:
                    user_inputs, colors = self._handle_key_input(event, user_inputs, input_boxes, colors)

        return self._finalize_user_inputs(user_inputs, checkbox_checked)


    
    def _initialize_inputs(self):
        user_inputs = {
            "total_num_def": str(self.env.num_defenders),
            "num_sonar_def": str(self.env.num_def_with_sonar),
            "num_heli_def": str(self.env.num_def_with_helicopter),
            "num_decoy_def": str(self.env.num_def_with_decoy),
            "num_default_def": str(self.env.num_default_def),
            "map_center": f"{self.env.map_center[0]}, {self.env.map_center[1]}",
            "zoom": str(self.env.zoom),
            "base_location": f"{self.env.base_location[0]}, {self.env.base_location[1]}",
            "def_default_formation": self.env.def_default_formation,
            "def_moving_formation": self.env.def_moving_formation
        }

        input_boxes = {
            "total_num_def": pygame.Rect(100, 120, 180, 35),
            "num_sonar_def": pygame.Rect(100, 390, 180, 35),
            "num_heli_def": pygame.Rect(100, 460, 180, 35),
            "num_decoy_def": pygame.Rect(380, 390, 180, 35),
            "num_default_def": pygame.Rect(380, 460, 180, 35),
            "map_center": pygame.Rect(100, 200, 220, 35),
            "zoom": pygame.Rect(380, 200, 220, 35),
            "base_location": pygame.Rect(380, 120, 220, 35)  
        }

        colors = {key: pygame.Color('lightskyblue3') for key in user_inputs}
        return user_inputs, input_boxes, colors


    def _render_title(self, text):
        instructions = self.title_font.render(text, True, (0, 0, 0))  # Render black text
        self.env.screen.blit(instructions, (100, 50))  # Draw at fixed top position


    def _render_checkbox(self, checked):
        self.checkbox_rect = pygame.Rect(100, 285, 20, 20)
        pygame.draw.rect(self.env.screen, (0, 0, 0), self.checkbox_rect, 2)  # Draw outer border

        if checked:
            pygame.draw.rect(self.env.screen, (0, 0, 0), self.checkbox_rect.inflate(-8, -8))  # Fill if checked

        # Label beside checkbox
        label = self.label_font.render("Customize Defender Ship(Optional):", True, (0, 0, 0))
        self.env.screen.blit(label, (100, 255))


#     def _render_formation_selection(self, inputs):
#         label = self.label_font.render("Select Defender Default Formation:", True, (0, 0, 0))
#         self.env.screen.blit(label, (700, 95))

#         start_x = 700
#         start_y = 130
#         gap_y = 40

#         self.formation_rects = []

#         for i, formation in enumerate(self.env.avail_def_default_formations):
#             rect = pygame.Rect(start_x, start_y + i * gap_y, 20, 20)
#             self.formation_rects.append((formation, rect))

#             pygame.draw.rect(self.env.screen, (0, 0, 0), rect, 2)

#             if inputs["def_default_formation"] == formation:
#                 pygame.draw.rect(self.env.screen, (0, 0, 0), rect.inflate(-8, -8))

#             self.env.screen.blit(self.label_font.render(formation.capitalize(), True, (0, 0, 0)), (rect.x + 30, rect.y))
    
    
    def _render_formation_selection(self, inputs):
        # First render the existing defender default formation
        label = self.label_font.render("Select Defender Default Formation:", True, (0, 0, 0))
        self.env.screen.blit(label, (680, 95))

        start_x = 700
        start_y = 130
        gap_y = 40

        self.formation_rects = []

        for i, formation in enumerate(self.env.avail_def_default_formations):
            rect = pygame.Rect(start_x, start_y + i * gap_y, 20, 20)
            self.formation_rects.append((formation, rect))

            pygame.draw.rect(self.env.screen, (0, 0, 0), rect, 2)

            if inputs["def_default_formation"] == formation:
                pygame.draw.rect(self.env.screen, (0, 0, 0), rect.inflate(-8, -8))

            self.env.screen.blit(self.label_font.render(formation.capitalize(), True, (0, 0, 0)), (rect.x + 30, rect.y))

        # --- Now Render Moving Formation ---
        move_label = self.label_font.render("Select Defender Moving Formation:", True, (0, 0, 0))
        self.env.screen.blit(move_label, (680, start_y + len(self.env.avail_def_default_formations) * gap_y + 20))

        start_y_moving = start_y + len(self.env.avail_def_default_formations) * gap_y + 55

        self.moving_formation_rects = []

        move_formations = ["triangle", "circle", "semicircle", "wedge", "line"]

        for i, formation in enumerate(move_formations):
            rect = pygame.Rect(start_x, start_y_moving + i * gap_y, 20, 20)
            self.moving_formation_rects.append((formation, rect))

            pygame.draw.rect(self.env.screen, (0, 0, 0), rect, 2)

            if inputs["def_moving_formation"] == formation:
                pygame.draw.rect(self.env.screen, (0, 0, 0), rect.inflate(-8, -8))

            self.env.screen.blit(self.label_font.render(formation.capitalize(), True, (0, 0, 0)), (rect.x + 30, rect.y))


    
    def _render_inputs(self, inputs, boxes, colors, custom_enabled):
        labels = {
            "total_num_def": "Total Defender Ships:",
            "num_sonar_def": "Defenders with Sonar:",
            "num_heli_def": "Defenders with Helicopter:",
            "num_decoy_def": "Defenders with Decoy Missile:",
            "num_default_def": "Default(generic) Defender Ships:",
            "map_center": "Map Center (Lat,Long):",
            "zoom": "Zoom Level:",
            "base_location": "Base Location (Lat,Long):"
        }

        self._draw_input_box(inputs["total_num_def"], boxes["total_num_def"], labels["total_num_def"], colors["total_num_def"])

        if custom_enabled:
            self.env.screen.blit(
                self.title_font.render("Enter Counts For Each Defender Ship:", True, (0, 0, 0)),
                (100, 320)
            )
            for key in ["num_sonar_def", "num_heli_def", "num_decoy_def", "num_default_def"]:
                self._draw_input_box(inputs[key], boxes[key], labels[key], colors[key])

        for key in ["map_center", "zoom", "base_location"]:
            self._draw_input_box(inputs[key], boxes[key], labels[key], colors[key])


        submit_rect = pygame.Rect(100, 520 if custom_enabled else 330, 180, 40)
        pygame.draw.rect(self.env.screen, (0, 150, 0), submit_rect)
        text = self.input_font.render("Start Simulation", True, (255, 255, 255))
        self.env.screen.blit(text, (submit_rect.x + 25, submit_rect.y + 10))

        return submit_rect


    def _draw_input_box(self, text, rect, label, color):
        pygame.draw.rect(self.env.screen, color, rect, 2)
        self.env.screen.blit(self.label_font.render(label, True, (0, 0, 0)), (rect.x, rect.y - 25))
        self.env.screen.blit(self.input_font.render(text, True, (0, 0, 0)), (rect.x + 5, rect.y + 5))


    def _render_error(self, message, y_pos):
        font = pygame.font.SysFont('Arial', 16)
        error_surface = font.render(message, True, (200, 0, 0))
        self.env.screen.blit(error_surface, (100, y_pos))


    def _handle_mouse_click(self, pos, checkbox_state, input_boxes):
        if self.checkbox_rect.collidepoint(pos):
            checkbox_state = not checkbox_state

        colors = {}
        for key, rect in input_boxes.items(): 
            active = rect.collidepoint(pos)
            colors[key] = pygame.Color('dodgerblue2') if active else pygame.Color('lightskyblue3')

        return checkbox_state, input_boxes, colors

    
    def _handle_formation_click(self, pos, inputs):
        for formation, rect in self.formation_rects:
            if rect.collidepoint(pos):
                inputs["def_default_formation"] = formation
                return  # Select only one at a time

        for formation, rect in self.moving_formation_rects:
            if rect.collidepoint(pos):
                inputs["def_moving_formation"] = formation
                return
            

    def _validate_inputs(self, inputs, custom):
        if not inputs["total_num_def"].isdigit():
            return False, "Enter a valid number for total defender ships."

        if custom and not all(inputs[key].isdigit() for key in ["num_sonar_def", "num_heli_def", "num_decoy_def", "num_default_def"]):
            return False, "All custom defender values must be valid numbers."

        if custom and sum(int(inputs[key]) for key in ["num_sonar_def", "num_heli_def", "num_decoy_def", "num_default_def"]) != int(inputs["total_num_def"]):
            return False, "Sum of custom types must equal total defender ships."

        try:
            lat, lon = map(float, inputs["map_center"].split(','))
            if not (-90 <= lat <= 90 and -180 <= lon <= 180):
                return False, "Map coordinates must be valid lat,long."
        except Exception:
            return False, "Map coordinates must be in lat,long format."

        try:
            base_lat, base_lon = map(float, inputs["base_location"].split(','))
            if not (-90 <= base_lat <= 90 and -180 <= base_lon <= 180):
                return False, "Base location coordinates must be valid lat,long."
        except Exception:
            return False, "Base location must be in lat,long format."


        if not inputs["zoom"].isdigit():
            return False, "Zoom level must be a valid number."

        if not (1 <= int(inputs["zoom"]) <= 12):
            return False, "Zoom level must be between 1 and 12."

        return True, ""



    def _handle_key_input(self, event, inputs, boxes, colors):
        for key, rect in boxes.items():
            if colors[key] == pygame.Color('dodgerblue2'):  # Active box
                if event.key == pygame.K_RETURN:
                    colors[key] = pygame.Color('lightskyblue3')  # Deactivate on Enter
                elif event.key == pygame.K_BACKSPACE:
                    inputs[key] = inputs[key][:-1]
                elif event.unicode.isdigit() or event.unicode in ['.', ',']:
                    inputs[key] += event.unicode
        return inputs, colors

    
    
    def _finalize_user_inputs(self, inputs, custom):
        num_defenders = int(inputs["total_num_def"])
        num_sonar = int(inputs["num_sonar_def"]) if custom else 0
        num_heli = int(inputs["num_heli_def"]) if custom else 0
        num_decoy = int(inputs["num_decoy_def"]) if custom else 0
        num_default = int(inputs["num_default_def"]) if custom else 0
        lat, lon = map(float, inputs["map_center"].split(','))
        zoom = int(inputs["zoom"])
        formation = inputs["def_default_formation"] 
        moving_formation = inputs["def_moving_formation"]
        base_lat, base_lon = map(float, inputs["base_location"].split(','))
        
        return num_defenders, custom, num_sonar, num_heli, num_decoy, num_default, formation, [lat, lon], zoom, [base_lat, base_lon], moving_formation


In [23]:
class Renderer:
    def __init__(self, env, map_generator):
        """Renderer for displaying real-world maps with accurate latitude/longitude grid lines."""
        self.env = env 
        self.map_generator = map_generator

    
    def _render_ship(self, ship, color=(255, 255, 255), rect_size=20, line_thickness=1):
        """
        Render a ship as a rectangle and draw its firing range as a circle.
        Also displays the latitude and longitude of the ship below it.
        """
        if ship.ship_health > 0:
            x, y = ship.get_position()

            # Draw ship (rectangle) centered on its position
            pygame.draw.rect(self.env.screen, color, (int(x) - rect_size // 2, int(y) - rect_size // 2, rect_size, rect_size))

            # Draw firing range (circle) around the ship
            pygame.draw.circle(self.env.screen, color, (int(x), int(y)), ship.firing_range, line_thickness)
            
            # Render lat/lon text below the ship
            self.map_generator._render_latlon(x, y, rect_size)

    
    def _render_torpedoes(self, ship, color=(255, 255, 255)):
        """
        Render all active torpedoes fired by a ship.
        """
        for torpedo in ship.torpedoes:
            if not torpedo.target_hit:
                pygame.draw.circle(self.env.screen, color, (int(torpedo.x), int(torpedo.y)), 5)

    
    def _render_decoyM(self, ship, color=(255, 165, 0)):
        """
        Render all active Decoy missile fired by a ship.
        """
        for missile in ship.decoy_missile:
            if not missile['target_hit']:
                pygame.draw.circle(self.env.screen, color, (int(missile['x']), int(missile['y'])), 5)

    
    def _draw_dotted_line(self, start_pos, end_pos, color=(255, 255, 255), width=3, segment_length=5, space_length=5):
        """
        Draw a dotted line between two ships' positions.
        """
        x1, y1 = start_pos.get_position()
        x2, y2 = end_pos.get_position()
        total_length = np.linalg.norm([x2 - x1, y2 - y1])
        num_segments = int(total_length // (segment_length + space_length))
        
        for i in range(num_segments):
            start_segment = (
                x1 + (x2 - x1) * (i / num_segments),
                y1 + (y2 - y1) * (i / num_segments)
            )
            end_segment = (
                x1 + (x2 - x1) * ((i + segment_length / (segment_length + space_length)) / num_segments),
                y1 + (y2 - y1) * ((i + segment_length / (segment_length + space_length)) / num_segments)
            )
            pygame.draw.line(self.env.screen, color, start_segment, end_segment, width)

    
    def _render_attacker(self):
        """
        Render the attacker ship and its torpedoes, along with the attacker's firing line if the HVU is in range.
        """
        self._render_ship(self.env.attacker_ship, color=(255, 0, 0))  # Render attacker (red)
        self._render_torpedoes(self.env.attacker_ship, color=(255, 0, 0))

        if self.env.attacker_ship.target_in_range(self.env.hvu.get_position()) and self.env.check_los_attacker():
            self._draw_dotted_line(self.env.attacker_ship, self.env.hvu, color=(255, 0, 0))

    
    def _render_defenders(self):
        """
        Render all defender ships, their torpedoes, and firing lines if the attacker is in range.
        """
        for defender in self.env.defender_ships:                            
            
            #Render ship and it's icon
            self._render_ship(defender, (0, 0, 255))  # Render defenders (blue)
            self.draw_defender_top_icon(defender)
            
            # Render torpedo and decoy missiles
            self._render_torpedoes(defender, (0, 0, 255))
            self._render_decoyM(defender, (255, 165, 0)) # Render Decoy missile

            if defender.target_in_range(self.env.attacker_ship.get_position()) and self.env.check_los_defender(defender):
                self._draw_dotted_line(defender, self.env.attacker_ship, (0, 0, 255))
                
    
    def draw_defender_top_icon(self, defender):
        
        ship_type = defender.ship_type
        x, y = int(defender.x), int(defender.y)
        
        if defender.ship_health > 0:
            if ship_type == "def_heli":
                # Yellow triangle (upward)

                size = 5  # Triangle size

                # Define the three points of the triangle
                points = [(x, y - size), (x - size, y + size), (x + size, y + size)]

                # Draw the triangle
                pygame.draw.polygon(self.env.screen, (255, 255, 0), points)

            elif ship_type == "def_sonar":
                # White circle
                pygame.draw.circle(self.env.screen, (255, 255, 255), (x, y), 5)


            elif ship_type == "def_decoyM":

                rect_size = 12
                pygame.draw.rect(self.env.screen, (255, 140, 0), (int(x) - rect_size // 2, int(y) - rect_size // 2, rect_size, rect_size))

            

    def _render_helicopter(self, helicopter_position):
        """
        Render the helicopter in a triangle shape with yellow color.
        """
        if helicopter_position is not None:
            x, y = int(helicopter_position[0]), int(helicopter_position[1])
            size = 15  # Triangle size

            # Define the three points of the triangle
            points = [(x, y - size), (x - size, y + size), (x + size, y + size)]

            # Draw the triangle
            pygame.draw.polygon(self.env.screen, (255, 255, 0), points)
            
            # Range of helicop
            pygame.draw.circle(self.env.screen, (255, 255, 0), (int(x), int(y)), self.env.helicop_range, 1)
            
            self.map_generator._render_latlon(x, y, rect_size=20)
            
    
    def _display_health(self):
        """
        Display the health of the attacker and the HVU at the bottom-left corner of the screen.
        """
        # Get screen dimensions
        screen_width, screen_height = self.env.width, self.env.height

        # Render health information with black text
        attacker_health_text = self.env.font.render(f"Attacker Health: {self.env.attacker_ship.ship_health}", True, (255, 0, 0))
        hvu_health_text = self.env.font.render(f"HVU Health: {self.env.hvu.ship_health}", True, (0, 100, 0))

        # Get text dimensions
        text_height = attacker_health_text.get_height()

        # Position at bottom-left corner
        bottom_offset = 10  # Space from the bottom
        left_offset = 10  # Space from the left

        self.env.screen.blit(attacker_health_text, (left_offset, screen_height - 2 * text_height - bottom_offset))
        self.env.screen.blit(hvu_health_text, (left_offset, screen_height - text_height - bottom_offset))

        
    def _render_base(self):
        """
        Render the base location as a green big circle with a black border.
        """
        if hasattr(self.env, "base_location"):

            # First, draw the black border (slightly bigger radius)
            pygame.draw.circle(self.env.screen, (0, 0, 0), self.env.base_location_inPixels, 32)  # Black border with radius 32

            # Then, draw the actual base circle on top
            pygame.draw.circle(self.env.screen, (0, 200, 0), self.env.base_location_inPixels, 30)  # Green circle with radius 30


In [24]:
class NavalShipEnv(gym.Env):
    metadata = {'render.modes': ['human']}
    
    def __init__(
        self,
        screen_width: int = 1000,
        screen_height: int = 600,
        env_name: str = "Naval Ship Environment",
        
        # Default Configuration of defenders and its formation
        tota_num_def: int = 5,
        num_def_with_sonar: int = 1,
        num_def_with_helicopter: int = 2,
        num_def_with_decoy: int = 1,
        num_default_def: int = 1,
        def_default_formation: str = "semicircle",
        map_center: List [float] = [3.0000, 86.0000],
        zoom: int = 6,
        base_location: List [float] = [3.8, 87.5000], 
        
        # --- Ship Configuration ---
        hvu_ship: dict = None,
        att_ship: dict = None,
        def_ship: dict = None,
        def_sonar: dict = None,
        def_heli: dict = None,
        def_decoyM: dict = None,

        # --- Formation and Movement ---
        def_moving_formation: str = "wedge",

        # --- Helicopter Configuration ---
        helicop_path_radius: int = 200,
        helicop_range: int = 150,
        helicop_speed: float = 2.0
    ):
        
        super(NavalShipEnv, self).__init__()

        # Initialize Pygame and basic display
        pygame.init()
        self.width, self.height = screen_width, screen_height
        self.screen = pygame.display.set_mode((self.width, self.height))
        pygame.display.set_caption(env_name)
        self.clock = pygame.time.Clock()
        self.font = pygame.font.SysFont('Arial', 15)

        # --- Store Formations and Helicopter Config ---
        self.avail_def_default_formations = ["semicircle", "circle", "triangle"]
        self.avail_def_moving_formations = ["line", "wedge", "semicircle", "circle", "triangle"]
        self.def_moving_formation = def_moving_formation
        
        self.helicop_path_radius = helicop_path_radius
        self.helicop_range = helicop_range
        self.helicop_speed = helicop_speed
        
        
        # Store Default Configuration of defenders and its formation
        self.num_defenders = tota_num_def
        self.num_def_with_sonar = num_def_with_sonar
        self.num_def_with_helicopter = num_def_with_helicopter
        self.num_def_with_decoy = num_def_with_decoy
        self.num_default_def = num_default_def
        self.def_default_formation = def_default_formation
        self.map_center = map_center
        self.zoom = zoom
        self.base_location = base_location
        self.def_moving_formation = def_moving_formation
        
        # Get User Input for environemnt configuration 
        (
            self.num_defenders,
            self.custmise_def,
            self.num_def_with_sonar,
            self.num_def_with_helicopter,
            self.num_def_with_decoy,
            self.num_default_def,
            self.def_formation,
            self.map_center,
            self.zoom,
            self.base_location,
            self.def_moving_formation
        ) = UIManager(self).get_user_input_screen()
        
        
        # Define Observation Space (Attacker + HVU + all Defenders)
        total_ships = self.num_defenders + 2
        self.observation_space = spaces.Box(
            low=np.zeros(2 * total_ships, dtype=np.float32),
            high=np.array([self.width, self.height] * total_ships, dtype=np.float32),
            dtype=np.float32
        )

        # Define Attacker's Action Space: [Stay, Up, Down, Left, Right]
        self.action_space = spaces.Discrete(5)
        
    
        
        # --- Default Ship Templates ---
        self.ship_templates = {
            "hvu_ship": copy.deepcopy(hvu_ship) if hvu_ship else {
                'x': 600,
                'y': 300,
                'speed': 1.5,
                'ship_type': 'HVU',
                'firing_range': 0,
                'ship_health': 10,
                'torpedo_count': 0,
            },
            "att_ship": copy.deepcopy(att_ship) if att_ship else {
                'x': 300,
                'y': 100,
                'speed': 3.0,
                'ship_type': 'attacker_ship',
                'firing_range': 200,
                'ship_health': 10,
                'reload_delay': 0.5,
                'target_delay': 0.2,
                'torpedo_count': 100,
                'torpedo_fire_speed': 2.0,
                'torpedo_damage': 1,
            },
            "def_ship": copy.deepcopy(def_ship) if def_ship else {
                'x': 0,
                'y': 0,
                'speed': 1.5,
                'ship_type': 'defender',
                'firing_range': 100,
                'reload_delay': 0.5,
                'target_delay': 0.2,
                'helicop_count': 0,
                'torpedo_count': 100,
                'torpedo_fire_speed': 2.0,
                'torpedo_damage': 1,
                'decoyM_count': 0,
            },
            "def_sonar": copy.deepcopy(def_sonar) if def_sonar else {
                'x': 0,
                'y': 0,
                'speed': 1.5,
                'ship_type': 'def_sonar',
                'firing_range': 150,
                'reload_delay': 0.5,
                'target_delay': 0.2,
                'helicop_count': 0,
                'torpedo_count': 100,
                'torpedo_fire_speed': 2.0,
                'torpedo_damage': 1,
                'decoyM_count': 0,
            },
            "def_heli": copy.deepcopy(def_heli) if def_heli else {
                'x': 0,
                'y': 0,
                'speed': 1.5,
                'ship_type': 'def_heli',
                'firing_range': 100,
                'reload_delay': 0.5,
                'target_delay': 0.2,
                'helicop_count': 1,
                'torpedo_count': 100,
                'torpedo_fire_speed': 2.0,
                'torpedo_damage': 1,
                'decoyM_count': 0,
            },
            "def_decoyM": copy.deepcopy(def_decoyM) if def_decoyM else {
                'x': 0,
                'y': 0,
                'speed': 1.5,
                'ship_type': 'def_decoyM',
                'firing_range': 100,
                'reload_delay': 0.5,
                'target_delay': 0.2,
                'helicop_count': 0,
                'torpedo_count': 100,
                'torpedo_fire_speed': 2.0,
                'torpedo_damage': 1,
                'decoyM_count': 100,
                'decoyM_speed': 4.0,
                'decoyM_blast_range': 2.0,
            }
        }
        
        # Call environment reset
        self.reset()
        
        
    def reset(self):
        # Reset environment state
        self.info = {
            'collision': 0,
            'attacker destroyed': 0,
            'HVU destroyed': 0,
            'step count': 0,
            'Returned to Base': 0
        }
        self.reward = 0
        self.done = False
        self.paused = False
        self.attacker_fired = False
        
        
        # Initialize HVU (High Value Unit)
        self.hvu = Ship(
            self,
            ship_id=0,
            screen_width=self.width,
            screen_height=self.height,
            **self.ship_templates["hvu_ship"]
        )

        # Initialize Attacker Ship
        self.attacker_ship = Ship(
            self,
            ship_id=1,
            screen_width=self.width,
            screen_height=self.height,
            **self.ship_templates["att_ship"]
        )


        # Initialize Defender Ships
        self.defender_ships = self.define_defenders()

        # Combine all ships
        self.ships = [self.attacker_ship, self.hvu] + self.defender_ships

        # Initialize Map Renderer
        self.mapGenerator = MapGenerator(self, map_center=self.map_center, zoom=self.zoom)
        self.renderer = Renderer(self, self.mapGenerator)

        # Initialize Managers
        self.defence_system = DefenseSystem(self)
        self.firemechanism = FireMechanism(self)
        self.helicopManager = HelicopterManager(self)
        self.DecoyMissileManager = DecoyMissileManager(self)
        
        #Convert Base Location from lat, long to pixels for calculation
        self.base_location_inPixels = self.mapGenerator._latlon_to_pixels(self.base_location[0], self.base_location[1])
        
        # Return initial observation
        return self._get_obs()

    
    def define_defenders(self):
        defender_ships = []
        current_id = 2  # Start from 2 (0: HVU, 1: Attacker)

        def add_ship_from_template(template_name):
            nonlocal current_id
            # Get a copy of the ship template
            template = copy.deepcopy(self.ship_templates[template_name])

            # Create the Ship using unpacked template values + assigned ship_id
            ship = Ship(
                self,
                ship_id=current_id,
                screen_width=self.width,
                screen_height=self.height,
                **template
            )
            defender_ships.append(ship)
            current_id += 1

        if not self.custmise_def:
            # Default setup with 5 mixed defenders
            add_ship_from_template("def_heli")
            add_ship_from_template("def_decoyM")
            add_ship_from_template("def_sonar")
            add_ship_from_template("def_decoyM")
            add_ship_from_template("def_heli")

            # Fill remaining slots with basic defenders
            for _ in range(self.num_defenders - 5):
                add_ship_from_template("def_ship")

        else:
            # User custom configuration
            for _ in range(self.num_def_with_sonar):
                add_ship_from_template("def_sonar")

            for _ in range(self.num_def_with_helicopter):
                add_ship_from_template("def_heli")

            for _ in range(self.num_def_with_decoy):
                add_ship_from_template("def_decoyM")

            # Fill remaining defenders with default if user count is short
            total_custom = self.num_def_with_sonar + self.num_def_with_helicopter + self.num_def_with_decoy
            for _ in range(self.num_defenders - total_custom):
                add_ship_from_template("def_ship")

        return defender_ships


    def _get_obs(self):
        # Return the positions and headings of all ships
        observation = []
        for ship in self.ships:
            observation.extend(ship.get_position())
        return np.array(observation, dtype=np.float32)
    
    
    def step(self, action):
        #Move ships
        self._movements(action) 
        
        #Check for collisions
        if self.firemechanism.check_collisions():
            self.reward -= 100  # Heavy penalty for collision
            self.info['collision'] += 1
            self.done = True
            return self._get_obs(), self.reward, self.done, self.info
        
        # Calculate reward based on current state
        self.reward += self._calculate_reward()
        
        self.info['step count'] += 1
        return self._get_obs(), self.reward, self.done, self.info
    

    def _movements(self, action):
        """Handle the movements of attacker and defenders in the environment."""
        self._move_attacker(action)  # Move attacker based on action
        
        self.firemechanism._update_torpedo_position()  # Update torpedo positions and calculate rewards
        
        self.DecoyMissileManager.update_decoy_missile() # update decoy missile 
        
        self.move_defenders() # Move defenderan
        
        # Move HVU towards the base if the attacker has fired
        if self.attacker_fired:
            self.move_hvu_to_base() # Move HVU ship to Base direction
            self._check_hvu_reached_base()
            
    
    def _check_hvu_reached_base(self):
        """
        Check if the HVU has reached the base, and end the episode if it has.
        """
        hvu_x, hvu_y = self.hvu.get_position()
        base_x, base_y = self.base_location_inPixels
        self.distance_to_base_and_hvu = np.linalg.norm(np.array([hvu_x, hvu_y]) - np.array([base_x, base_y]))
        
        if self.distance_to_base_and_hvu < 10:  # If HVU is at the base
#             print("HVU has safely reached the base! Episode ends.")
            self.done = True  # End the episode
            self.info['Returned to Base'] += 1
            
            
    def move_hvu_to_base(self):
        """
        Move the HVU towards the base location when the attacker fires.
        """
        hvu_x, hvu_y = self.hvu.get_position()
        base_x, base_y = self.base_location_inPixels

        # Calculate direction vector to base
        direction = np.array([base_x - hvu_x, base_y - hvu_y])
        distance = np.linalg.norm(direction)

        if distance > 10:  # If HVU is not very close to base
            direction = direction / distance  # Normalize direction
            new_x = hvu_x + direction[0] * self.hvu.speed
            new_y = hvu_y + direction[1] * self.hvu.speed

            # Move HVU ship towards the base
            self.hvu.set_position(new_x, new_y)

    
    def move_defenders(self):
        # Handle defense system movements if active
        
        
        if self.defence_system.defense_active:
            self.defence_system.handle_defense_mechanism(formation_type=self.def_moving_formation) # tri, line, wedge, circle

        else: 
            # keep defenders in current formation
            self.defence_system.move_defenders_in_formation(self.def_formation) 
        
        self.defence_system.check_for_defense_activation()
        
        # Movement of Defender's Helicopter
        self.helicopManager.move_defenders_helicop()

    
    def _move_attacker(self, action):
        """Move the attacker ship based on the provided action index."""

        # Mapping action index to movement heading in degrees
        action_heading_map = {
            0: 90,    # Up
            1: 270,   # Down
            2: 180,   # Left
            3: 0,     # Right
            4: 45,    # Up-Right
            5: 135,   # Up-Left
            6: 225,   # Down-Left
            7: 315    # Down-Right
            # 8: No movement
        }

        heading = action_heading_map.get(action)

        if heading is not None:
            self.attacker_ship.move_ship_to_direction(heading=heading)
        elif action == 8:
            pass  # No movement
        else:
            print(f"[WARNING] Invalid action received: {action}")
        
#         return
        
    
    def _calculate_reward(self):
        reward = 0  # Step penalty
        
        # Encourages closer movement
        attacker_pos = self.attacker_ship.get_position()
        hvu_pos = self.hvu.get_position()
        distance_to_hvu = np.linalg.norm(attacker_pos - hvu_pos)
        
        hvu_in_att_range = self.attacker_ship.target_in_range(self.hvu.get_position())
        if hvu_in_att_range:
            reward += 2
        
        # Negative reward for moving too far from HVU (outside attacker's firing range)
        if not hvu_in_att_range:
            reward -= (distance_to_hvu / 10)

        # Check if the attacker is within any defender's firing range
        in_range_defenders = self.defence_system.attacker_within_defender_range() 
        
        if in_range_defenders:
            reward -= 5
        
        # Reward for taking HVU in firing range and being out of defender's range    
        if hvu_in_att_range and not in_range_defenders:
            reward += 20

        reward += self.firemechanism._handle_firing_mechanics()
        
        self.DecoyMissileManager.handle_DecoyM_firing_mechanics()

        return reward
    
    
    
    def check_los_attacker(self):
        """
        Check if any defender ship is blocking the line of sight between the attacker and the HVU ship.
        Returns False if a defender is in the way, otherwise returns True.
        """
        attacker_pos = self.attacker_ship.get_position()
        hvu_pos = self.hvu.get_position()

        # Loop through each defender and check if they are blocking the line of sight
        for defender in self.defender_ships:
            defender_pos = defender.get_position()
            if self.check_if_blocking_los(attacker_pos, hvu_pos, defender_pos):
                return False  # Defender is blocking the line of sight

        return True  # No defender is blocking the line of sight


    def check_los_defender(self, defender):
        """
        Check if the defender has a clear line of sight to fire at the attacker.
        It ensures that no other defender or the HVU is in the way.
        """
        attacker_pos = self.attacker_ship.get_position()
        hvu_pos = self.hvu.get_position()
        defender_pos = defender.get_position()
        if self.check_if_blocking_los(defender_pos, attacker_pos, hvu_pos):  # check if HVU isn't blocking the LOS     
            # Loop through each defender and check if they are blocking the LOS
            for def_ship in self.defender_ships:
                if def_ship.ship_id == defender.ship_id:
                    continue  # Skip the defender itself 
                def_pos = def_ship.get_position() 
                if self.check_if_blocking_los(def_pos, attacker_pos, def_pos):
                    return False 
        
        return True #No ship is blocking the LOS


    def check_line_intersection(self, A, B, C, D):
        """
        Helper function to check if two line segments (AB and CD) intersect.
        Uses determinant calculations to find intersection.

        Args:
        - A, B (np.array): Start and end points of the first line.
        - C, D (np.array): Start and end points of the second line.

        Returns:
        - bool: True if the two line segments intersect, False otherwise.
        """
        def cross_product(v1, v2):
            return v1[0] * v2[1] - v1[1] * v2[0]

        AB = B - A
        AC = C - A
        AD = D - A
        CD = D - C
        CA = A - C
        CB = B - C

        # Compute cross-products
        cross1 = cross_product(AB, AC)
        cross2 = cross_product(AB, AD)
        cross3 = cross_product(CD, CA)
        cross4 = cross_product(CD, CB)

        # Check if the segments properly intersect
        if (cross1 * cross2 < 0) and (cross3 * cross4 < 0):
            return True  # The two line segments intersect

        return False  # No intersection
    
    def check_if_blocking_los(self, start_pos, end_pos, blocker_center, blocker_size=(20, 20)):
        """
        Checks if a rectangular blocker ship is in the line of sight (LoS) between start and end positions.
        Uses line-segment intersection instead of just angle thresholding.

        Args:
        - start_pos (np.array): Start position (x, y).
        - end_pos (np.array): End position (x, y).
        - blocker_center (np.array): Center of the blocking ship (x, y).
        - blocker_size (tuple): (width, height) of the blocker ship.

        Returns:
        - bool: True if blocking, False otherwise.
        """
        # Compute LoS vector
        line_vector = end_pos - start_pos
        if np.linalg.norm(line_vector) == 0:
            return False  # Avoid division by zero

        # Define blocker boundaries
        half_width, half_height = blocker_size[0] / 2, blocker_size[1] / 2
        blocker_corners = [
            blocker_center + np.array([-half_width, -half_height]),  # Bottom-left
            blocker_center + np.array([half_width, -half_height]),   # Bottom-right
            blocker_center + np.array([-half_width, half_height]),   # Top-left
            blocker_center + np.array([half_width, half_height])     # Top-right
        ]

        # Define the four edges of the rectangle as line segments
        blocker_edges = [
            (blocker_corners[0], blocker_corners[1]),  # Bottom edge
            (blocker_corners[1], blocker_corners[3]),  # Right edge
            (blocker_corners[3], blocker_corners[2]),  # Top edge
            (blocker_corners[2], blocker_corners[0])   # Left edge
        ]

        # Check if the line of sight intersects any of the blocker edges
        for edge in blocker_edges:
            if self.check_line_intersection(start_pos, end_pos, edge[0], edge[1]):
                return True  # If LoS intersects any edge, it's blocked

        return False  # No intersection → No blockage
    
    
    def check_for_collisions_while_ship_moves(self, ship, all_ships, ship_target_pos):
        ship_current_pos = ship.get_position()
        for other_ship in all_ships:
            if ship != other_ship:
                other_ship_pos = other_ship.get_position()
                if ship.target_in_range(other_ship_pos):
                    # Check if ship j is in the way of the line from current_pos_i to target_pos
                    if self.check_if_blocking_los(ship_current_pos, ship_target_pos, other_ship_pos):
                        return other_ship, True  # Collision detected

        return None, False  # No collision detected
    
    
    def render(self, mode='human'):
        """Renders the game state, displaying ships, firing ranges, health, and firing actions."""
        
        # Render the Background
        self.mapGenerator.render_background()  # Displays the real map
        
        # Render HVU ship.
        self.renderer._render_ship(self.hvu, color=(0, 255, 0)) # Green ship
        
        # Render attacker ship and its firing logic
        self.renderer._render_attacker()
        
        # Render the defender ships and their firing logic
        self.renderer._render_defenders()
        
        self.renderer._render_helicopter(self.helicopManager.helicopter_position)
        
        # Render the base (Big white circle)
        self.renderer._render_base()
        
        # Display health status
        self.renderer._display_health()
        
        # Update the screen and control the frame rate
        pygame.display.update()
        self.clock.tick(15)  # Reduced from 60 to 15 FPS for slower simulation
    
    def close(self):
        """Closes the Pygame window."""
        pygame.quit()

        
    def seed(self, seed=None):
        """Sets the random seed."""
        random.seed(seed)
        np.random.seed(seed)

# Define the Environment Details

In [25]:
hvu_ship = {     
    'x': 600,
    'y': 300,
    'speed': 1.5,
    'ship_type': 'HVU',
    'firing_range': 0,
    'ship_health': 10,
    'torpedo_count': 0,
}
    

att_ship = {      
    'x': 300,
    'y': 100,
    'speed': 3.0,
    'ship_type': 'attacker_ship',
    'firing_range': 200,
    'ship_health': 10,
    'reload_delay': 0.5,
    'target_delay': 0.2,
    'torpedo_count': 100,
    'torpedo_fire_speed': 2.0,
    'torpedo_damage': 1,    
}

    
def_ship = {        
    'x': 0,
    'y': 0,
    'speed': 1.5,
    'ship_type': 'defender',
    'firing_range': 100,
    'reload_delay': 0.5,
    'target_delay': 0.2,
    'helicop_count': 0,
    'torpedo_count': 100,
    'torpedo_fire_speed': 2.0,
    'torpedo_damage': 1,
    'decoyM_count': 0
}


def_sonar = {   
    'x': 0,
    'y': 0,
    'speed': 1.5,
    'ship_type': 'def_sonar',
    'firing_range': 150,
    'reload_delay': 0.5,
    'target_delay': 0.2,
    'helicop_count': 0,
    'torpedo_count': 100,
    'torpedo_fire_speed': 2.0,
    'torpedo_damage': 1,
    'decoyM_count': 0
}


def_heli = {        
    'x': 0,
    'y': 0,
    'speed': 1.5,
    'ship_type': 'def_heli',
    'firing_range': 100,
    'reload_delay': 0.5,
    'target_delay': 0.2,
    'helicop_count': 1,
    'torpedo_count': 100,
    'torpedo_fire_speed': 2.0,
    'torpedo_damage': 1,
    'decoyM_count': 0
    
}


def_decoyM = {       
    'x': 0,
    'y': 0,
    'speed': 1.5,
    'ship_type': 'def_decoyM',
    'firing_range': 100,
    'reload_delay': 0.5,
    'target_delay': 0.2,
    'helicop_count': 0,
    'torpedo_count': 100,
    'torpedo_fire_speed': 2.0,
    'torpedo_damage': 1,
    'decoyM_count': 100,
    'decoyM_speed': 3.0,
    'decoyM_blast_range': 2.0 
}


# Game Play

In [26]:
if __name__ == "__main__":
    
    # Initialise the Environment 
    env = NavalShipEnv(
        screen_width = 1000,
        screen_height = 600,
        env_name = "Naval Ship Environment",

        # Default Configuration of defenders and its formation
        tota_num_def = 6,
        num_def_with_sonar = 1,
        num_def_with_helicopter = 2,
        num_def_with_decoy = 2,
        num_default_def = 1,
        def_default_formation = "semicircle",
        map_center = [3.0000, 86.0000],
        zoom = 6,
        base_location = [3.8, 87.5000], 
        
        # --- Ship Configuration ---
        hvu_ship = hvu_ship,
        att_ship = att_ship,
        def_ship = def_ship,
        def_heli = def_heli,
        def_decoyM = def_decoyM,
        def_sonar = def_sonar, 

        # --- Formation and Movement ---
        def_moving_formation = "line",

        # --- Helicopter Configuration ---
        helicop_path_radius = 200, # helicopter path is circle around HVU 
        helicop_range = 150,
        helicop_speed = 2.0
    )
    
    obs = env.reset()

    running = True
    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_p:
                    env.paused = not env.paused
                elif event.key == pygame.K_q:
                    running = False

        if not env.paused and not env.done:
            action = env.action_space.sample()
            obs, reward, done, _ = env.step(action)
            env.render()

    env.close()