In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from scipy.spatial.distance import cdist


In [36]:
from probabilistic import (
    create_label_dataset,
    create_anchors,
    probabilistic_transformer_loss,
    predict,
    sample_sphere,
    sample_so3,
    sample_se3,
    train_probabilistic_model
)
from train_constrained import train_model
from regular import RegularTransformer, RegularFeedForward

In [4]:
def check_so3(X: torch.Tensor, atol: float = 1e-3):
    assert X.dim() == 2 and X.shape[1] == 9
    R = X.view(-1, 3, 3)

    I = torch.eye(3, device=R.device, dtype=R.dtype).expand(R.shape[0], 3, 3)
    ortho_err = torch.linalg.norm(R.transpose(-1, -2) @ R - I, dim=(-2, -1))  # (B,)
    det_err = torch.abs(torch.linalg.det(R) - 1.0)                            # (B,)

    ok = (ortho_err <= atol) & (det_err <= atol)
    return {"ok": ok, "ortho_err": ortho_err, "det_err": det_err}

def check_so3_flat(R9: torch.Tensor, tol: float = 1e-4):
    """
    R9: [B, 9] or [B, S, 9]
    Returns: (ok_mask, max_orth_err, max_det_err)
    """
    orig = R9.shape
    if R9.dim() == 3:
        B, S, _ = orig
        R = R9.reshape(B*S, 3, 3)
    else:
        B = orig[0]
        R = R9.reshape(B, 3, 3)

    I = torch.eye(3, device=R.device, dtype=R.dtype).expand(R.shape[0], 3, 3)
    orth_err = torch.linalg.norm(R.transpose(-1, -2) @ R - I, dim=(-2, -1))
    det_err  = torch.abs(torch.linalg.det(R) - 1.0)

    ok = (orth_err <= tol) & (det_err <= tol)
    if R9.dim() == 3:
        ok = ok.reshape(B, S)

    return ok, orth_err.max().item(), det_err.max().item()


def check_se3_flat(G16: torch.Tensor, tol_R: float = 1e-4, tol_last: float = 1e-6):
    """
    G16: [B, 16] (row-major flatten of 4x4).
    Checks:
      - top-left 3x3 is in SO(3) (via check_so3_flat)
      - last row equals [0,0,0,1]
    Returns: dict with masks + max errors.
    """
    B = G16.shape[0]
    G = G16.reshape(B, 4, 4)

    R = G[:, :3, :3].reshape(B, 9)          # [B,9]
    ok_R, max_orth, max_det = check_so3_flat(R, tol=tol_R)

    last = G[:, 3, :]                        # [B,4]
    target = torch.tensor([0., 0., 0., 1.], device=G.device, dtype=G.dtype).expand_as(last)
    last_err = torch.max(torch.abs(last - target), dim=-1).values  # [B]
    ok_last = last_err <= tol_last

    ok_all = ok_R & ok_last

    return dict(
        ok=ok_all,
        ok_R=ok_R,
        ok_last_row=ok_last,
        max_orth_err=max_orth,
        max_det_err=max_det,
        max_last_row_err=last_err.max().item(),
    )


# Test Probabilistic Model on SO(3) and SE(3) Dataset 

In [72]:
loaded = torch.load("./../Data/protein_dataset.pt", map_location="cpu", weights_only = False)

for k, v in loaded.items():
    try:
        print(k, v.shape)
    except AttributeError:
        print(k, type(v))

X_train = torch.tensor(loaded['X_train'], dtype = torch.float32)
Y_train = torch.tensor(loaded['Y_train'], dtype = torch.float32)

for name in ["X_train","Y_train","X_val","Y_val","X_test","Y_test"]:
    assert torch.isfinite(torch.tensor(loaded[name])).all(), f"Found NaN/Inf in {name}"


X_train (4000, 16)
Y_train (4000, 16)
X_val (800, 16)
Y_val (800, 16)
X_test (1200, 16)
Y_test (1200, 16)


In [84]:
# Create anchors from training data (recommended approach)
num_anchors = 100
Y_train_np = Y_train.numpy()
anchors = create_anchors(num_anchors, training_data=Y_train_np, manifold='se3')
# anchors = create_anchors(num_anchors, manifold='se3')
anchors_tensor = torch.tensor(anchors, dtype=torch.float32)
print(f"Created {num_anchors} anchors with shape {anchors.shape}")
print(anchors_tensor.shape)

