In [None]:

x_max, y_max = 200, 200

# Constants and initial conditions
alpha = 0.00075
sigma = 0.05
beta = 0.06
rate = 0.03
delta = 100
gamma = 0.001
epsilon = .00001
V = 1

u_min = 0
u_max = 5

n_ships = 5
n_grounds = 7
t_final = 400
num_timesteps = 200



# Initialize scenario
# grounds = generate_fishing_grounds(n_grounds, x_max, y_max)
# ships = generate_ships(n_ships, x_max, y_max)
N = ships.shape[0]
K = grounds.shape[0]
h = np.array([x_max / 2, 0])

# Initialize U as a tensor of shape (num_timesteps, N, K + 2)
U_tensor = np.ones((num_timesteps, N, K + 2))
U_tensor[:, :, -2] = 0
U_tensor[:, :, -1] = 2

F0 = 100
F = np.ones(K) * F0

# Flatten initial conditions
new_x = ships
x_size = np.prod(x.shape)
x_shape = x.shape
F_size = np.prod(F.shape)
p_size = x_size
p_shape = x_shape
y0 = np.concatenate((x.flatten(), F))

p_init = np.zeros((n_ships, 2))
pf_init = np.full(F.shape, -delta)
p0 = np.concatenate((p_init.flatten(), pf_init))

