In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
# Chamfer Distance Loss (simple version)
def chamfer_loss(pred, target):
    # pred, target: (B, N, 2) tensors
    D = torch.cdist(pred, target)  # (B, N, N)
    loss1 = D.min(dim=2)[0].mean(dim=1)  # pred -> target
    loss2 = D.min(dim=1)[0].mean(dim=1)  # target -> pred
    return (loss1 + loss2).mean()

In [4]:
# DeepSets model
class DeepSets(nn.Module):
    def __init__(self, in_dim=2, hidden_dim=128, num_outputs=10):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.rho = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * num_outputs)
        )
        self.num_outputs = num_outputs

    def forward(self, x):
        # x: (B, N, 2) input root sets (real, imag)
        x = self.phi(x)         # (B, N, H)
        x = x.sum(dim=1)        # (B, H)
        x = self.rho(x)         # (B, 2*num_outputs)
        return x.view(-1, self.num_outputs, 2)  # (B, M, 2)

In [None]:
B = 8   # batch size
N = 10  # number of input roots
M = 10  # number of output roots

# Example input and target data to be replace with the padded data from JonesPloynomialTransformer.txt
input_roots = torch.randn(B, N, 2)     # zeros of Jones 2 poly
target_roots = torch.randn(B, M, 2)    # zeros of Jones 3 poly

model = DeepSets(in_dim=2, hidden_dim=128, num_outputs=M)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(100):
    model.train()
    pred_roots = model(input_roots)
    loss = chamfer_loss(pred_roots, target_roots)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

Epoch 0: Loss = 1.2391
Epoch 10: Loss = 0.8649
Epoch 20: Loss = 0.7368
Epoch 30: Loss = 0.6444
Epoch 40: Loss = 0.5678
Epoch 50: Loss = 0.5200
Epoch 60: Loss = 0.4763
Epoch 70: Loss = 0.4341
Epoch 80: Loss = 0.4233
Epoch 90: Loss = 0.3947
