In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

In [3]:
from regular import RegularTransformer, RegularFeedForward
from projected import ProjectedTransformer, ProjectedFeedForward
from exponential import ExponentialTransformer, ExponentialFeedForward
from train_constrained import train_model
import projections
import exponentials

In [4]:
def check_so3(X: torch.Tensor, atol: float = 1e-3):
    assert X.dim() == 2 and X.shape[1] == 9
    R = X.view(-1, 3, 3)

    I = torch.eye(3, device=R.device, dtype=R.dtype).expand(R.shape[0], 3, 3)
    ortho_err = torch.linalg.norm(R.transpose(-1, -2) @ R - I, dim=(-2, -1))  # (B,)
    det_err = torch.abs(torch.linalg.det(R) - 1.0)                            # (B,)

    ok = (ortho_err <= atol) & (det_err <= atol)
    return {"ok": ok, "ortho_err": ortho_err, "det_err": det_err}

def check_so3_flat(R9: torch.Tensor, tol: float = 1e-4):
    """
    R9: [B, 9] or [B, S, 9]
    Returns: (ok_mask, max_orth_err, max_det_err)
    """
    orig = R9.shape
    if R9.dim() == 3:
        B, S, _ = orig
        R = R9.reshape(B*S, 3, 3)
    else:
        B = orig[0]
        R = R9.reshape(B, 3, 3)

    I = torch.eye(3, device=R.device, dtype=R.dtype).expand(R.shape[0], 3, 3)
    orth_err = torch.linalg.norm(R.transpose(-1, -2) @ R - I, dim=(-2, -1))
    det_err  = torch.abs(torch.linalg.det(R) - 1.0)

    ok = (orth_err <= tol) & (det_err <= tol)
    if R9.dim() == 3:
        ok = ok.reshape(B, S)

    return ok, orth_err.max().item(), det_err.max().item()


def check_se3_flat(G16: torch.Tensor, tol_R: float = 1e-4, tol_last: float = 1e-6):
    """
    G16: [B, 16] (row-major flatten of 4x4).
    Checks:
      - top-left 3x3 is in SO(3) (via check_so3_flat)
      - last row equals [0,0,0,1]
    Returns: dict with masks + max errors.
    """
    B = G16.shape[0]
    G = G16.reshape(B, 4, 4)

    R = G[:, :3, :3].reshape(B, 9)          # [B,9]
    ok_R, max_orth, max_det = check_so3_flat(R, tol=tol_R)

    last = G[:, 3, :]                        # [B,4]
    target = torch.tensor([0., 0., 0., 1.], device=G.device, dtype=G.dtype).expand_as(last)
    last_err = torch.max(torch.abs(last - target), dim=-1).values  # [B]
    ok_last = last_err <= tol_last

    ok_all = ok_R & ok_last

    return dict(
        ok=ok_all,
        ok_R=ok_R,
        ok_last_row=ok_last,
        max_orth_err=max_orth,
        max_det_err=max_det,
        max_last_row_err=last_err.max().item(),
    )

In [5]:
loaded = torch.load("./../Data/so3_dataset.pt", map_location="cpu", weights_only = False)

for k, v in loaded.items():
    try:
        print(k, v.shape)
    except AttributeError:
        print(k, type(v))

X_train = torch.tensor(loaded['X_train'], dtype = torch.float32)
Y_train = torch.tensor(loaded['Y_train'], dtype = torch.float32)

for name in ["X_train","Y_train","X_val","Y_val","X_test","Y_test"]:
    assert torch.isfinite(torch.tensor(loaded[name])).all(), f"Found NaN/Inf in {name}"

X_train (4000, 9)
Y_train (4000, 9)
X_val (800, 9)
Y_val (800, 9)
X_test (1200, 9)
Y_test (1200, 9)


In [6]:
model = ExponentialFeedForward(9,3,3, exp_func = exponentials.so3, use_internal_exponential = True)

In [190]:
out = check_se3_flat(Y_train, tol_R=1e-4, tol_last=1e-6)
print(out["ok"].float().mean(), out["max_orth_err"], out["max_det_err"], out["max_last_row_err"])

tensor(0.9945) 1.4141149520874023 0.9881886839866638 0.0


In [177]:
out = check_so3(X_train, atol=1e-5)
print(out["ok"].float().mean(), out["ortho_err"].max(), out["det_err"].max())

tensor(1.) tensor(2.1100e-07) tensor(2.3842e-07)


In [7]:
Y_pred = model(X_train)

In [142]:
Y_pred.square().sum(dim=1)

tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)