--- Creating Anchors: Sampling 100 from Training Data ---
Created 100 anchors with shape (100, 16)
torch.Size([100, 16])


In [85]:
# Verify anchors are valid SE(3)
# out = check_so3(anchors_tensor)
# print(out["ok"].float().mean(), out["ortho_err"].max(), out["det_err"].max())

out = check_se3_flat(anchors_tensor, tol_R=1e-4, tol_last=1e-6)
print(out["ok"].float().mean(), out["max_orth_err"], out["max_det_err"], out["max_last_row_err"])

tensor(1.) 3.2167184826903394e-07 2.384185791015625e-07 0.0


In [86]:
# Create label dataset using Voronoi partitioning
X_train_np = X_train.numpy()
Y_train_np = Y_train.numpy()

train_loader, L_train = create_label_dataset(
    X_train_np, 
    Y_train_np, 
    anchors, 
    batch_size=500, 
    metric='euclidean'
)
print(f"Created training loader with {len(train_loader)} batches")
print(f"Label tensor shape: {L_train.shape}, unique labels: {L_train.unique().numel()}")

--- Generating Labels (Voronoi Partitioning) ---
    Training Data (T): 4000
    Anchors/Particles (N): 100
    Done. Created DataLoader with 8 batches.
Created training loader with 8 batches
Label tensor shape: torch.Size([4000]), unique labels: 100


In [87]:
X_val   = torch.tensor(loaded["X_val"],   dtype=torch.float32)
Y_val   = torch.tensor(loaded["Y_val"],   dtype=torch.float32)

# datasets
val_ds   = TensorDataset(X_val, Y_val)

# loaders
batch_size = 256  # change if you want

val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False)

In [88]:
model = nn.Sequential(RegularFeedForward(16,16,3,dropout=0.0), nn.Linear(16,100))

In [89]:
Y_pred = model(X_train)

In [90]:
Y_pred.shape

torch.Size([4000, 100])

In [91]:
# Train model
model, logs = train_probabilistic_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    anchors=anchors_tensor,
    lr=1e-3,
    num_epochs=1000,
    device='cuda',
    weight_decay=0,
    early_stop=0,      # disable early stopping
    verbose=True
)


epoch    1 | train 1.969e+00 | val 4.896e+06 | lr 1.0e-03
epoch   20 | train 1.948e+00 | val 4.235e+06 | lr 1.0e-03
epoch   40 | train 1.948e+00 | val 4.589e+06 | lr 1.0e-03
epoch   60 | train 1.939e+00 | val 4.425e+06 | lr 1.0e-03
epoch   80 | train 1.961e+00 | val 5.660e+06 | lr 1.0e-03
epoch  100 | train 1.940e+00 | val 4.651e+06 | lr 1.0e-03
epoch  120 | train 1.920e+00 | val 3.708e+06 | lr 1.0e-03
epoch  140 | train 1.941e+00 | val 3.754e+06 | lr 1.0e-03
epoch  160 | train 1.942e+00 | val 3.939e+06 | lr 1.0e-03
epoch  180 | train 1.961e+00 | val 4.071e+06 | lr 1.0e-03
epoch  200 | train 1.958e+00 | val 3.931e+06 | lr 1.0e-03
epoch  220 | train 1.955e+00 | val 3.861e+06 | lr 1.0e-03
epoch  240 | train 1.957e+00 | val 4.077e+06 | lr 1.0e-03
epoch  260 | train 1.957e+00 | val 4.082e+06 | lr 1.0e-03
epoch  280 | train 1.957e+00 | val 4.095e+06 | lr 1.0e-03
epoch  300 | train 1.957e+00 | val 4.104e+06 | lr 1.0e-03
epoch  320 | train 1.957e+00 | val 4.102e+06 | lr 1.0e-03
epoch  340 | t

KeyboardInterrupt: 

In [None]:
# Create anchors from training data
num_anchors = 100
Y_train_np = Y_train.numpy()
anchors = create_anchors(num_anchors, training_data=Y_train_np, manifold='sphere')
anchors_tensor = torch.tensor(anchors, dtype=torch.float32)
print(f"Created {num_anchors} anchors with shape {anchors.shape}")

