In [None]:
import time

import torch
from SwarmEnvironment import *
from SwarmAgent import *
import numpy
# ------------------------------
# 4. Training Loop
# ------------------------------
load_model = True
numpy.random.seed(20)
manual_selected_device = torch.device("cuda")

n_agents = 20
space_size = 20
visible_neighbor_amount = 4
error_tolerance = 0.5 # to goal
collision_tolerance = 0.25
linear_displacement = 0.125
angular_displacement = 22.5

env = SwarmEnv(n_agents=n_agents, space_size=space_size, linear_displacement=linear_displacement, angular_displacement=angular_displacement, visible_neighbor_amount=visible_neighbor_amount)
env.set_random_goals()
agent = Agent(state_dim=env.observation_dimension, action_dim=env.action_amount, device=manual_selected_device)
if load_model:
    agent.model.load_state_dict(torch.load("swarm_agent_model.pth"))
    agent.target.load_state_dict(torch.load("swarm_target_model.pth"))
print(env.action_amount)


agent.gamma = 0.9 # q learning gamma, learning rate
agent.epsilon = 1.0 # action randomness 1 for fully random
agent.batch_size = 128

training_steps = 1000
episodes_length = 516

epsilon_decay = 0.999 # action randomness decay rate
epsilon_decay_accelerating_factor = 0.99
epsilon_min = 0.05 # minimum epsilon

env.non_goal_reward = -1.0
env.goal_reward = 50.0
env.collision_reward = -2.0
env.distance_reward_factor = 5.0 / linear_displacement # how much nearest neighbor evey agent can visit

total_rewards = np.zeros(training_steps)
epsilons = np.zeros(training_steps)
time_spend = np.zeros(training_steps)

for episode in range(training_steps):
    observations = env.reset()
    env.set_random_goals()
    total_reward = 0
    start_time = time.time()
    step = 0
    for step in range(episodes_length):
       # Batched GPU/MPS inference for all agents
       actions = agent.select_multiple_actions(observations)  # replaces the for-loop

        # Environment step (expects actions as a list or array)
       next_observations, rewards, done, _ = env.step(actions, error_tolerance=error_tolerance, collision_tolerance=collision_tolerance)

       # Store transitions for all agents
       for i in range(env.n_agents):
           agent.store(observations[i], actions[i], rewards[i], next_observations[i])

       # Train DQN
       agent.train_step()

       # Move to next step
       observations = next_observations
       total_reward += np.mean(rewards)

       # End early if environment finishes
       if done:
           agent.train_step(done)
           break

    agent.update_target()
    # agent.epsilon = max(epsilon_min, agent.epsilon * epsilon_decay)

    delta_time = time.time() - start_time

    total_rewards[episode] = total_reward
    epsilons[episode] = agent.epsilon
    time_spend[episode] = delta_time
    print(f"Episode {episode + 1}, steps {step + 1:.0f} (done: {env.done_count/n_agents * 100.0:.3f}% collision: {env.collision_count:.0f}), Average total reward {total_reward:.5f}, epsilon {agent.epsilon:.5f} time {delta_time:.5f}s")

    agent.epsilon_decay(epsilon_min, epsilon_decay)

# Save model weights
torch.save(agent.model.state_dict(), "swarm_agent_model.pth")
torch.save(agent.target.state_dict(), "swarm_target_model.pth")


