In [10]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import torch
import numpy as np
import torch.nn as nn
from utils.sim_utils import *
import torch.nn.functional as F
from utils.couzin_utils import *
from torch.distributions import Normal
from models.ModularNetworks import Attention

### AutoEncoder for Spatial Temporal Representation Learning

https://github.com/HSoo-Kim/SpatioTemporal-AutoEncoder/blob/main/Models/ST_AutoEncoder.py

https://github.com/AlexanderFabisch/vtae/blob/master/trajectory_vae.py



In [11]:
#policy_path = rf'..\models\modular_policy.pt'
#policy = torch.load(policy_path, weights_only=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

_, exp_tensor, _ = run_couzin_simulation(visualization="off", max_steps=500, alpha=0.01, 
                                                       constant_speed=5, shark_speed=5, area_width=50, area_height=50, 
                                                       number_of_sharks=0, n=32)

# [n_frames, agents, neigh, features]
print("Expert data shape:", exp_tensor.shape)

Expert data shape: torch.Size([500, 32, 31, 5])


In [12]:
def sample_data(data, consecutive_frames=10, batch_size=10):
    frames, agents, neigh, features = data.shape

    window_list = []

    for i in range(batch_size):
        start = torch.randint(0, frames - consecutive_frames + 1, (1,), device=data.device).item()
        end = start + consecutive_frames
        window = data[start:end]  # [T, A, N, F]
        window_list.append(window)

    stacked_windows = torch.stack(window_list, dim=0)  # [count, T, A, N, F]

    return stacked_windows

In [13]:
expert_batch = sample_data(exp_tensor, consecutive_frames=10, batch_size=10)
print("Sampled expert batch shape:", expert_batch.shape)

Sampled expert batch shape: torch.Size([10, 10, 32, 31, 5])


In [23]:
class NeighborPooling(nn.Module):
    def __init__(self, features=4, embd_dim=32):
        super().__init__()

        self.embed = nn.Sequential(
            nn.Linear(features, 64),
            nn.LeakyReLU(0.1),
            nn.Linear(64, embd_dim),
            nn.LeakyReLU(0.1),
        )

        self.attention = Attention(features)

    def forward(self, states):
        embed = self.embed(states)
        weights_logit = self.attention(states)

        weights = torch.softmax(weights_logit, dim=2)
        pooled = (embed * weights).sum(dim=2)

        return pooled #torch.Size([100, 32, 31])
    

class AgentPooling(nn.Module):
    def __init__(self, embd_dim=32, z=32):
        super().__init__()

        self.embed = nn.Sequential(
            nn.Linear(embd_dim, 64),
            nn.LeakyReLU(0.1),
            nn.Linear(64, z)
        )

    def forward(self, pooled_embd):
        embed = self.embed(pooled_embd)
        return embed # torch.Size([100, 32, 1])


class TransitionEncoder(nn.Module):
    def __init__(self, features=4, embd_dim=32, z=32):
        super().__init__()

        self.neigh_pooling = NeighborPooling(features=features, embd_dim=embd_dim)
        self.agent_pooling = AgentEmbedding(embd_dim=embd_dim, z=z)

    def forward(self, states):
        batch, frames, agents, neigh, features = states.shape
        flat = states.view(batch*frames, agents, neigh, features)

        pooled_embd = self.neigh_pooling(flat)
        pooled = pooled_embd.view(batch, frames, agents, -1) 

        z_state = self.agent_pooling(pooled)
        
        z_t   = z_state[:, :-1]
        z_tp1 = z_state[:,  1:]
        dz = z_tp1 - z_t
        transition_feature = torch.cat([z_t, dz], dim=-1)
        return transition_feature
    

class TransitionDecoder(nn.Module):
    def __init__(self, z=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z, 64),
            nn.LeakyReLU(0.1),
            nn.Linear(64, 64),
            nn.LeakyReLU(0.1),
            nn.Linear(64, z)
        )
    def forward(self, z_t):
        return self.net(z_t)
    

class TransitionAE(nn.Module):
    def __init__(self, features=4, embd_dim=32, z=32):
        super().__init__()
        self.encoder = TransitionEncoder(features=features, embd_dim=embd_dim, z=z)
        self.decoder = TransitionDecoder(z=z)

    def forward(self, states):
        dz, z_state = self.encoder(states)
        z_t = z_state[:, :-1]
        z_tp1 = z_state[:, 1:]
        z_hat_tp1 = self.decoder(z_t)
        return z_state, dz, z_hat_tp1, z_tp1

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TransitionAE(features=4, embd_dim=32, z=32).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

