### Load Data

In [32]:
import pandas as pd

In [33]:
data_df = pd.read_csv('../data/Easier Dataset.csv')

# represent each planetary system as a matrix
matrices = [
    group.drop(columns=['system_number']).to_numpy()
    for _, group in data_df.groupby('system_number')
]

### LSTM-VAE

In [34]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import torch.nn.functional as F

In [35]:
class PlanetarySystemsDataset(Dataset):
    def __init__(self, matrices):
        self.matrices = matrices

    def __len__(self):
        return len(self.matrices)

    def __getitem__(self, idx):
        return torch.tensor(self.matrices[idx], dtype=torch.float32)

def collate_fn(batch):
    # Normalize each sequence before padding
    normed = [nn.LayerNorm(seq.shape[1])(seq) for seq in batch]
    lengths = [seq.shape[0] for seq in normed]
    padded = pad_sequence(normed, batch_first=True)
    return padded, torch.tensor(lengths, dtype=torch.long)

In [36]:
class LSTMVAEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, latent_dim=20):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.length_norm = nn.LayerNorm(hidden_dim)
        self.length_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.LayerNorm(hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, 1)
        )

        # MLP to map z to length
        self.z_to_length = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.LayerNorm(hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, 1)
        )

    def forward(self, x, lengths):
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        _, (h_n, _) = self.lstm(packed)
        h = h_n[-1]
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        normed_h = self.length_norm(h)
        length_pred = self.length_head(normed_h).squeeze(-1)
        return mu, logvar, length_pred

    def predict_length_from_z(self, z):
        return self.z_to_length(z).squeeze(-1)

class LSTMVADecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super().__init__()
        self.lstm = nn.LSTM(latent_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, z, seq_len):
        # z: (batch, latent_dim)
        repeated = z.unsqueeze(1).repeat(1, seq_len, 1)
        out, _ = self.lstm(repeated)
        out = self.fc(out)
        return out

In [39]:
# Prepare dataset and dataloader
dataset = PlanetarySystemsDataset(matrices)
BATCH_SIZE = 64
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

ENCODING_SIZE = 20
LSTM_HIDDEN_SIZE = 128
input_dim = matrices[0].shape[1]

# Instantiate models
encoder = LSTMVAEncoder(input_dim=input_dim, hidden_dim=LSTM_HIDDEN_SIZE, latent_dim=ENCODING_SIZE)
decoder = LSTMVADecoder(output_dim=input_dim, hidden_dim=LSTM_HIDDEN_SIZE, latent_dim=input_dim)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = matrices[0].shape[1]
HIDDEN_DIM = 128
LATENT_DIM = 20

encoder = LSTMVAEncoder(input_dim, HIDDEN_DIM, LATENT_DIM).to(device)
decoder = LSTMVADecoder(LATENT_DIM, HIDDEN_DIM, input_dim).to(device)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)
criterion = nn.MSELoss(reduction='sum')  # sum for proper KL scaling

EPOCHS = 10
for epoch in range(EPOCHS):
    encoder.train()
    decoder.train()
    total_loss = 0
    for padded, lengths in loader:
        padded = padded.to(device)
        batch_size = padded.size(0)
        optimizer.zero_grad()
        mu, logvar, length_pred = encoder(padded, lengths)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std

        length_pred_from_z = encoder.predict_length_from_z(z)
        length_loss_z = F.mse_loss(length_pred_from_z, lengths.float().to(device))
        length_loss = F.mse_loss(length_pred, lengths.float().to(device))

        recon = decoder(z, padded.shape[1])
        # Mask out padded values for loss
        mask = torch.arange(padded.shape[1])[None, :].to(device) < lengths[:, None].to(device)
        recon_loss = ((recon - padded) ** 2).sum(dim=2)
        recon_loss = (recon_loss * mask).sum() / mask.sum()
        # KL divergence
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size

        loss = recon_loss + 1e-3 * kld + length_loss + length_loss_z
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 98692.3960
Epoch 2, Loss: 10319.4348
Epoch 3, Loss: 555.8332
Epoch 4, Loss: 291.1423
Epoch 5, Loss: 248.4918
Epoch 6, Loss: 233.7783
Epoch 7, Loss: 216.9952
Epoch 8, Loss: 205.2572
Epoch 9, Loss: 196.1170
Epoch 10, Loss: 188.3839


In [41]:
encoder.eval()
decoder.eval()
with torch.no_grad():
    z = torch.randn(1, LATENT_DIM).to(device)
    length_pred = encoder.predict_length_from_z(z)
    seq_len = int(torch.clamp(length_pred.round(), min=1, max=50).item())
    generated = decoder(z, seq_len)
    generated = generated.cpu().numpy()

In [42]:
generated

array([[[-0.7783725 ,  1.2636299 , -0.48817095],
        [-0.6269555 ,  1.2938346 , -0.69207555]]], dtype=float32)