In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

In [8]:
nn.Linear(10, 5)

Linear(in_features=10, out_features=5, bias=True)

In [7]:
# ---------- utils ----------
def mlp(sizes, act=nn.ReLU, out_act=None):
    layers = []
    for i in range(len(sizes)-1):
        layers += [nn.Linear(sizes[i], sizes[i+1])]
        if i < len(sizes)-2:
            layers += [act()]
        elif out_act is not None:
            layers += [out_act()]
    return nn.Sequential(*layers)

In [5]:
output = mlp([3, 64, 64, 1], act=nn.ReLU, out_act=None)
print(output)

Sequential(
  (0): Linear(in_features=3, out_features=64, bias=True)
  (1): ReLU()
  (2): Linear(in_features=64, out_features=64, bias=True)
  (3): ReLU()
  (4): Linear(in_features=64, out_features=1, bias=True)
)


In [19]:
x = [1, 2, 3, 4]
print(x[-2])

3


In [21]:
x = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(x, -2)

tensor([[1, 2, 3, 4]])

In [20]:
x = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(x, -2)

tensor([[1, 2, 3, 4]])

In [None]:
# ---------- relational energy E_theta ----------
class RelationalEnergy(nn.Module):
    """
    E_theta(x, a, w) = [ f_theta( sum_{t,i,j} sig(a_i)*sig(a_j) * g_theta(x_t_i, x_t_j, w), w ) ]^2
    x: [B, T, N, Dx]                                        ## Batch, Time, Num entities, Dim entity
    a: [B, N]    (real-valued, gated via sigmoid inside) 
    w: [B, Dw]                                              ## Batch, Dim concept code
    returns scalar energy [B]
    """
    def __init__(self, Dx, Dw, hidden=128, T_weighted=True):
        super().__init__()
        self.Dx, self.Dw = Dx, Dw
        # g takes concatenated (x_i, x_j, w)
        self.g = mlp([2*Dx + Dw, hidden, hidden, hidden])
        # f takes pooled sum over pairs (+ w)
        self.f = mlp([hidden + Dw, hidden, hidden, 1])  # scalar before square
        self.T_weighted = T_weighted

    def forward(self, x, a, w):
        B, T, N, Dx = x.shape
        sig_a = torch.sigmoid(a)                          # [B, N]
        # prepare pairwise masks m_ij = sig(a_i) * sig(a_j)
        m = sig_a.unsqueeze(2) * sig_a.unsqueeze(1)       # [B, N, N]
        # expand across time
        m = m.unsqueeze(1).expand(B, T, N, N)             # [B, T, N, N]

        # make pairwise feature tensor: [B,T,N,N, 2*Dx + Dw]
        xi = x.unsqueeze(3).expand(B, T, N, N, Dx)
        xj = x.unsqueeze(2).expand(B, T, N, N, Dx)
        w_exp = w.view(B,1,1,1,-1).expand(B,T,N,N,-1)
        pair_feat = torch.cat([xi, xj, w_exp], dim=-1)    # [..., 2Dx+Dw]

        # apply g to each pair
        g_ij = self.g(pair_feat)                          # [B,T,N,N,H]
        # gate by m
        g_ij = g_ij * m.unsqueeze(-1)                     # [B,T,N,N,H]
        # sum over pairs and time
        pooled = g_ij.sum(dim=(1,2,3))                    # [B,H]

        # fuse with w and score
        f_in = torch.cat([pooled, w], dim=-1)             # [B,H+Dw]
        out = self.f(f_in).squeeze(-1)                    # [B]
        energy = out.pow(2)                               # non-negative
        return energy

# ---------- SGLD samplers ----------
def sgld_step(var, grad, alpha):
    noise = torch.randn_like(var) * (alpha ** 0.5)
    return var + 0.5 * alpha * grad + noise

@torch.no_grad()
def sgld_optimize_x(E, x_init, a, w, steps=10, alpha=1e-2):
    x = x_init.clone().requires_grad_(True)
    for _ in range(steps):
        E_x = E(x, a, w).sum()
        grad, = torch.autograd.grad(E_x, x, create_graph=False)
        x = sgld_step(x, grad, alpha).detach().requires_grad_(True)
    return x.detach()