In [207]:
B = X_train.shape[0]
a = torch.randn(B, 6, device=X_train.device, dtype=X_train.dtype)  # trivial tangent
dt = torch.tensor(1.0, device=X_train.device, dtype=X_train.dtype)
X_step = exponentials.se3(X_step, a, dt)
out_step = check_se3_flat(X_step, tol_R=1e-4, tol_last=1e-6)
print("se3(X_train,0) ok% =", out_step["ok"].float().mean().item(),
      out_step["max_orth_err"], out_step["max_det_err"], out_step["max_last_row_err"])

se3(X_train,0) ok% = 0.9944999814033508 1.4141782522201538 0.99268639087677 0.0


In [211]:
out = check_se3_flat(Y_pred, tol_R=1e-2, tol_last=1e-6)
print(out["ok"].float().mean(), out["max_orth_err"], out["max_det_err"], out["max_last_row_err"])

tensor(0.9942) 1.4145879745483398 0.9926833510398865 0.0


In [170]:
out = check_so3(Y_pred, atol=1e-5)
print(out["ok"].float().mean(), out["ortho_err"].max(), out["det_err"].max())

tensor(1.) tensor(1.6745e-06, grad_fn=<MaxBackward1>) tensor(1.4305e-06, grad_fn=<MaxBackward1>)


In [140]:
assert torch.isfinite(Y_pred).all(), f"Found NaN/Inf in {name}"

In [141]:
Y_pred.shape

torch.Size([4000, 3])

# Overfit Test

In [8]:
from torch.utils.data import TensorDataset, DataLoader

In [20]:
# tensors
X_train = torch.tensor(loaded["X_train"], dtype=torch.float32)
Y_train = torch.tensor(loaded["Y_train"], dtype=torch.float32)
X_val   = torch.tensor(loaded["X_val"],   dtype=torch.float32)
Y_val   = torch.tensor(loaded["Y_val"],   dtype=torch.float32)

# datasets
train_ds = TensorDataset(X_train[:16,:], Y_train[:16,:])
val_ds   = TensorDataset(X_val, Y_val)

# loaders
batch_size = 256  # change if you want

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False)

In [None]:
# tensors
X_train = torch.tensor(loaded["X_train"], dtype=torch.float32)
Y_train = torch.tensor(loaded["Y_train"], dtype=torch.float32)
X_val   = torch.tensor(loaded["X_val"],   dtype=torch.float32)
Y_val   = torch.tensor(loaded["Y_val"],   dtype=torch.float32)

# compute tau from TRAIN translations (no leakage)
G = X_train.view(-1, 4, 4)
t = G[:, :3, 3]
tau = t.std()   # scalar tensor

# normalize translation column (rows 0..2, col 3) in ALL splits
def normalize_se3_translation(X16: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:
    G = X16.view(-1, 4, 4).clone()
    G[:, :3, 3] = G[:, :3, 3] / tau
    return G.view(-1, 16)

X_train_n = normalize_se3_translation(X_train, tau)
Y_train_n = normalize_se3_translation(Y_train, tau)
X_val_n   = normalize_se3_translation(X_val,   tau)
Y_val_n   = normalize_se3_translation(Y_val,   tau)

# datasets (use normalized tensors)
train_ds = TensorDataset(X_train_n[:16, :], Y_train_n[:16, :])
val_ds   = TensorDataset(X_val_n, Y_val_n)

# loaders
batch_size = 256
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False)


In [26]:
model = ExponentialFeedForward(16,6,3, exp_func = exponentials.se3, dropout = 0.0, dt = 1)

In [27]:
model, logs = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    lr=1e-3,
    num_epochs=1000,
    device='cuda',
    weight_decay=0,
    early_stop=0,      # disable early stopping
    verbose=True
)

epoch    1 | train 5.511e-01 | val 4.384e-01 | lr 1.0e-03
epoch   20 | train 3.941e-01 | val 4.233e-01 | lr 1.0e-03
epoch   40 | train 3.627e-01 | val 4.220e-01 | lr 1.0e-03
epoch   60 | train 3.400e-01 | val 4.095e-01 | lr 1.0e-03
epoch   80 | train 3.205e-01 | val 4.044e-01 | lr 1.0e-03
epoch  100 | train 3.033e-01 | val 4.022e-01 | lr 1.0e-03
epoch  120 | train 2.880e-01 | val 4.031e-01 | lr 1.0e-03
epoch  140 | train 2.736e-01 | val 4.086e-01 | lr 1.0e-03
epoch  160 | train 2.578e-01 | val 4.175e-01 | lr 1.0e-03
epoch  180 | train 2.400e-01 | val 4.284e-01 | lr 1.0e-03
epoch  200 | train 2.202e-01 | val 4.425e-01 | lr 1.0e-03
epoch  220 | train 2.009e-01 | val 4.556e-01 | lr 1.0e-03
epoch  240 | train 1.838e-01 | val 4.727e-01 | lr 1.0e-03
epoch  260 | train 1.666e-01 | val 4.935e-01 | lr 1.0e-03
epoch  280 | train 1.493e-01 | val 5.084e-01 | lr 1.0e-03
epoch  300 | train 1.324e-01 | val 5.210e-01 | lr 1.0e-03
epoch  320 | train 1.152e-01 | val 5.355e-01 | lr 1.0e-03
epoch  340 | t