# Verify anchors are on sphere
anchor_norms = torch.linalg.norm(anchors_tensor, dim=1)
print(f"Anchor norm stats: min={anchor_norms.min():.6f}, max={anchor_norms.max():.6f}, mean={anchor_norms.mean():.6f}")


In [None]:
# Create label dataset
X_train_np = X_train.numpy()
Y_train_np = Y_train.numpy()

train_loader, L_train = create_label_dataset(
    X_train_np, 
    Y_train_np, 
    anchors, 
    batch_size=256, 
    metric='euclidean'
)
print(f"Created training loader with {len(train_loader)} batches")
print(f"Label tensor shape: {L_train.shape}, unique labels: {L_train.unique().numel()}")


In [None]:
# Create validation loader
X_val_np = torch.tensor(loaded["X_val"], dtype=torch.float32).numpy()
Y_val_np = torch.tensor(loaded["Y_val"], dtype=torch.float32).numpy()

val_distances = cdist(Y_val_np, anchors, metric='euclidean')
val_labels = torch.tensor(np.argmin(val_distances, axis=1), dtype=torch.long)

val_ds = TensorDataset(
    torch.tensor(X_val_np, dtype=torch.float32),
    val_labels
)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, drop_last=False)
print(f"Created validation loader with {len(val_loader)} batches")


In [25]:
# Create model
input_dim = X_train.shape[1]  # 3 for sphere input
output_dim = Y_train.shape[1]  # 3 for sphere output

model = ProbabilisticFeedForward(
    input_dim=input_dim,
    hidden_dim=128,
    num_layers=4,
    num_anchors=num_anchors,
    dropout=0.1
)
print(f"Model created: input_dim={input_dim}, num_anchors={num_anchors}")


Model created: input_dim=3, num_anchors=100


In [26]:
# Train model
model, logs = train_probabilistic_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    anchors=anchors_tensor,
    lr=1e-3,
    num_epochs=1000,
    device='cuda',
    weight_decay=0,
    early_stop=0,      # disable early stopping
    verbose=True
)


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [27]:
# Test prediction
model.eval()
X_test = X_train[:100].to('cuda')
anchors_cuda = anchors_tensor.to('cuda')

Y_pred = predict(model, X_test, anchors_cuda, method='expectation')
print(f"Prediction shape: {Y_pred.shape}")

# Check if predictions are on sphere
pred_norms = torch.linalg.norm(Y_pred, dim=1)
print(f"Prediction norm stats: min={pred_norms.min():.6f}, max={pred_norms.max():.6f}, mean={pred_norms.mean():.6f}")
print(f"Predictions on sphere (within 1e-3): {(torch.abs(pred_norms - 1.0) < 1e-3).float().mean():.4f}")


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [28]:
# Check for NaN/Inf
Y_pred_cpu = Y_pred.cpu()
assert torch.isfinite(Y_pred_cpu).all(), "Found NaN/Inf in predictions"
print("All predictions are finite ✓")


NameError: name 'Y_pred' is not defined

In [29]:
# Compare expectation vs argmax prediction methods
Y_pred_expectation = predict(model, X_test, anchors_cuda, method='expectation')
Y_pred_argmax = predict(model, X_test, anchors_cuda, method='argmax')

expectation_norms = torch.linalg.norm(Y_pred_expectation, dim=1)
argmax_norms = torch.linalg.norm(Y_pred_argmax, dim=1)

print(f"Expectation method - norm stats: min={expectation_norms.min():.6f}, max={expectation_norms.max():.6f}, mean={expectation_norms.mean():.6f}")
print(f"Argmax method - norm stats: min={argmax_norms.min():.6f}, max={argmax_norms.max():.6f}, mean={argmax_norms.mean():.6f}")
print(f"Argmax on sphere (within 1e-3): {(torch.abs(argmax_norms - 1.0) < 1e-3).float().mean():.4f}")


NameError: name 'X_test' is not defined

# Test on SO(3)

In [30]:
loaded = torch.load("./../Data/so3_dataset.pt", map_location="cpu", weights_only = False)

for k, v in loaded.items():
    try:
        print(k, v.shape)
    except AttributeError:
        print(k, type(v))