In [None]:
def visualize_swarm(agent, env, steps=50, save=False, interval=10):
    """
    Visualize swarm movement in 3D and optionally save 4 views as .gif.
    - Agents that are 'done' are shown in green from that step onward.
    - save=False → live animation (in Jupyter)
    - save=True  → export 4 gifs: normal, xy, xz, yz
    """
    obs = env.reset()
    positions_history = [env.positions.copy()]
    done_history = []

    print("stepping")
    done_flags = np.zeros(env.num_agents, dtype=bool)

    for _ in range(steps):
        actions = agent.select_multiple_actions(obs)
        obs, _, done, _ = env.step(actions, error_tolerance=error_tolerance)

        # Support both scalar and per-agent done
        if np.isscalar(done):
            done_flags[:] = done
        else:
            done_flags |= np.array(done)  # once done → always done

        done_history.append(done_flags.copy())
        positions_history.append(env.positions.copy())

        if np.all(done_flags):
            break

    print("plotting")

    positions_history = np.array(positions_history)  # [T, n_agents, 3]
    done_history = np.array(done_history)            # [T, n_agents]
    n_steps, n_agents, _ = positions_history.shape

    goals = np.atleast_2d(env.goal)
    if goals.shape[0] == 1:
        goals = np.repeat(goals, n_agents, axis=0)

    # ----------------------------
    # Helper to make and save GIFs
    # ----------------------------
    def make_animation(view_name, elev, azim):
        fig = plt.figure(figsize=(6, 6))
        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlim(0, env.space_size)
        ax.set_ylim(0, env.space_size)
        ax.set_zlim(0, env.space_size)
        ax.set_xlabel("X-axis")
        ax.set_ylabel("Y-axis")
        ax.set_zlabel("Z-axis")
        ax.set_title(f"3D Swarm Movement ({view_name})")

        # Initial scatter
        scat = ax.scatter([], [], [], c='blue', s=50, label='Agents')
        ax.scatter(goals[:,0], goals[:,1], goals[:,2],
                   c='red', s=100, marker='*', label='Goals')

        lines = [ax.plot([], [], [], 'gray', linestyle='--', linewidth=1)[0]
                 for _ in range(n_agents)]

        ax.view_init(elev=elev, azim=azim)
        ax.legend()

        def init():
            scat._offsets3d = ([], [], [])
            return [scat, *lines]

        def update(frame):
            pos = positions_history[frame]
            done_flags = done_history[min(frame, done_history.shape[0]-1)]
            colors = ['green' if d else 'blue' for d in done_flags]
            scat._offsets3d = (pos[:,0], pos[:,1], pos[:,2])
            scat.set_color(colors)

            for i, line in enumerate(lines):
                x = [pos[i, 0], goals[i, 0]]
                y = [pos[i, 1], goals[i, 1]]
                z = [pos[i, 2], goals[i, 2]]
                line.set_data(x, y)
                line.set_3d_properties(z)

            ax.set_title(f"3D Swarm Movement ({view_name}) - Step {frame}/{n_steps}")
            return [scat, *lines]

        ani = animation.FuncAnimation(
            fig, update, frames=n_steps, init_func=init,
            interval=interval, blit=False
        )

        if save:
            filename = f"swarm_simulation_{view_name.lower()}.gif"
            ani.save(filename, writer='pillow')
            print(f"✅ Saved {filename}")
        else:
            plt.show()

        plt.close(fig)

    # ------------------------------------------------
    # A) SAVE FOUR VIEWS AS GIFS
    # ------------------------------------------------
    if save:
        make_animation("normal", elev=30, azim=45)
        make_animation("xy", elev=90, azim=-90)
        make_animation("xz", elev=0, azim=-90)
        make_animation("yz", elev=0, azim=0)
        return

    # ------------------------------------------------
    # B) LIVE DISPLAY (REAL-TIME UPDATE in Jupyter)
    # ------------------------------------------------
    plt.ion()
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.set_xlim(0, env.space_size)
    ax.set_ylim(0, env.space_size)
    ax.set_zlim(0, env.space_size)
    ax.set_title("3D Swarm Movement (Live)")
    scat = ax.scatter([], [], [], c='blue', s=50, label='Agents')
    ax.scatter(goals[:,0], goals[:,1], goals[:,2],
               c='red', s=100, marker='*', label='Goals')
    lines = [ax.plot([], [], [], 'gray', linestyle='--', linewidth=1)[0]
             for _ in range(n_agents)]
    ax.legend()

    for frame in range(n_steps):
        pos = positions_history[frame]
        done_flags = done_history[min(frame, done_history.shape[0]-1)]
        colors = ['green' if d else 'blue' for d in done_flags]
        scat._offsets3d = (pos[:,0], pos[:,1], pos[:,2])
        scat.set_color(colors)

        for i, line in enumerate(lines):
            x = [pos[i, 0], goals[i, 0]]
            y = [pos[i, 1], goals[i, 1]]
            z = [pos[i, 2], goals[i, 2]]
            line.set_data(x, y)
            line.set_3d_properties(z)

        ax.set_title(f"3D Swarm Movement (Live) - Step {frame}/{n_steps}")
        display(fig)
        clear_output(wait=True)
        plt.pause(interval / 1000.0)

    plt.ioff()
    plt.show()


In [None]:
  visualize_swarm(agent, env, steps=600, save=True)