In [1]:
import os
import time
import pylab
import torch
import pickle
import numpy as np
from math import *
import matplotlib as mpl
from cycler import cycler
from numpy.linalg import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from models.PredatorPolicy import PredatorPolicy
from models.PreyPolicy import PreyPolicy

mpl.use('TkAgg')
from utils.couzin_utils import run_couzin_simulation

In [2]:
from models.ModularNetworks import PairwiseInteraction, Attention
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

class ModularPolicy(nn.Module):
    def __init__(self, features=4):
        super(ModularPolicy, self).__init__()

        self.pairwise = PairwiseInteraction(features)
        self.attention = Attention(features)

    def forward(self, states, deterministic=True):
        mu, sigma = self.pairwise(states)

        weights_logit = self.attention(states)
        weights = torch.softmax(weights_logit, dim=1)

        if deterministic:
            scaled_action = torch.sigmoid(mu)
            action = (scaled_action * weights).sum(dim=1)
            return action
        else:
            action = Normal(mu, sigma).rsample()
            scaled_action = torch.sigmoid(action)
            action = (scaled_action * weights).sum(dim=1)
            return action
        
    def set_parameters(self, init=True):
        if init is True:
            for layer in self.modules():
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()



def pretrain_policy(policy, expert_data, batch_size=256, epochs=250, lr=1e-3, deterministic=True, device='cpu'):
    policy.to(device)
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)

    frames, agents, neigh, features = expert_data.shape
    expert_data = expert_data.view(frames * agents, neigh, features)
    
    states = expert_data[..., :4]
    actions = expert_data[:, 0, 4]

    dataset = TensorDataset(states, actions)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(1, epochs + 1):

        for s, a in loader:
            states = s.to(device)
            actions = a.to(device)

            action_prey = policy.forward(states, deterministic=deterministic)

            loss = F.mse_loss(action_prey, actions)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 25 == 0:
            print(f"Epoch {epoch}/{epochs}, Loss: {loss.item():.6f}")

    return policy

In [3]:
bc_folder = "BC Training - 19.12.2025_14.25 - Couzin Data"

model_folder = rf"..\data\2. Training\training"
bc_path = os.path.join(model_folder, "BC", bc_folder)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# BC Simulation
pred_policy = torch.load(os.path.join(bc_path, "bc_pred_policy.pt"), weights_only=False)

expert_pred_tensor, expert_prey_tensor, logs = run_couzin_simulation(visualization="off", max_steps=100, alpha=0.01, 
                                                       constant_speed=2, shark_speed=5, area_width=50, area_height=50, 
                                                       number_of_sharks=1, n=32)


policy = ModularPolicy().to(device)
pretrain_policy(policy, expert_data=expert_prey_tensor, batch_size=4048, epochs=100, lr=1e-3, deterministic=True, device=device)
print("Pretraining done.\n")

  loss = F.mse_loss(action_prey, actions)


Epoch 25/100, Loss: 0.465000
Epoch 50/100, Loss: 0.381186
Epoch 75/100, Loss: 0.363026
Epoch 100/100, Loss: 0.359746
Pretraining done.



In [None]:
class Agent:
    def __init__(self, agent_id, speed, area_width, area_height):
        self.id = agent_id
        self.pos = np.array([np.random.uniform(0, area_width),
                             np.random.uniform(0, area_height)], dtype=np.float64)

        self.theta = np.random.uniform(-np.pi, np.pi)
        self.vel = np.array([np.cos(self.theta), np.sin(self.theta)], dtype=np.float64) * speed

    def update_position(self, dt):
        self.pos += self.vel * dt


def apply_turnrate_on_theta(agent, dtheta_raw, speed, max_turn=np.pi/12):
    d = float(dtheta_raw)

    d = (np.clip(d, 0.0, 1.0) * 2.0 - 1.0) * max_turn

    agent.theta = (agent.theta + d + np.pi) % (2*np.pi) - np.pi
    agent.vel = np.array([np.cos(agent.theta), np.sin(agent.theta)], dtype=np.float64) * speed



