In [1]:
import ast
import pandas as pd
import re
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torch.nn.functional as F
import torch.nn as nn
from pytorch_lightning import Trainer
from sklearn.model_selection import train_test_split
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:
def read_knot_data(filename):
    records = []
    
    with open(filename, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            # 1) Convert curly braces to square brackets
            line = line.replace('{', '[').replace('}', ']')
            
            # 2) Convert Mathematica *^ floats to Python e floats
            #    Do each sign separately, then default to no sign
            line = line.replace('*^-', 'e-')
            line = line.replace('*^+', 'e+')
            line = line.replace('*^', 'e')
            
            # 3) literal_eval
            data = ast.literal_eval(line)
            records.append(data)
    
    # Create DataFrame
    df = pd.DataFrame(records, columns=["knot_name", "J_zeros", "C_zeros", "volume"])
    return df

In [3]:
df_knots=read_knot_data("nameJ2zerosJ3zerosvol.txt")

In [4]:
def sort_complex_pairs(pair_list):
    """
    Given a list of [real, imag] pairs, 
    sort them first by real component, then by imaginary component.
    """
    if not pair_list:
        return pair_list
    # Use Python's built-in 'sorted' with a tuple (real, imag) as the key
    return sorted(pair_list, key=lambda pair: (pair[0], pair[1]))

In [5]:
df_knots['J_zeros'] = df_knots['J_zeros'].apply(sort_complex_pairs)
df_knots['C_zeros'] = df_knots['C_zeros'].apply(sort_complex_pairs)
df_knots.head()

Unnamed: 0,knot_name,J_zeros,C_zeros,volume
0,4_1,"[[-0.3090169943749479, -0.9510565162951543], [...","[[-0.9713385871308796, -0.2377001244227466], [...",2.029883
1,5_2,"[[-0.33911004330436717, -0.8223754344096812], ...","[[-1.0758734219330786, -0.36954428316759746], ...",2.828122
2,6_1,"[[-0.40662961271472337, -0.7490398002735331], ...","[[-1.0198202123407993, -0.3943807100072422], [...",3.163963
3,6_2,"[[-0.49883183995589636, -1.001302556337741], [...","[[-1.0522537067062143, -0.36270316966133176], ...",4.400833
4,6_3,"[[-0.40096886790241876, -0.9160916804409108], ...","[[-1.1795647308385209, 0], [-0.914211156015100...",5.693021


In [6]:
class KnotsDataset(Dataset):
    def __init__(self, df_knots, input_col='C_zeros', target_col='J_zeros'):
        """
        df_knots: a pandas DataFrame.
        input_col: name of the column containing the input list of [re, im] pairs.
        target_col: name of the column containing the target list of [re, im] pairs.
        """
        self.df_knots = df_knots.reset_index(drop=True)
        self.input_col = input_col
        self.target_col = target_col

    def __len__(self):
        return len(self.df_knots)

    def __getitem__(self, idx):
        """
        Return (x, y) where x is the list of [re, im] for J_roots_pos_imag,
        and y is the list of [re, im] for C_roots_pos_imag.
        """
        row = self.df_knots.iloc[idx]
        x = row[self.input_col]  # a list of [real, imag]
        y = row[self.target_col] # a list of [real, imag]
        return x, y

In [7]:
def knots_collate_fn(batch):
    """
    batch: list of (x, y), where:
      x is a list of [re, im] floats,
      y is a list of [re, im] floats.

    Returns:
      x_tensor: (batch_size, max_len, 2) float
      x_mask:   (batch_size, max_len) bool  (True=valid, False=pad)
      y_tensor: (batch_size, max_len, 3) float
                where [re, im, 0] = valid
                      [0,  0,  1] = padded
    """
    # Separate out all x and y
    x_list = [item[0] for item in batch]  # list of lists of [re, im]
    y_list = [item[1] for item in batch]

    batch_size = len(batch)
    # Find the maximum lengths
    max_len_x = max((len(x_seq) for x_seq in x_list), default=0)
    max_len_y = max((len(y_seq) for y_seq in y_list), default=0)
    max_len   = max(max_len_x, max_len_y)  # unify so we can compare 1:1

    # Create tensors
    x_tensor = torch.zeros((batch_size, max_len, 2), dtype=torch.float)
    x_mask   = torch.zeros((batch_size, max_len), dtype=torch.bool)

    # For y, we have an extra dimension for the "valid/pad" flag
    y_tensor = torch.zeros((batch_size, max_len, 3), dtype=torch.float)

    # Fill each row
    for i, (x_seq, y_seq) in enumerate(zip(x_list, y_list)):
        # --- Inputs ---
        Lx = len(x_seq)
        for j in range(Lx):
            re_val, im_val = x_seq[j]
            x_tensor[i, j, 0] = re_val
            x_tensor[i, j, 1] = im_val
            x_mask[i, j] = True  # Mark valid

        # --- Targets ---
        Ly = len(y_seq)
        for j in range(Ly):
            re_val, im_val = y_seq[j]
            y_tensor[i, j, 0] = re_val
            y_tensor[i, j, 1] = im_val
            y_tensor[i, j, 2] = 0.0  # 0 => valid token

        # Pad the remainder of y with an indicator 1.0 => padded
        for j in range(Ly, max_len):
            y_tensor[i, j, 2] = 1.0  # 1 => padded

    return x_tensor, x_mask, y_tensor

In [24]:
train_df, valtest_df = train_test_split(df_knots, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(valtest_df, test_size=0.5, random_state=42)

train_dataset = KnotsDataset(train_df, input_col='C_zeros', target_col='J_zeros')
val_dataset = KnotsDataset(val_df, input_col='C_zeros', target_col='J_zeros')
test_dataset = KnotsDataset(test_df, input_col='C_zeros', target_col='J_zeros')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=knots_collate_fn)#, multiprocessing_context="forkserver", persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=knots_collate_fn)#, multiprocessing_context="forkserver", persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=knots_collate_fn)#, multiprocessing_context="forkserver", persistent_workers=True)

In [25]:
num_knots = len(train_loader.dataset)
print("Number of knots:", num_knots)

for x_tensor, x_mask, y_tensor in train_loader:
    print("x_tensor shape:", x_tensor.shape)  # (batch_size, max_len, 2)
    print("y_tensor shape:", y_tensor.shape)  # (batch_size, max_len, 3)
    break

Number of knots: 141836
x_tensor shape: torch.Size([32, 67, 2])
y_tensor shape: torch.Size([32, 67, 3])


In [26]:
# Chamfer Distance Loss (simple version)
def chamfer_loss(pred, target):
    # pred, target: (B, N, 2) tensors
    D = torch.cdist(pred, target)  # (B, N, N)
    loss1 = D.min(dim=2)[0].mean(dim=1)  # pred -> target
    loss2 = D.min(dim=1)[0].mean(dim=1)  # target -> pred
    return (loss1 + loss2).mean()

In [28]:
# DeepSets model
class DeepSets(nn.Module):
    def __init__(self, in_dim=2, hidden_dim=128, num_outputs=49):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.rho = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * num_outputs)
        )
        self.num_outputs = num_outputs

    def forward(self, x, mask):
        # x: (B, N, 2), mask: (B, N) bool
        x = self.phi(x)            # (B, N, H)
        mask = mask.unsqueeze(-1)  # (B, N, 1) for broadcasting
        x = x * mask               # zero out padded elements
        x = x.sum(dim=1)           # sum only valid elements
        x = self.rho(x)            # (B, 2*num_outputs)
        return x.view(-1, self.num_outputs, 2)


In [None]:
epochs = 10

model = DeepSets(in_dim=2, hidden_dim=128, num_outputs=49)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    model.train()
    train_loss = 0.0

    for x_tensor, x_mask, y_tensor in train_loader:
        # y_tensor: (B, max_len, 3), need to mask out padded entries
        y_valid = y_tensor[:, :, 2] == 0  # shape (B, max_len)
        y_real  = y_tensor[:, :, :2]      # shape (B, max_len, 2)

        # Get valid target positions only
        pred = model(x_tensor, x_mask)
        y_valid = (y_tensor[:, :, 2] == 0)      # (B, max_len) bool mask
        y_real = y_tensor[:, :, :2].float()    # (B, max_len, 2)

        # Zero out padded entries in target
        padded_target = y_real * y_valid.unsqueeze(-1).float()  # (B, max_len, 2)

        loss = chamfer_loss(pred, padded_target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}")



Epoch 0: Train Loss = 0.1983
Epoch 1: Train Loss = 0.1580
Epoch 2: Train Loss = 0.1437
Epoch 3: Train Loss = 0.1365
Epoch 4: Train Loss = 0.1319
