In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AstroStem(nn.Module):
    """
    Instrument-specific 'Stems' that project Science+RMS into a 
    common latent resolution.
    """
    def __init__(self, in_channels, input_size, embed_dim=256):
        super().__init__()
        # Rubin: 512 -> stride 16 results in 32x32 tokens
        # Euclid: 1050 -> center crop to 1024 -> stride 32 results in 32x32 tokens
        self.input_size = input_size
        stride = 16 if input_size == 512 else 32
        
        # We concatenate Science + RMS, doubling the input channels
        self.proj = nn.Conv2d(in_channels * 2, embed_dim, kernel_size=stride, stride=stride)
        self.pos_embed = nn.Parameter(torch.zeros(1, 32*32, embed_dim))
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, img, rms):
        # 1. Uncertainty Integration: Concatenate flux and noise
        x = torch.cat([img, rms], dim=1) # (B, 2*C, H, W)

        # 2. Resolution Matching
        if self.input_size == 1050:
            # Center crop 1050 to 1024 to maintain power-of-2 patching
            x = x[:, :, 13:-13, 13:-13]
            
        # 3. Tokenization
        tokens = self.proj(x).flatten(2).transpose(1, 2) # (B, 1024, embed_dim)
        tokens = tokens + self.pos_embed
        return self.norm(tokens)

class Stage1Foundation(nn.Module):
    """
    Roadmap Stage 1: Cross-survey representation.
    Learns consistent features across Rubin (ground) and Euclid (space).
    """
    def __init__(self, embed_dim=256, depth=8, num_heads=8):
        super().__init__()
        # Multi-scale stems
        self.rubin_stem = AstroStem(in_channels=6, input_size=512, embed_dim=embed_dim)
        self.euclid_stem = AstroStem(in_channels=4, input_size=1050, embed_dim=embed_dim)

        # Shared trunk: Windowed Attention Transformer
        # This allows Rubin and Euclid to be processed by the same physical weights
        self.trunk = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim, 
                nhead=num_heads, 
                dim_feedforward=embed_dim*4,
                activation='gelu',
                batch_first=True
            ),
            num_layers=depth
        )

        # JEPA Predictor Head
        # Attempts to predict Euclid latents from Rubin latents (cross-instrument bridge)
        self.predictor = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, r_img, r_rms, e_img, e_rms):
        # Map both instruments to a shared 32x32 latent grid
        r_latents_init = self.rubin_stem(r_img, r_rms)
        e_latents_init = self.euclid_stem(e_img, e_rms)
        
        # Process through shared trunk
        # r_latents represents the Rubin-based view of the sky
        # e_latents represents the Euclid-based view of the sky
        r_latents = self.trunk(r_latents_init)
        e_latents = self.trunk(e_latents_init)
        
        # JEPA Prediction: Can we see the space-based truth from the ground?
        e_pred = self.predictor(r_latents)
        
        return {
            "r_latents": r_latents, 
            "e_latents": e_latents, 
            "e_pred": e_pred
        }

def jepa_loss_fn(outputs, e_rms_map, threshold=0.1):
    """
    Latent-space loss weighted by the target's S/N.
    We detatch the target latents (e_latents) following the JEPA/DINO recipe.
    """
    e_target = outputs["e_latents"].detach()
    e_pred = outputs["e_pred"]
    
    # Calculate MSE
    loss = F.mse_loss(e_pred, e_target, reduction='none') # (B, 1024, embed_dim)
    
    # Optional: You can use the e_rms_map to mask out latent patches 
    # that correspond to saturated or missing data in the target Euclid tile.
    return loss.mean()