def enforce_walls(agent, area_width, area_height):
    bounced = False

    if agent.pos[0] < 0:
        agent.pos[0] = 0
        agent.theta = np.pi - agent.theta
        bounced = True
    elif agent.pos[0] > area_width:
        agent.pos[0] = area_width
        agent.theta = np.pi - agent.theta
        bounced = True

    if agent.pos[1] < 0:
        agent.pos[1] = 0
        agent.theta = -agent.theta
        bounced = True
    elif agent.pos[1] > area_height:
        agent.pos[1] = area_height
        agent.theta = -agent.theta
        bounced = True

    if bounced:
        agent.theta = (agent.theta + np.pi) % (2*np.pi) - np.pi



def get_state_tensors(prey_log_step, pred_log_step, n_pred=1, 
                      area_width=50, area_height=50,
                      prey_speed=5, pred_speed=5, device="cuda"):
    
    combined = np.vstack([pred_log_step, prey_log_step]).astype(np.float32)  # [N,6]
    n_agents = combined.shape[0]

    xs, ys  = combined[:, 0], combined[:, 1]
    vxs, vys = combined[:, 2], combined[:, 3]
    dir_x, dir_y = combined[:, 4], combined[:, 5]

    theta = np.arctan2(dir_y, dir_x).astype(np.float32)
    theta_norm = (theta / np.pi).astype(np.float32)
    cos_t = np.cos(theta).astype(np.float32)
    sin_t = np.sin(theta).astype(np.float32)

    xs_scaled = np.clip(xs, 0, area_width) / float(area_width)
    ys_scaled = np.clip(ys, 0, area_height) / float(area_height)

    dx = xs_scaled[None, :] - xs_scaled[:, None]
    dy = ys_scaled[None, :] - ys_scaled[:, None]

    rel_vx = cos_t[:, None] * vxs[None, :] + sin_t[:, None] * vys[None, :]
    rel_vy = -sin_t[:, None] * vxs[None, :] + cos_t[:, None] * vys[None, :]

    speed = max(prey_speed, pred_speed)
    rel_vx = np.clip(rel_vx, -speed, speed) / speed
    rel_vy = np.clip(rel_vy, -speed, speed) / speed

    theta_mat = np.tile(theta_norm[:, None], (1, n_agents))
    features = np.stack([dx, dy, rel_vx, rel_vy, theta_mat], axis=-1).astype(np.float32)

    mask = ~np.eye(n_agents, dtype=bool)
    neigh = features[mask].reshape(n_agents, n_agents-1, 5)  # [agent, neigh, feat]

    tensor = torch.from_numpy(neigh).unsqueeze(0).to(device=device, dtype=torch.float32)
    step, agent, neigh, feat = tensor.size()

    tensor = tensor.view(step * agent, neigh, feat)  # [1*agent, N-1, 5]

    pred_tensor = tensor[:n_pred]    # [n_pred, N-1, 5] (kann 0 sein)
    prey_tensor = tensor[n_pred:]    # [n_prey, N-1, 5]

    pred_states = pred_tensor[..., :4]
    prey_states = prey_tensor[..., :4]

    return pred_states, prey_states