# Flow Matching

In [113]:
flow_proj_dict = torch.load("./outputs/sphere_dataset/alpha1.0/lr1e-3_wd1e-4/model.pt")
torch.load("./outputs/sphere_dataset/alpha1.0/lr1e-3_wd1e-4/meta.pt")

{'dataset': '/tmp/sonthal/fm_cache/sphere_dataset.pt',
 'train_shape': (4000, 3),
 'val_shape': (800, 3),
 'hparams': {'dataset': '/tmp/sonthal/fm_cache/sphere_dataset.pt',
  'outdir': 'outputs/sphere_dataset/alpha1.0/lr1e-3_wd1e-4',
  'hidden_dim': 256,
  'num_layers': 8,
  'lr': 0.001,
  'weight_decay': 0.0001,
  'num_epochs': 2000,
  'batch_size': 256,
  'early_stop': 1000,
  'seed': 0,
  'device': 'cuda',
  'scheduler_patience': 100,
  'scheduler_factor': 0.5,
  'T': 'auto',
  'alpha': 1.0,
  'num_timesteps': 30,
  'velocity_scale': 0.5,
  'velocity_cov_scale': 1.0,
  'T_used': 2.0,
  'r_median': 1.0},
 'best_val_loss': 0.05418527197647602,
 'best_epoch': 6,
 'epochs_ran': 1007}

In [104]:
flow_proj_dict['state_dict'].keys()

odict_keys(['net.0.weight', 'net.0.bias', 'net.1.weight', 'net.1.bias', 'net.4.weight', 'net.4.bias', 'net.5.weight', 'net.5.bias', 'net.8.weight', 'net.8.bias', 'net.9.weight', 'net.9.bias', 'net.12.weight', 'net.12.bias', 'net.13.weight', 'net.13.bias', 'net.16.weight', 'net.16.bias', 'net.17.weight', 'net.17.bias', 'net.20.weight', 'net.20.bias', 'net.21.weight', 'net.21.bias', 'net.24.weight', 'net.24.bias', 'net.25.weight', 'net.25.bias', 'net.28.weight', 'net.28.bias', 'net.29.weight', 'net.29.bias', 'net.32.weight', 'net.32.bias', 'net.33.weight', 'net.33.bias', 'net.36.weight', 'net.36.bias'])

In [120]:
from flow_matching import FlowVelocityNet, flow_matching_projection

In [142]:
flow_proj = FlowVelocityNet(input_dim=3, hidden_dim=256, num_layers=8).to('cuda')
flow_proj.load_state_dict(flow_proj_dict['state_dict'])
flow_proj = flow_proj.to('cuda')
flow_proj.eval();

In [143]:
def proj_flow(x):
    return flow_matching_projection(x, flow_proj)

In [144]:
loaded = torch.load("./../Data/sphere_dataset.pt", map_location="cpu", weights_only = False)

for k, v in loaded.items():
    try:
        print(k, v.shape)
    except AttributeError:
        print(k, type(v))

X_train = torch.tensor(loaded['X_train'], dtype = torch.float32)
Y_train = torch.tensor(loaded['Y_train'], dtype = torch.float32)

for name in ["X_train","Y_train","X_val","Y_val","X_test","Y_test"]:
    assert torch.isfinite(torch.tensor(loaded[name])).all(), f"Found NaN/Inf in {name}"

X_train torch.Size([4000, 3])
Y_train (4000, 3)
X_val torch.Size([800, 3])
Y_val (800, 3)
X_test torch.Size([1200, 3])
Y_test (1200, 3)


  X_train = torch.tensor(loaded['X_train'], dtype = torch.float32)
  assert torch.isfinite(torch.tensor(loaded[name])).all(), f"Found NaN/Inf in {name}"


In [161]:
# tensors
X_train = torch.tensor(loaded["X_train"], dtype=torch.float32)
Y_train = torch.tensor(loaded["Y_train"], dtype=torch.float32)
X_val   = torch.tensor(loaded["X_val"],   dtype=torch.float32)
Y_val   = torch.tensor(loaded["Y_val"],   dtype=torch.float32)