_, exp_tensor, _ = run_couzin_simulation(visualization="off", max_steps=500, alpha=0.01, 
                                                       constant_speed=5, shark_speed=5, area_width=50, area_height=50, 
                                                       number_of_sharks=0, n=32)

epochs = 200
for epoch in range(1, epochs + 1):
    model.train()

    expert_batch = sample_data(exp_tensor, consecutive_frames=10, batch_size=10)
    states = expert_batch[..., :4].to(device)  # [B,T,A,N,4]
    _, _, z_hat_tp1, z_tp1 = model(states)

    loss = F.mse_loss(z_hat_tp1, z_tp1)

    opt.zero_grad()
    loss.backward()
    opt.step()

    print(f"epoch {epoch}: loss={loss.item():.6f}")

epoch 1: loss=0.012030
epoch 2: loss=0.010073
epoch 3: loss=0.008443
epoch 4: loss=0.007064
epoch 5: loss=0.005863
epoch 6: loss=0.004832
epoch 7: loss=0.003960
epoch 8: loss=0.003226
epoch 9: loss=0.002629
epoch 10: loss=0.002142
epoch 11: loss=0.001747
epoch 12: loss=0.001430
epoch 13: loss=0.001176
epoch 14: loss=0.000986
epoch 15: loss=0.000825
epoch 16: loss=0.000693
epoch 17: loss=0.000593
epoch 18: loss=0.000490
epoch 19: loss=0.000407
epoch 20: loss=0.000346
epoch 21: loss=0.000292
epoch 22: loss=0.000243
epoch 23: loss=0.000217
epoch 24: loss=0.000186
epoch 25: loss=0.000171
epoch 26: loss=0.000153
epoch 27: loss=0.000136
epoch 28: loss=0.000129
epoch 29: loss=0.000118
epoch 30: loss=0.000115
epoch 31: loss=0.000100
epoch 32: loss=0.000093
epoch 33: loss=0.000087
epoch 34: loss=0.000085
epoch 35: loss=0.000080
epoch 36: loss=0.000081
epoch 37: loss=0.000080
epoch 38: loss=0.000078
epoch 39: loss=0.000073
epoch 40: loss=0.000074
epoch 41: loss=0.000068
epoch 42: loss=0.000061
e

### Encoder-Only mit VicReg

In [47]:
class NeighborPooling(nn.Module):
    def __init__(self, features=4, embd_dim=32):
        super().__init__()

        in_dim = features * 2

        self.embed = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.LeakyReLU(0.1),
            nn.Linear(64, embd_dim),
            nn.LeakyReLU(0.1),
        )

        self.attention = Attention(in_dim)

    def forward(self, states, neigh_mask=None, feat_mask=None):
        if feat_mask is None:
            feat_mask = torch.ones_like(states)

        cat_states = torch.cat([states, feat_mask], dim=-1)

        embed = self.embed(cat_states)
        weights_logit = self.attention(cat_states)

        if neigh_mask is not None:
            weights_logit = weights_logit.masked_fill(neigh_mask == 0, float("-inf"))


        weights = torch.softmax(weights_logit, dim=2)
        pooled = (embed * weights).sum(dim=2)

        return pooled
    

class AgentEmbedding(nn.Module):
    def __init__(self, embd_dim=32, z=32):
        super().__init__()

        self.embed = nn.Sequential(
            nn.Linear(embd_dim, 64),
            nn.LeakyReLU(0.1),
            nn.Linear(64, z)
        )

    def forward(self, pooled_embd):
        embed = self.embed(pooled_embd)
        return embed


class TransitionEncoder(nn.Module):
    def __init__(self, features=4, embd_dim=32, z=32):
        super().__init__()
        self.z = z
        self.neigh_pooling = NeighborPooling(features=features, embd_dim=embd_dim)
        self.agent_pooling = AgentEmbedding(embd_dim=embd_dim, z=z)

    def forward(self, states, neigh_mask=None, feat_mask=None):
        batch, frames, agents, neigh, features = states.shape
        flat = states.reshape(batch * frames, agents, neigh, features)

        if neigh_mask is not None:
            neigh_mask = neigh_mask.reshape(batch * frames, agents, neigh, 1)

        if feat_mask is not None:
            feat_mask = feat_mask.expand(batch, frames, agents, neigh, features)
            feat_mask = feat_mask.reshape(batch * frames, agents, neigh, features) 

        pooled_embd = self.neigh_pooling(flat, neigh_mask=neigh_mask, feat_mask=feat_mask)
        pooled = pooled_embd.view(batch, frames, agents, -1)

        z_state = self.agent_pooling(pooled)
        
        z_t   = z_state[:, :-1]
        z_tp1 = z_state[:,  1:]
        dz = z_tp1 - z_t
        transition_feature = torch.cat([z_t, dz], dim=-1)
        return z_state, transition_feature