X_train = torch.tensor(loaded['X_train'], dtype = torch.float32)
Y_train = torch.tensor(loaded['Y_train'], dtype = torch.float32)

for name in ["X_train","Y_train","X_val","Y_val","X_test","Y_test"]:
    assert torch.isfinite(torch.tensor(loaded[name])).all(), f"Found NaN/Inf in {name}"


X_train (4000, 9)
Y_train (4000, 9)
X_val (800, 9)
Y_val (800, 9)
X_test (1200, 9)
Y_test (1200, 9)


In [31]:
# Check that Y_train is valid SE(3)
out = check_se3_flat(Y_train, tol_R=1e-4, tol_last=1e-6)
print(f"Y_train SE(3) validity: {out['ok'].float().mean():.4f}")
print(f"Max ortho err: {out['max_orth_err']:.6e}, Max det err: {out['max_det_err']:.6e}, Max last row err: {out['max_last_row_err']:.6e}")


RuntimeError: shape '[4000, 4, 4]' is invalid for input of size 36000

In [32]:
# Check that X_train is valid SO(3)
out = check_so3(X_train, atol=1e-5)
print(f"X_train SO(3) validity: {out['ok'].float().mean():.4f}")
print(f"Max ortho err: {out['ortho_err'].max():.6e}, Max det err: {out['det_err'].max():.6e}")


X_train SO(3) validity: 1.0000
Max ortho err: 2.109975e-07, Max det err: 2.384186e-07


In [33]:
# Create anchors from training data (recommended approach)
num_anchors = 100
Y_train_np = Y_train.numpy()
anchors = create_anchors(num_anchors, training_data=Y_train_np, manifold='se3')
anchors_tensor = torch.tensor(anchors, dtype=torch.float32)
print(f"Created {num_anchors} anchors with shape {anchors.shape}")


--- Creating Anchors: Sampling 100 from Training Data ---
Created 100 anchors with shape (100, 9)


In [34]:
# Verify anchors are valid SE(3)
out = check_se3_flat(anchors_tensor, tol_R=1e-4, tol_last=1e-6)
print(f"Anchors SE(3) validity: {out['ok'].float().mean():.4f}")
print(f"Max ortho err: {out['max_orth_err']:.6e}, Max det err: {out['max_det_err']:.6e}, Max last row err: {out['max_last_row_err']:.6e}")


RuntimeError: shape '[100, 4, 4]' is invalid for input of size 900

In [35]:
# Create label dataset using Voronoi partitioning
X_train_np = X_train.numpy()
Y_train_np = Y_train.numpy()

train_loader, L_train = create_label_dataset(
    X_train_np, 
    Y_train_np, 
    anchors, 
    batch_size=256, 
    metric='euclidean'
)
print(f"Created training loader with {len(train_loader)} batches")
print(f"Label tensor shape: {L_train.shape}, unique labels: {L_train.unique().numel()}")


--- Generating Labels (Voronoi Partitioning) ---
    Training Data (T): 4000
    Anchors/Particles (N): 100
    Done. Created DataLoader with 16 batches.
Created training loader with 16 batches
Label tensor shape: torch.Size([4000]), unique labels: 100


In [36]:
# Create validation loader
X_val_np = torch.tensor(loaded["X_val"], dtype=torch.float32).numpy()
Y_val_np = torch.tensor(loaded["Y_val"], dtype=torch.float32).numpy()

# For validation, we still need labels (closest anchors)
from scipy.spatial.distance import cdist
val_distances = cdist(Y_val_np, anchors, metric='euclidean')
val_labels = torch.tensor(np.argmin(val_distances, axis=1), dtype=torch.long)

val_ds = TensorDataset(
    torch.tensor(X_val_np, dtype=torch.float32),
    val_labels
)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, drop_last=False)
print(f"Created validation loader with {len(val_loader)} batches")


Created validation loader with 4 batches


In [37]:
# Create model
input_dim = X_train.shape[1]  # 9 for SO(3) input
output_dim = Y_train.shape[1]  # 16 for SE(3) output

model = ProbabilisticFeedForward(
    input_dim=input_dim,
    hidden_dim=128,
    num_layers=4,
    num_anchors=num_anchors,
    dropout=0.1
)
print(f"Model created: input_dim={input_dim}, num_anchors={num_anchors}")