# # compute tau from TRAIN translations (no leakage)
# G = X_train.view(-1, 4, 4)
# t = G[:, :3, 3]
# tau = t.std()   # scalar tensor

# # normalize translation column (rows 0..2, col 3) in ALL splits
# def normalize_se3_translation(X16: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:
#     G = X16.view(-1, 4, 4).clone()
#     G[:, :3, 3] = G[:, :3, 3] / tau
#     return G.view(-1, 16)

# X_train_n = normalize_se3_translation(X_train, tau)
# Y_train_n = normalize_se3_translation(Y_train, tau)
# X_val_n   = normalize_se3_translation(X_val,   tau)
# Y_val_n   = normalize_se3_translation(Y_val,   tau)

# datasets (use normalized tensors)
train_ds = TensorDataset(X_train, Y_train)
val_ds   = TensorDataset(X_val, Y_val)

# loaders
batch_size = 256
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False)


  X_train = torch.tensor(loaded["X_train"], dtype=torch.float32)
  X_val   = torch.tensor(loaded["X_val"],   dtype=torch.float32)


In [162]:
model = ProjectedFeedForward(3,3,3,proj_func = proj_flow, dropout = 0.0)

In [163]:
model, logs = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    lr=1e-3,
    num_epochs=1000,
    device='cuda',
    weight_decay=0,
    early_stop=0,      # disable early stopping
    verbose=True
)

epoch    1 | train 3.669e-01 | val 3.289e-01 | lr 1.0e-03
epoch   20 | train 1.190e-01 | val 1.172e-01 | lr 1.0e-03
epoch   40 | train 8.880e-02 | val 8.906e-02 | lr 1.0e-03
epoch   60 | train 7.462e-02 | val 7.368e-02 | lr 1.0e-03
epoch   80 | train 6.268e-02 | val 6.159e-02 | lr 1.0e-03
epoch  100 | train 4.974e-02 | val 4.528e-02 | lr 1.0e-03
epoch  120 | train 4.385e-02 | val 3.952e-02 | lr 1.0e-03
epoch  140 | train 4.011e-02 | val 3.475e-02 | lr 1.0e-03
epoch  160 | train 3.543e-02 | val 3.015e-02 | lr 1.0e-03
epoch  180 | train 3.081e-02 | val 2.786e-02 | lr 1.0e-03
epoch  200 | train 2.849e-02 | val 2.723e-02 | lr 1.0e-03
epoch  220 | train 2.666e-02 | val 2.674e-02 | lr 1.0e-03
epoch  240 | train 2.486e-02 | val 2.536e-02 | lr 1.0e-03
epoch  260 | train 2.522e-02 | val 2.509e-02 | lr 1.0e-03
epoch  280 | train 2.247e-02 | val 2.206e-02 | lr 1.0e-03
epoch  300 | train 2.298e-02 | val 2.247e-02 | lr 1.0e-03
epoch  320 | train 2.261e-02 | val 2.213e-02 | lr 1.0e-03
epoch  340 | t

In [168]:
(model.to('cpu')(X_train.to('cpu')).square().sum(dim=1) - 1).abs().mean()

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [127]:
x0 = X_train[:8].to('cuda').clone().requires_grad_(True)   # pick a tiny batch
y  = proj_flow(x0)

loss = (y**2).sum()   # any scalar function works
loss.backward()

print("x0.grad is None? ", x0.grad is None)
print("||x0.grad||:      ", x0.grad.norm().item())
print("max |grad|:       ", x0.grad.abs().max().item())

x0.grad is None?  False
||x0.grad||:       5.557411193847656
max |grad|:        1.9498889446258545


In [138]:
eps = 1e-2
i = 0  # one sample
j = 0  # one coordinate

device = 'cuda'

x0 = X_train[:2].to(device).clone().requires_grad_(True)
y  = proj_flow(x0+torch.randn_like(x0))
f  = y[0,0]            # scalar
f.backward()
g_autograd = x0.grad[i,j].item()

with torch.no_grad():
    x_plus = x0.detach().clone()
    x_minus = x0.detach().clone()
    x_plus[i,j] += eps
    x_minus[i,j] -= eps
    f_plus  = proj_flow(x_plus)[0,0].item()
    f_minus = proj_flow(x_minus)[0,0].item()
g_fd = (f_plus - f_minus)/(2*eps)

print("autograd:", g_autograd)
print("fd:      ", g_fd)


autograd: 0.7022574543952942
fd:       0.8332699537277222