In [48]:
class TrajectoryAugmentation(nn.Module):
    def __init__(self, noise_std=0.01, neigh_drop=0.10, feat_drop=0.05):
        super().__init__()
        self.noise_std = noise_std
        self.neigh_drop = neigh_drop
        self.feat_drop = feat_drop

    def forward(self, states):
        states = states.clone()
        batch, frames, agents, neigh, features = states.shape
        device = states.device

        if self.noise_std > 0:
            states = states + torch.randn_like(states) * self.noise_std

        neigh_mask = torch.ones((batch, frames, agents, neigh, 1), device=device, dtype=states.dtype)
        feat_mask  = torch.ones((batch, frames, agents, neigh, features), device=device, dtype=states.dtype)

        if self.neigh_drop > 0:
            neigh_mask = (torch.rand(batch, frames, agents, neigh, 1, device=device) > self.neigh_drop).float()

        if self.feat_drop > 0:
            feat_mask = (torch.rand(batch, frames, agents, 1, features, device=device) > self.feat_drop).float()

        return states, neigh_mask, feat_mask


class VicRegProjector(nn.Module):
    def __init__(self, input_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 128))
        
    def forward(self, states):
        return self.net(states)


# https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py
def off_diagonal(tensor):
    dim = tensor.size(0)
    mask = ~torch.eye(dim, dtype=torch.bool, device=tensor.device)
    return tensor[mask]

def vicreg_loss(z1, z2, sim_coeff=25.0, std_coeff=15.0, cov_coeff=5.0, eps=1e-4):

    sim_loss = torch.mean((z1 - z2) ** 2)

    z1 = z1 - z1.mean(dim=0)
    z2 = z2 - z2.mean(dim=0)

    std_z1 = torch.sqrt(z1.var(dim=0) + eps)
    std_z2 = torch.sqrt(z2.var(dim=0) + eps)
    std_loss = torch.mean(F.relu(1.0 - std_z1)) + torch.mean(F.relu(1.0 - std_z2))

    batch, dim = z1.shape
    cov_z1 = (z1.T @ z1) / (batch - 1)
    cov_z2 = (z2.T @ z2) / (batch - 1)
    cov_loss = (off_diagonal(cov_z1).pow(2).sum() / dim) + (off_diagonal(cov_z2).pow(2).sum() / dim)

    loss = sim_coeff * sim_loss + std_coeff * std_loss + cov_coeff * cov_loss
    logs = {
        "sim": sim_loss.detach(),
        "std": std_loss.detach(),
        "cov": cov_loss.detach(),
        "std_mean": 0.5 * (std_z1.mean().detach() + std_z2.mean().detach()),
    }
    return loss, logs


In [49]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

aug_states = TrajectoryAugmentation(noise_std=0.01, neigh_drop=0.10, feat_drop=0.05).to(device)
encoder = TransitionEncoder(features=4, embd_dim=32, z=32).to(device)
projector = VicRegProjector(input_dim=64).to(device)

opt = torch.optim.Adam(list(encoder.parameters()) + list(projector.parameters()), lr=1e-3, weight_decay=1e-6)

_, exp_tensor, _ = run_couzin_simulation(visualization="off", max_steps=500, alpha=0.01, 
                                         constant_speed=5, shark_speed=5, 
                                         area_width=50, area_height=50,
                                         number_of_sharks=0, n=32)

epochs = 1000

for epoch in range(1, epochs + 1):
    encoder.train()
    projector.train()

    expert_batch = sample_data(exp_tensor, consecutive_frames=10, batch_size=64)
    states = expert_batch[..., :4].to(device)  # [B,T,A,N,4]

    x1, neigh_mask1, feat_mask1 = aug_states(states)
    x2, neigh_mask2, feat_mask2 = aug_states(states)

    z_state1, trans1 = encoder(x1, neigh_mask=neigh_mask1, feat_mask=feat_mask1)
    z_state2, trans2 = encoder(x2, neigh_mask=neigh_mask2, feat_mask=feat_mask2)
    
    r1 = trans1.reshape(-1, trans1.size(-1))  # [-1, 64]
    r2 = trans2.reshape(-1, trans2.size(-1))  # [-1, 64]

    y1 = projector(r1)
    y2 = projector(r2)

    loss, logs = vicreg_loss(y1, y2)

    opt.zero_grad(set_to_none=True)
    loss.backward()
    nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(projector.parameters()), 1.0)
    opt.step()

    if epoch % 10 == 0:
        print(f"epoch {epoch:03d}: loss={loss.item():.6f} "
              f"sim={logs['sim'].item():.4f} std={logs['std'].item():.4f} "
              f"cov={logs['cov'].item():.4f} std_mean={logs['std_mean'].item():.3f}")


