In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


  cpu = _conversion_method_template(device=torch.device("cpu"))


Inputs at time t:  
  $N_{t-1}$ (prob vector, length $m$)  
  $y_t$ (vector, length $m$)  
  $r_t$ (vector, length $m$)  
  $g_{t-1}$ (vector, length $m$)  

 Concatenate all  vector of length $4m$


| MLP (2 layers)        | 
|-----------------------|
| Input dim: 4m         |
| Hidden layers: e.g.128 |
| Output dim: m         |
| Outputs: Δ logits     |
| (delta_logits)        |



Previous logits:  
$$
\ell_{t-1} = \log(N_{t-1} + \varepsilon)
$$

Add residual:  
$$
\ell_t = \ell_{t-1} + \Delta \text{logits}
$$



$$
\text{Softmax}(\ell_t) \to N_t  \quad \text{(probability distribution at time t)}
$$

 \( N_t \) is then used in:  
  - \( g_t \) update:  
  $$
  g_t = \alpha \odot g_{t-1} + (N_t - \alpha \odot N_{t-1}) \odot y_t
  $$
  - next time step input \( N_{t} \)


In [3]:

# ----------------------------------------------------------------------
# F network: predicts changes in logits (residuals) instead of full logits
# ----------------------------------------------------------------------
class FNetResidual(nn.Module):
    def __init__(self, m, hidden=128):
        super().__init__()
        self.fc1 = nn.Linear(4 * m, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.out = nn.Linear(hidden, m)  # delta logits

    def forward(self, N_prev, y_t, r_t, g_prev):
        # Concatenate all inputs
        x = torch.cat([N_prev, y_t, r_t, g_prev], dim=-1)
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        delta_logits = self.out(h)

        # Previous logits from probabilities
        prev_logits = torch.log(N_prev + 1e-9)
        logits = prev_logits + delta_logits

        # Softmax to get new probability vector
        N_t = F.softmax(logits, dim=-1)
        return N_t, logits



In [4]:

# ----------------------------------------------------------------------
# Full dynamics model with fixed alpha and residual logits
# ----------------------------------------------------------------------
class DynamicsModelResidual(nn.Module):
    def __init__(self, m, alpha_values, hidden=128):
        super().__init__()
        self.F = FNetResidual(m, hidden)
        # Fixed alpha, registered as buffer so it moves with the model
        self.register_buffer("alpha", torch.tensor(alpha_values, dtype=torch.float32))

    def forward(self, N0, g0, ys, rs):
        """
        N0, g0: (batch, m) initial states
        ys, rs: (T, batch, m) exogenous inputs
        Returns:
            Ns: (T, batch, m)
            gs: (T, batch, m)
            logits_all: (T, batch, m)
        """
        N_prev = N0
        g_prev = g0
        Ns, gs, logits_all = [], [], []

        for t in range(ys.size(0)):
            y_t = ys[t]
            r_t = rs[t]

            # Predict N_t with residual logits
            N_t, logits = self.F(N_prev, y_t, r_t, g_prev)

            # Elementwise g_t update
            g_t = self.alpha * g_prev + (N_t - self.alpha * N_prev) * y_t

            # Store
            Ns.append(N_t)
            gs.append(g_t)
            logits_all.append(logits)

            # Update for next step
            N_prev = N_t
            g_prev = g_t

        return torch.stack(Ns), torch.stack(gs), torch.stack(logits_all)



In [7]:

# ----------------------------------------------------------------------
# Example usage / test run
# ----------------------------------------------------------------------

T = 360       # number of time steps
m = 3        # distribution size
batch = 32
alpha_values = torch.linspace(0.1, 0.9, m)  # Example alpha values

# Dummy inputs
ys = torch.rand(T, batch, m)
rs = torch.rand(T, batch, m)
N0 = F.softmax(torch.rand(batch, m), dim=-1)
g0 = torch.zeros(batch, m)

model = DynamicsModelResidual(m, alpha_values, hidden=64)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


  self.register_buffer("alpha", torch.tensor(alpha_values, dtype=torch.float32))


In [8]:

for step in range(200):
    optimizer.zero_grad()
    Ns, gs, logits_all = model(N0, g0, ys, rs)

    # Example path-dependent loss:
    entropy = -(Ns * torch.log(Ns + 1e-9)).sum(dim=-1)  # (T, batch)
    loss = gs.pow(2).sum(dim=-1).mean() - 0.1 * entropy.mean()

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % 50 == 0:
        print(f"Step {step}: loss={loss.item():.4f}")


Step 0: loss=0.3087
Step 50: loss=-0.0112
Step 100: loss=-0.0411
Step 150: loss=-0.0641