def run_env_simulation(prey_policy=None, pred_policy=None, n_prey=32, n_pred=1, max_steps=100, prey_speed=15, pred_speed=15, area_width=50, area_height=50, visualization='off', device="cpu"):

    prey_policy.to(device)
    pred_policy.to(device)

    prey = [Agent(i, prey_speed, area_width, area_height) for i in range(n_prey)]
    pred = [Agent(i, pred_speed, area_width, area_height) for i in range(n_pred)]

    prey_pos = np.zeros((n_prey, 2))
    prey_vel = np.zeros((n_prey, 2))
    
    pred_pos = np.zeros((n_pred, 2))
    pred_vel = np.zeros((n_pred, 2))

    t = 0

    if visualization == 'on':
        fig, ax = plt.subplots()

    prey_log = np.zeros((max_steps, n_prey, 6))
    predator_log = np.zeros((max_steps, n_pred, 6))

    pred_tensor_list = []
    prey_tensor_list = []

    while t < max_steps:
        # Prey
        for i, agent in enumerate(prey):
            prey_pos[i, :] = agent.pos
            prey_vel[i, :] = agent.vel

            vel_norm = (norm(agent.vel[0:2]) + 1e-12)
            if vel_norm > 1e-12:
                dir_xy = agent.vel[0:2] / vel_norm
            else:
                dir_xy = np.zeros(2, dtype=np.float64)

            prey_log[t, i, 0:2] = agent.pos[0:2]
            prey_log[t, i, 2:4] = agent.vel[0:2]
            prey_log[t, i, 4:6] = dir_xy

        if n_pred > 0:
            for i, predator in enumerate(pred):
                pred_pos[i, :] = predator.pos

                v = norm(predator.vel)
                pred_vel[i, :] = predator.vel / (v + 1e-12) / 80 * area_width

                vel_norm_s = norm(predator.vel[0:2])
                if vel_norm_s > 1e-12:
                    dir_xy_s = predator.vel[0:2] / vel_norm_s
                else:
                    dir_xy_s = np.zeros(2, dtype=np.float64)

                predator_log[t, i, 0:2] = predator.pos[0:2]
                predator_log[t, i, 2:4] = predator.vel[0:2]
                predator_log[t, i, 4:6] = dir_xy_s

        # --- Visualization ---
        if visualization == 'on':
            ax.clear()

            pylab.quiver(
                prey_pos[:, 0], prey_pos[:, 1],
                prey_vel[:, 0], prey_vel[:, 1],
                scale=120,
                width=0.01,
                headwidth=3,
                headlength=3,
                headaxislength=3,
            )

            if n_pred > 0:
                pylab.quiver(
                    pred_pos[:, 0], pred_pos[:, 1],
                    pred_vel[:, 0], pred_vel[:, 1],
                    color="#FF0000",
                    scale=15,
                    width=0.01,
                    headwidth=3,
                    headlength=3,
                    headaxislength=3,
                )
            ax.set_aspect('equal', 'box')
            ax.set_xlim(0, area_width)
            ax.set_ylim(0, area_height)

            plt.pause(0.00000001)

        pred_states, prey_states = get_state_tensors(prey_log[t], predator_log[t], n_pred=n_pred, area_width=area_width, area_height=area_height, prey_speed=prey_speed, pred_speed=pred_speed, device=device)  # [1,agent,neigh,feat]

        if n_pred > 0:
            pred_states = pred_tensor[..., :4]
            pred_actions = pred_policy.forward(pred_states)

            agents, neigh, feat = pred_states.shape
            pred_actions_exp = pred_actions.unsqueeze(1).expand(-1, neigh, -1)
            step_pred_tensor = torch.cat([pred_states, pred_actions_exp], dim=0)
            pred_tensor_list.append(step_pred_tensor)


        prey_actions = prey_policy.forward(prey_states, deterministic=True)

        agents, neigh, feat = prey_states.shape
        prey_actions_exp = prey_actions.unsqueeze(1).expand(-1, neigh, -1)  # [A, neigh, 1]

        step_prey_tensor = torch.cat([prey_states, prey_actions_exp], dim=-1)  # [A, neigh, 5]
        prey_tensor_list.append(step_prey_tensor)

        for i, agent in enumerate(prey):
            apply_turnrate_on_theta(agent, prey_actions[i], prey_speed)

        if n_pred > 0:
            for i, predator in enumerate(pred):
                apply_turnrate_on_theta(predator, pred_actions[i], pred_speed)

        for agent in prey:
            enforce_walls(agent, area_width, area_height)
            agent.update_position(dt=1.0)

        if n_pred > 0:
            for predator in pred:
                enforce_walls(predator, area_width, area_height)
                predator.update_position(dt=1.0)

        t += 1

    pred_tensor = torch.stack(pred_tensor_list, dim=0) if n_pred > 0 else 0
    prey_tensor = torch.stack(prey_tensor_list, dim=0) 

    return pred_tensor, prey_tensor

In [None]:
gen_pred_tensor, gen_prey_tensor = run_env_simulation(visualization='off', prey_policy=policy, pred_policy=pred_policy, 
                                                      n_prey=32, n_pred=0, max_steps=100, 
                                                      pred_speed=5, prey_speed=5,
                                                      area_width=50, area_height=50, device=device)

expert_pred_tensor, expert_prey_tensor, logs = run_couzin_simulation(visualization="on", max_steps=100, alpha=0.01, 
                                                       constant_speed=5, shark_speed=5, area_width=50, area_height=50, 
                                                       number_of_sharks=0, n=32)