for i in range(1):

    # Define parameters for state function
    state_params = [alpha, sigma, beta, rate, grounds, h, x_size, F_size, x_shape]
    params_costate = [alpha, sigma, beta, epsilon, rate, grounds, h, p_size, F_size, p_shape]

    # Solve state equations
    sol_state = solve_state_equations(state, U_tensor, state_params, t_final, y0, num_timesteps=num_timesteps)

    # print("done")
    # Prepare initial conditions for costate solution

    # Solve costate equations
    sol_costate = solve_costate_equations(costate, sol_state, U_tensor, params_costate, t_final, p0, x_size, F_size, num_timesteps=num_timesteps)
    
    # update initial conditions for p0
    # p0 = sol_costate.y[:, -1]
    # Processing the results
    # Reverse the solutions to match the original time order
    costate_results = sol_costate.y[:, ::-1]

    p_vectors = costate_results[:p_size].reshape(n_ships, 2, -1)
    pf_vectors = costate_results[p_size:]

    x = sol_state.y[:x_size].reshape((n_ships, 2, -1))
    F = sol_state.y[x_size:x_size+F_size]


    # print("Updating U")
    for t in range(num_timesteps):
        ship_positions = x[:, :, t]
        fish_remaining = F[:, t]
        D = d_matrix(ship_positions, grounds)
        G = ship_ground_gaussians(ship_positions, grounds, alpha)

        U = U_tensor[t, :, :]


        
        
        for ship in range(n_ships):

            U_ship = U[ship]
            p = p_vectors[ship, :, t]
            # attraction = U_ship[:-2, None] * G[ship, :, None] * D[ship] * fish_remaining[:, None]
            # s_before = np.sum(attraction, axis=0)
            # orthogonal = np.array([-s_before[1], s_before[0]])
            # dot_product = np.dot(p, orthogonal)
            # print("Previous dot product", dot_product)

                        # Regularization strength
            lambda_reg = 10 # Adjust this parameter based on desired smoothness

            # Previous values of U, initialize as needed

            # Bounds for U
            bounds = [(u_min, u_max)] * (K + 2)

            
            def minimize_dot_product(U_ship):
                # Calculate the attraction based on current U
                attraction = U_ship[:-2, None] * G[ship, :, None] * D[ship] * fish_remaining[:, None]
                s = np.sum(attraction, axis=0)
                orthogonal = np.array([s[1], -s[0]])
                dot_product = np.dot(p, orthogonal)
                # Regularization term: penalize the square of the difference from the previous U
                regularization = lambda_reg * np.sum((U_tensor[t-1, ship] - U_ship)**2)
                if t == 0:
                    regularization = 0
                return np.abs(dot_product) + regularization
            
            # Find the optimal U for the ship
            res = minimize(minimize_dot_product, U[ship], bounds=[(u_min, u_max)] * (K + 2))
            # print(res)

            U_tensor[t, ship, :] = res.x

            # attraction = U_ship[:-2, None] * G[ship, :, None] * D[ship] * fish_remaining[:, None]
            # s_after = np.sum(attraction, axis=0)
            # orthogonal = np.array([-s_after[1], s_after[0]])
            # dot_product = np.dot(p, orthogonal)
            # print("Updated dot product", dot_product)
            # plot_vectors(p, s_before, s_after)
            
    # print(F[:, -1].shape)
    # sum up over the last time step to get the total fish remaining
    initial_fish = np.sum(F[:, 0])
    total_fish_remaining = np.sum(F[:, -1])
    print("Fish Harvested: ", initial_fish - total_fish_remaining)

    F_normalized = F[:, -1] / np.max(F[:, -1])

    fig = plt.figure(figsize=(12, 6))
    fig.patch.set_facecolor('none')

    ax = fig.add_subplot(111)
    ax.patch.set_facecolor('white')

    plt.scatter(grounds[:, 0], grounds[:, 1], c=F_normalized, cmap='viridis', marker='x', s=100, label='Fishing Grounds')

    cbar = plt.colorbar()
    cbar.set_label('Percentage of Fish Remaining', rotation=270, labelpad=20)

    plt.scatter(x[:, 0, 0], x[:, 1, 0], c='green', marker='^')

    for j in range(x.shape[0]):
        plt.plot(x[j, 0, :], x[j, 1, :])

    plt.xlabel('x')
    plt.ylabel('y')
    plt.title(fr'Optimal Control State Trajectories, $\beta = {beta}$')

    plt.xlim(0, x_max)
    # plt.ylim(0, y_max)

    fishing_grounds = mlines.Line2D([], [], color='#fde725', marker='x', linestyle='None', markersize=10, label='Fishing Grounds')
    starting_points = mlines.Line2D([], [], color='green', marker='^', linestyle='None', markersize=10, label='Initial Boat Positions')
    trajectories = mlines.Line2D([], [], color='blue', marker='_', linestyle='-', markersize=10, label='Trajectories')

    plt.legend(handles=[fishing_grounds, starting_points, trajectories], loc='lower right')

    # plt.savefig('massive_resolution.png', transparent=True)
    plt.show()


    fig, ax = plt.subplots(figsize=(12, 6))
    fig.patch.set_facecolor('none')


    # Plot fishing grounds
    plt.scatter(grounds[:, 0], grounds[:, 1], c='blue', marker='x', s=100, label='Fishing Grounds')

    # Initial positions of ships
    plt.scatter(x[:, 0, 0], x[:, 1, 0], c='green', marker='^', label='Initial Boat Positions')

    # Choose a specific number of ships to plot
    # num_ships_to_plot = 2
    # You can select ships randomly:
    # selected_ships = np.random.choice(range(n_ships), num_ships_to_plot, replace=False)
    # Or select specific ships by their indices:
    selected_ships = [i for i in range(n_ships)]  # for example, to choose the first three ships
    # selected_ships = [0, 1]

    # Plot trajectories for selected ships only
    for k in selected_ships:
        plt.plot(x[k, 0, :], x[k, 1, :], label=f'Ship {i+1} Trajectory')

    # Normalize p vectors
    p_vectors = costate_results[:p_size].reshape(n_ships, 2, -1)
    norms = np.linalg.norm(p_vectors[selected_ships], axis=1, keepdims=True)
    normalized_p_vectors = p_vectors[selected_ships] / (norms + NORMALIZATOR)

    # Adjust step for plotting arrows
    step = max(1, num_timesteps // 25)

    # Plot p vectors for selected ships
    for j, idx in enumerate(selected_ships):
        indices = np.arange(0, num_timesteps, step)
        # indices = indices[indices < num_timesteps]  # Ensure indices are within bounds
        ax.quiver(x[idx, 0, indices], x[idx, 1, indices], normalized_p_vectors[j, 0, indices], normalized_p_vectors[j, 1, indices], color='red', scale=30, headwidth=2, headlength=2.5, width=0.005, headaxislength=2)

    # Set labels and titles
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title(fr'Optimal Control State Trajectories with Costate Vectors, $\beta = {beta}$')

    # Set plot limits
    plt.xlim(0, x_max)

    # Legend setup
    fishing_grounds = mlines.Line2D([], [], color='#fde725', marker='x', linestyle='None', markersize=10, label='Fishing Grounds')
    starting_points = mlines.Line2D([], [], color='green', marker='^', linestyle='None', markersize=10, label='Initial Boat Positions')
    trajectories = mlines.Line2D([], [], color='blue', marker='_', linestyle='-', markersize=10, label='Trajectories')
    plt.legend(handles=[fishing_grounds, starting_points, trajectories], loc='lower right')

    plt.show()

    print(f"Iteration {i+1} done")