Model created: input_dim=9, num_anchors=100


In [38]:
# Test forward pass
x_test = X_train[:4]
logits = model(x_test)
print(f"Input shape: {x_test.shape}, Logits shape: {logits.shape}")
probs = F.softmax(logits, dim=1)
print(f"Probabilities shape: {probs.shape}, Sum per sample: {probs.sum(dim=1)}")


Input shape: torch.Size([4, 9]), Logits shape: torch.Size([4, 100])
Probabilities shape: torch.Size([4, 100]), Sum per sample: tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)


In [39]:
# Custom training function for probabilistic model
def train_probabilistic_model(
    model: nn.Module,
    train_loader,
    val_loader,
    anchors: torch.Tensor,
    lr: float,
    num_epochs: int,
    device,
    weight_decay: float = 1e-4,
    grad_clip: float = 1.0,
    scheduler_patience: int = 300,
    scheduler_factor: float = 0.8,
    early_stop: int = 30,
    verbose: bool = True,
):
    import copy
    model = model.to(device)
    anchors = anchors.to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode="min", patience=scheduler_patience, factor=scheduler_factor
    )

    train_losses, val_losses, lrs = [], [], []
    best_val, best_epoch = float("inf"), -1
    best_state, patience = None, 0

    for epoch in range(num_epochs):
        # ---- train ----
        model.train()
        tr = 0.0
        n_tr = 0
        for batch in train_loader:
            x, labels = batch
            x, labels = x.to(device), labels.to(device)

            opt.zero_grad(set_to_none=True)
            logits = model(x)
            loss = probabilistic_transformer_loss(logits, labels)
            loss.backward()

            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            opt.step()

            bs = x.shape[0]
            tr += loss.item() * bs
            n_tr += bs

        tr /= n_tr
        train_losses.append(tr)

        # ---- val ----
        model.eval()
        va = 0.0
        n_va = 0
        with torch.no_grad():
            for batch in val_loader:
                x, labels = batch
                x, labels = x.to(device), labels.to(device)
                logits = model(x)
                loss = probabilistic_transformer_loss(logits, labels)

                bs = x.shape[0]
                va += loss.item() * bs
                n_va += bs

        va /= n_va
        val_losses.append(va)

        # ---- scheduler ----
        sched.step(va)
        lrs.append(opt.param_groups[0]["lr"])

        # ---- best model ----
        if va < best_val:
            best_val, best_epoch = va, epoch
            best_state = copy.deepcopy(model.state_dict())
            patience = 0
        else:
            patience += 1

        if verbose and ((epoch == 0) or ((epoch + 1) % 20 == 0)):
            print(
                f"epoch {epoch+1:4d} | train {tr:.3e} | val {va:.3e} | lr {opt.param_groups[0]['lr']:.1e}"
            )

        if early_stop and patience >= early_stop:
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    logs = dict(
        train_losses=train_losses,
        val_losses=val_losses,
        lrs=lrs,
        best_val_loss=best_val,
        best_epoch=best_epoch,
    )
    return model, logs


In [40]:
# Train model
model, logs = train_probabilistic_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    anchors=anchors_tensor,
    lr=1e-3,
    num_epochs=1000,
    device='cuda',
    weight_decay=0,
    early_stop=0,      # disable early stopping
    verbose=True
)


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [41]:
# Test prediction on training data
model.eval()
X_test = X_train[:100].to('cuda')
anchors_cuda = anchors_tensor.to('cuda')

Y_pred = predict(model, X_test, anchors_cuda, method='expectation')
print(f"Prediction shape: {Y_pred.shape}")
print(f"Expected shape: (100, 16)")


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [42]:
# Check if predictions are valid SE(3)
Y_pred_cpu = Y_pred.cpu()
out = check_se3_flat(Y_pred_cpu, tol_R=1e-2, tol_last=1e-6)
print(f"Predictions SE(3) validity: {out['ok'].float().mean():.4f}")
print(f"Max ortho err: {out['max_orth_err']:.6e}, Max det err: {out['max_det_err']:.6e}, Max last row err: {out['max_last_row_err']:.6e}")


NameError: name 'Y_pred' is not defined