@torch.no_grad()
def sgld_optimize_a(E, x, a_init, w, steps=10, alpha=1e-2):
    a = a_init.clone().requires_grad_(True)
    for _ in range(steps):
        E_a = E(x, a, w).sum()
        grad, = torch.autograd.grad(E_a, a, create_graph=False)
        a = sgld_step(a, grad, alpha).detach().requires_grad_(True)
    return a.detach()

# ---------- execution-time concept inference (inner loop) ----------
def infer_concept_codes(E, demos, Dw, steps=10, lr=0.1):
    """
    demos: dict with tensors from X_demo:
      x0_demo: [B,T,N,Dx], x1_demo: [B,T,N,Dx], a_demo: [B,N]
    Returns: w_x, w_a each [B, Dw]
    """
    B = demos['x0'].shape[0]
    w_x = torch.randn(B, Dw, device=demos['x0'].device, requires_grad=True)
    w_a = torch.randn(B, Dw, device=demos['x0'].device, requires_grad=True)
    opt = torch.optim.SGD([w_x, w_a], lr=lr)

    for _ in range(steps):
        opt.zero_grad()
        Ex = E(demos['x1'], demos['a'], w_x)   # generation consistency
        Ea = E(demos['x0'], demos['a'], w_a)   # identification consistency
        loss = (Ex + Ea).mean()
        loss.backward()
        opt.step()
    return w_x.detach(), w_a.detach()

# ---------- training step (outer loop) ----------
def training_step(E, batch_demo, batch_train, opt_theta, K=10, alpha=1e-2, lam=1.0):
    """
    batch_demo / batch_train contain x0, x1, a (same shapes)
    """
    # 1) infer concept codes from demos (stop-grad w.r.t theta, per paper’s simplification)
    with torch.no_grad():
        w_x, w_a = infer_concept_codes(E, batch_demo, Dw=E.Dw if hasattr(E, 'Dw') else E.f[0].in_features, steps=K)

    # 2) sample negatives via SGLD
    x0, x1, a = batch_train['x0'], batch_train['x1'], batch_train['a']
    a_init = torch.randn_like(a)
    x_tilde = sgld_optimize_x(E, x0, a, w_x, steps=K, alpha=alpha)
    a_tilde = sgld_optimize_a(E, x0, a_init, w_a, steps=K, alpha=alpha)

    # 3) compute losses (contrastive + KL)
    Ex_pos = E(x1, a, w_x)
    Ex_neg = E(x_tilde, a, w_x)
    La_pos = E(x0, a, w_a)
    La_neg = E(x0, a_tilde, w_a)

    Lx = F.softplus(Ex_pos - Ex_neg).mean()
    La = F.softplus(La_pos - La_neg).mean()
    L_ml = Lx + La

    L_kl = (Ex_neg + La_neg).mean()  # entropy terms omitted (constant with fixed noise)

    loss = L_ml + lam * L_kl

    opt_theta.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(E.parameters(), 1.0)
    opt_theta.step()

    return {
        'loss': loss.item(),
        'L_ml': L_ml.item(),
        'L_kl': L_kl.item(),
        'Ex_pos': Ex_pos.mean().item(),
        'Ex_neg': Ex_neg.mean().item()
    }


In [None]:
def sample_line_concept(B, T, N, Dx=2, noise=0.01, length=1.0, device='cpu'):
    # x0: random positions in [-1,1]^2
    x0 = torch.rand(B, T, N, Dx, device=device)*2-1
    # choose k attended entities (e.g., k=4)
    k = 4
    a = torch.zeros(B, N, device=device)
    for b in range(B):
        idx = torch.randperm(N)[:k]
        a[b, idx] = 3.0   # unbounded; model will pass through sigmoid
        # make those k near a random line at t=1
        p0 = torch.rand(2, device=device)*2-1
        direction = F.normalize(torch.rand(2, device=device)*2-1, dim=0)
        t_vals = torch.linspace(-length/2, length/2, k, device=device)
        line_pts = p0 + t_vals[:,None]*direction
        x0[b, -1, idx, :2] = line_pts + noise*torch.randn_like(line_pts)
    # make x1 = last frame; you can interpolate or set only t=1
    x1 = x0.clone()
    return {'x0': x0, 'x1': x1, 'a': a}