In [13]:
import numpy as np
import torch
import torch.nn as nn
import torch
from torch.utils.data import IterableDataset, DataLoader
import torch.nn.functional as F
from modules import Trafo
from modules.galactic2icrs import cartesianGalactic_to_ICRS_torch
from modules import pairwise_min_velocity_diff_torch

# Create data

In [15]:
N = 10_000
mu = np.array([
    # X  Y  Z
    100, -10, 5,
    # U V  W
    1, 2, -1
])
# Dispersion in X
sx = 1
# Velocity dispersion
sv = 1
C = np.diag([sx, sx, sx, sv, sv, sv])

# This is the Cartesian space where we should cluster in --> the latent space learned by the AE
data = np.random.multivariate_normal(mu, C, N)
data_icrs = Trafo().cart2spher(data)

In [4]:
class RandomSubset5DDataset(IterableDataset):
    def __init__(self, data5D: torch.Tensor, subset_size: int, iters_per_epoch: int):
        """
        data5D:        (N, 5) tensor of your full catalogue [ra, dec, plx, pmra, pmdec]
        subset_size:   M, how many samples per batch
        iters_per_epoch: number of random subsets to draw each epoch
        """
        super().__init__()
        self.data = data5D
        self.N = data5D.size(0)
        self.M = subset_size
        self.iters = iters_per_epoch

    def __iter__(self):
        for _ in range(self.iters):
            # sample M unique indices without replacement
            idx = torch.randperm(self.N, device=self.data.device)[:self.M]
            yield self.data[idx]  # shape (M, 5)

# ——————————————
# Usage example:

# 1) your full 5D dataset
full5D = torch.as_tensor(data_icrs[['ra', 'dec', 'parallax', 'pmra', 'pmdec']].values, dtype=torch.float32)

# 2) make the IterableDataset
M = 512    # samples per batch
Iters = 200    # subsets per epoch
ds = RandomSubset5DDataset(full5D, M, Iters)

# 3) wrap in a DataLoader (batch_size=None means each yield is one “batch”)
loader  = DataLoader(ds, batch_size=None)

# Create encoder network

In [32]:
def build_mlp(input_dim: int, hidden_dims: tuple, output_dim: int,
              activation=nn.ReLU, inplace=True) -> nn.Sequential:
    """
    Builds an MLP from input_dim → *hidden_dims → output_dim*,
    inserting activation() between each linear layer.
    """
    layers = []
    dims = [input_dim, *hidden_dims, output_dim]
    for i in range(len(dims) - 1):
        layers.append(nn.Linear(dims[i], dims[i+1]))
        # add activation after every layer except the last
        if i < len(dims) - 2:
            layers.append(activation(inplace=inplace))
    return nn.Sequential(*layers)


class SemiAnalyticAutoencoder(nn.Module):
    def __init__(self, M, hidden_dims=(128,64)):
        """
        M : int
            number of samples in each random batch (so diff-matrices are M×M)
        hidden_dims : tuple
            sizes of the two hidden layers in the row-wise MLP encoder
        """
        super().__init__()
        self.M = M

        # dynamically build encoder: M → *hidden_dims → 6
        self.encoder =build_mlp(
            input_dim=M,
            hidden_dims=hidden_dims,
            output_dim=6
        )

    def forward(self, batch5D):
        """
        batch5D : Tensor of shape (B, M, 5)
            last-dim = [ra, dec, parallax, pmra, pmdec]
        Returns
        -------
        recon        : Tensor (B, M, M)
        diff_target  : Tensor (B, M, M)
        """
        x5 = torch.as_tensor(batch5D)
        if x5.dim() != 3 or x5.size(-1) != 5:
            raise ValueError(f"Expected (B, {self.M}, 5), got {tuple(x5.shape)}")
        B, M, _ = x5.shape
        if M != self.M:
            raise ValueError(f"Model built for M={self.M}, got batch size {M}")

        # 1) analytic ground truth
        #    pairwise_min_velocity_diff_torch ignores the 6th dim, so 5 dims is fine
        diff_target = pairwise_min_velocity_diff_torch(x5)  # (B, M, M)

        # 2) encode each row of diff_target → 6D
        rows = diff_target.reshape(B * M, M)  # (B*M, M)
        lat  = self.encoder(rows)             # (B*M, 6)
        lat  = lat.view(B, M, 6)              # (B, M, 6)

        # 3) analytic decode back to M×M
        astro_icrs = cartesianGalactic_to_ICRS_torch(lat)         # (B, M, 6)
        # keep the first 5 dims to regress back to x5
        astro_pred5 = astro_icrs[..., :5]                         # (B, M, 5)
        recon = pairwise_min_velocity_diff_torch(astro_icrs)      # (B, M, M)

        return recon, diff_target, astro_pred5

In [33]:
model = SemiAnalyticAutoencoder(M=512, hidden_dims=(1024, 512, 256, 128, 64, 32)).to('cpu')
opt   = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 500
beta = 0.1
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch5D in loader:           # each batch5D: (M,5)
        batch5D = batch5D.to('cpu').unsqueeze(0)  # → (1, M, 5) or stack multiple

        recon, target, astro_pred5 = model(batch5D)            # both (1, M, M)
        # 1) pairwise‐matrix MSE
        loss_mat = F.mse_loss(recon, target)
        # 2) astro‐features MSE
        # compare astro_pred5 against original x5
        loss_astro = F.mse_loss(astro_pred5, batch5D)
        # total loss
        loss = loss_mat + beta * loss_astro

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item()
    print(f"Epoch {epoch}: avg loss = {total_loss/Iters:.6f}")

Epoch 0: avg loss = 7276.502719
Epoch 1: avg loss = 1.201085
Epoch 2: avg loss = 1.198250
Epoch 3: avg loss = 1.178690
Epoch 4: avg loss = 1.144219
Epoch 5: avg loss = 1.118202
Epoch 6: avg loss = 1.116410
Epoch 7: avg loss = 1.121658
Epoch 8: avg loss = 1.118961
Epoch 9: avg loss = 1.114386
Epoch 10: avg loss = 1.114094
Epoch 11: avg loss = 1.115171
Epoch 12: avg loss = 1.115516
Epoch 13: avg loss = 1.120365
Epoch 14: avg loss = 1.112367
Epoch 15: avg loss = 1.114928
Epoch 16: avg loss = 1.111764
Epoch 17: avg loss = 1.114697
Epoch 18: avg loss = 1.114460
Epoch 19: avg loss = 1.114372
Epoch 20: avg loss = 1.110418
Epoch 21: avg loss = 1.105784
Epoch 22: avg loss = 1.112582
Epoch 23: avg loss = 1.111187
Epoch 24: avg loss = 1.109915
Epoch 25: avg loss = 1.109958
Epoch 26: avg loss = 1.102282
Epoch 27: avg loss = 1.108710
Epoch 28: avg loss = 1.107262
Epoch 29: avg loss = 1.097012
Epoch 30: avg loss = 1.087760
Epoch 31: avg loss = 1.067882
Epoch 32: avg loss = 1.034269
Epoch 33: avg los

KeyboardInterrupt: 