epoch 010: loss=23.844624 sim=0.0497 std=1.4030 cov=0.3113 std_mean=0.299
epoch 020: loss=22.559509 sim=0.0338 std=1.3162 cov=0.3945 std_mean=0.342
epoch 030: loss=21.826788 sim=0.0222 std=1.2564 cov=0.4852 std_mean=0.372
epoch 040: loss=21.201599 sim=0.0201 std=1.2302 cov=0.4494 std_mean=0.385
epoch 050: loss=20.696140 sim=0.0195 std=1.1682 cov=0.5372 std_mean=0.416
epoch 060: loss=20.465712 sim=0.0198 std=1.1382 cov=0.5794 std_mean=0.431
epoch 070: loss=20.216318 sim=0.0201 std=1.1264 cov=0.5634 std_mean=0.437
epoch 080: loss=20.152771 sim=0.0224 std=1.1000 cov=0.6187 std_mean=0.450
epoch 090: loss=20.098295 sim=0.0233 std=1.0927 cov=0.6250 std_mean=0.454
epoch 100: loss=20.078236 sim=0.0257 std=1.0913 cov=0.6134 std_mean=0.454
epoch 110: loss=19.930140 sim=0.0245 std=1.0828 cov=0.6152 std_mean=0.459
epoch 120: loss=19.780333 sim=0.0243 std=1.0738 cov=0.6131 std_mean=0.463
epoch 130: loss=19.700085 sim=0.0237 std=1.0608 cov=0.6393 std_mean=0.470
epoch 140: loss=19.537148 sim=0.0227 s

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

aug = TrajectoryAugmentation(noise_std=0.01, neigh_drop=0.10, feat_drop=0.05).to(device)
encoder = TransitionEncoder(features=4, embd_dim=32, z=32).to(device)
projector = VicRegProjector(input_dim=64).to(device)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(projector.parameters()), lr=1e-3, weight_decay=1e-6)

_, exp_tensor, _ = run_couzin_simulation(visualization="off", max_steps=500, alpha=0.01, 
                                         constant_speed=5, shark_speed=5, 
                                         area_width=50, area_height=50,
                                         number_of_sharks=0, n=32)

epochs = 1000

def train_encoder(encoder, projector, aug, exp_tensor, epochs, optimizer, device):
    for epoch in range(1, epochs + 1):
        encoder.train()
        projector.train()

        expert_batch = sample_data(exp_tensor, consecutive_frames=10, batch_size=64)
        states = expert_batch[..., :4].to(device)  # [B,T,A,N,4]

        x1, neigh_mask1, feat_mask1 = aug_states(states)
        x2, neigh_mask2, feat_mask2 = aug_states(states)

        z_state1, trans1 = encoder(x1, neigh_mask=neigh_mask1, feat_mask=feat_mask1)
        z_state2, trans2 = encoder(x2, neigh_mask=neigh_mask2, feat_mask=feat_mask2)
        
        r1 = trans1.reshape(-1, trans1.size(-1))  # [-1, 64]
        r2 = trans2.reshape(-1, trans2.size(-1))  # [-1, 64]

        y1 = projector(r1)
        y2 = projector(r2)

        loss, logs = vicreg_loss(y1, y2)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(projector.parameters()), 1.0)
        opt.step()

        if epoch % 10 == 0:
            print(f"epoch {epoch:03d}: loss={loss.item():.6f} "
                f"sim={logs['sim'].item():.4f} std={logs['std'].item():.4f} "
                f"cov={logs['cov'].item():.4f} std_mean={logs['std_mean'].item():.3f}")


In [52]:
expert_batch = sample_data(exp_tensor, consecutive_frames=10, batch_size=64)
print("Sampled expert batch shape:", expert_batch.shape)
states = expert_batch[..., :4].to(device)
print("States shape:", states.shape)

Sampled expert batch shape: torch.Size([64, 10, 32, 31, 5])
States shape: torch.Size([64, 10, 32, 31, 4])


In [50]:
torch.save(encoder, r'..\models\encoder.pt')