In [1]:
# test_stage1_setup.py
#
# Quick test to verify Stage 1 is working with your data

import torch
import numpy as np
from JAISP_dataset import make_loader
from stage1_jepa_foundation import JAISPFoundation

def test_data_loading(rubin_dir, euclid_dir):
    """Test data loader"""
    print("="*60)
    print("TESTING DATA LOADER")
    print("="*60)
    
    dataset, dataloader = make_loader(
        rubin_dir=rubin_dir,
        euclid_dir=euclid_dir,
        batch_size=2,
        shuffle=True,
        num_workers=0,  # Single thread for debugging
        persistent_workers=False
    )
    
    print(f"✓ Dataset loaded: {len(dataset)} tiles")
    
    # Get one batch
    batch = next(iter(dataloader))
    
    print(f"\n✓ Batch structure:")
    print(f"  x_rubin:  {len(batch['x_rubin'])} tiles")
    print(f"  x_euclid: {len(batch['x_euclid'])} tiles")
    
    for i in range(len(batch['x_rubin'])):
        print(f"\n  Tile {i}:")
        print(f"    Rubin:  {batch['x_rubin'][i].shape}")
        print(f"    Euclid: {batch['x_euclid'][i].shape}")
        print(f"    Rubin RMS:  {batch['rms_rubin'][i].shape}")
        print(f"    Euclid RMS: {batch['rms_euclid'][i].shape}")
        print(f"    Euclid mask: {batch['mask_euclid'][i]}")
    
    return batch


def test_patch_extraction(batch):
    """Test patch extraction"""
    print("\n" + "="*60)
    print("TESTING PATCH EXTRACTION")
    print("="*60)
    
    from stage1_jepa_foundation import PatchExtractor
    
    extractor = PatchExtractor(patch_size=128, n_patches=4)
    
    # Extract Rubin patches
    print("\nExtracting Rubin patches...")
    rubin_patches, rubin_weights = extractor(
        batch['x_rubin'], 
        batch['rms_rubin']
    )
    print(f"✓ Rubin patches: {rubin_patches.shape}")
    print(f"✓ Rubin weights: {rubin_weights.shape}")
    print(f"  Expected: ({len(batch['x_rubin']) * 4}, 6, 128, 128)")
    
    # Extract Euclid patches
    print("\nExtracting Euclid patches...")
    euclid_patches, euclid_weights = extractor(
        batch['x_euclid'], 
        batch['rms_euclid']
    )
    print(f"✓ Euclid patches: {euclid_patches.shape}")
    print(f"✓ Euclid weights: {euclid_weights.shape}")
    print(f"  Expected: ({len(batch['x_euclid']) * 4}, 4, 128, 128)")
    
    # Check for NaNs
    print("\nData quality checks:")
    print(f"  Rubin patches - finite: {torch.isfinite(rubin_patches).all()}")
    print(f"  Euclid patches - finite: {torch.isfinite(euclid_patches).all()}")
    print(f"  Rubin weights - finite: {torch.isfinite(rubin_weights).all()}")
    print(f"  Euclid weights - finite: {torch.isfinite(euclid_weights).all()}")
    
    # Statistics
    print("\nPatch statistics:")
    print(f"  Rubin mean:  {rubin_patches.mean():.4f}, std: {rubin_patches.std():.4f}")
    print(f"  Euclid mean: {euclid_patches.mean():.4f}, std: {euclid_patches.std():.4f}")
    
    return rubin_patches, rubin_weights, euclid_patches, euclid_weights


def test_encoder(patches, weights, encoder_type='rubin'):
    """Test encoder"""
    print("\n" + "="*60)
    print(f"TESTING {encoder_type.upper()} ENCODER")
    print("="*60)
    
    from stage1_jepa_foundation import ViTEncoder
    
    in_channels = 6 if encoder_type == 'rubin' else 4
    encoder = ViTEncoder(
        in_channels=in_channels,
        patch_size=16,
        embed_dim=384,
        depth=6,
        num_heads=6
    )
    
    print(f"Input shape: {patches.shape}")
    print(f"Weights shape: {weights.shape}")
    
    with torch.no_grad():
        features = encoder(patches, weights)
    
    print(f"✓ Output features: {features.shape}")
    print(f"  Expected: ({patches.shape[0]}, 384)")
    print(f"  Features finite: {torch.isfinite(features).all()}")
    print(f"  Features mean: {features.mean():.4f}, std: {features.std():.4f}")
    
    return features


def test_full_model(batch, device='cpu'):
    """Test full JEPA model"""
    print("\n" + "="*60)
    print("TESTING FULL JEPA MODEL")
    print("="*60)
    
    model = JAISPFoundation(
        patch_size=128,
        n_patches_per_tile=4,
        vit_patch_size=16,
        embed_dim=384,
        depth=6,
        num_heads=6,
        projection_dim=256,
        temperature=0.07
    ).to(device)
    
    # Count parameters
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"✓ Model created")
    print(f"  Parameters: {n_params:,} ({n_params/1e6:.2f}M)")
    
    # Forward pass
    print("\nForward pass...")
    model.train()
    
    try:
        outputs = model(batch)
        
        print(f"✓ Forward pass successful!")
        print(f"\nOutputs:")
        print(f"  Loss: {outputs['loss'].item():.4f}")
        print(f"  z_rubin: {outputs['z_rubin'].shape}")
        print(f"  z_euclid: {outputs['z_euclid'].shape}")
        print(f"  z_rubin_raw: {outputs['z_rubin_raw'].shape}")
        print(f"  z_euclid_raw: {outputs['z_euclid_raw'].shape}")
        
        # Check embedding quality
        print("\nEmbedding quality:")
        similarity = torch.matmul(outputs['z_rubin'], outputs['z_euclid'].T)
        diag_sim = torch.diag(similarity).mean().item()
        off_diag = similarity.clone()
        off_diag.fill_diagonal_(0)
        off_diag_sim = off_diag.sum().item() / (similarity.numel() - similarity.shape[0])
        
        print(f"  Diagonal similarity (matched): {diag_sim:.4f}")
        print(f"  Off-diagonal similarity (unmatched): {off_diag_sim:.4f}")
        print(f"  Separation: {diag_sim - off_diag_sim:.4f}")
        print(f"    (Should increase during training)")
        
        # Test backward pass
        print("\nTesting backward pass...")
        outputs['loss'].backward()
        print("✓ Backward pass successful!")
        
        return True
        
    except Exception as e:
        print(f"✗ ERROR: {e}")
        import traceback
        traceback.print_exc()
        return False


def main():
    # CONFIGURE THESE PATHS
    RUBIN_DIR = "../data/rubin_tiles_ecdfs"   
    EUCLID_DIR = "../data/euclid_tiles_ecdfs" 
    
    print("\n" + "="*60)
    print("JAISP STAGE 1 - SETUP TEST")
    print("="*60)
    print(f"\nData directories:")
    print(f"  Rubin:  {RUBIN_DIR}")
    print(f"  Euclid: {EUCLID_DIR}")
    
    # Check device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        print(f"\n✓ GPU detected: {torch.cuda.get_device_name(0)}")
        print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    else:
        print("\n⚠ No GPU detected - using CPU (will be slow)")
    
    try:
        # Test 1: Data loading
        batch = test_data_loading(RUBIN_DIR, EUCLID_DIR)
        
        # Test 2: Patch extraction
        rubin_patches, rubin_weights, euclid_patches, euclid_weights = test_patch_extraction(batch)
        
        # Test 3: Encoders
        rubin_features = test_encoder(rubin_patches, rubin_weights, 'rubin')
        euclid_features = test_encoder(euclid_patches, euclid_weights, 'euclid')
        
        # Test 4: Full model
        success = test_full_model(batch, device=device)
        
        if success:
            print("\n" + "="*60)
            print("✓ ALL TESTS PASSED!")
            print("="*60)
            print("\nYou're ready to train! Run:")
            print("  python train_stage1_foundation.py")
            print("\nRecommended settings for your data:")
            print("  - Batch size: 2-4 (depending on GPU memory)")
            print("  - Patch size: 128 (current)")
            print("  - Patches per tile: 4 (extracts 4 random crops per tile)")
            print("  - Total patches per batch: batch_size × 4")
        else:
            print("\n" + "="*60)
            print("✗ TESTS FAILED - check errors above")
            print("="*60)
    
    except Exception as e:
        print("\n" + "="*60)
        print("✗ FATAL ERROR")
        print("="*60)
        print(f"{e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()


JAISP STAGE 1 - SETUP TEST

Data directories:
  Rubin:  ../data/rubin_tiles_ecdfs
  Euclid: ../data/euclid_tiles_ecdfs

✓ GPU detected: Quadro RTX 6000
  Memory: 23.60 GB
TESTING DATA LOADER
✓ Dataset loaded: 144 tiles

✓ Batch structure:
  x_rubin:  2 tiles
  x_euclid: 2 tiles

  Tile 0:
    Rubin:  torch.Size([6, 512, 512])
    Euclid: torch.Size([4, 1050, 1050])
    Rubin RMS:  torch.Size([6, 512, 512])
    Euclid RMS: torch.Size([4, 1050, 1050])
    Euclid mask: tensor([1., 1., 1., 1.])

  Tile 1:
    Rubin:  torch.Size([6, 512, 512])
    Euclid: torch.Size([4, 1050, 1050])
    Rubin RMS:  torch.Size([6, 512, 512])
    Euclid RMS: torch.Size([4, 1050, 1050])
    Euclid mask: tensor([1., 1., 1., 1.])

TESTING PATCH EXTRACTION

Extracting Rubin patches...
✓ Rubin patches: torch.Size([8, 6, 128, 128])
✓ Rubin weights: torch.Size([8, 6, 128, 128])
  Expected: (8, 6, 128, 128)

Extracting Euclid patches...
✓ Euclid patches: torch.Size([8, 4, 128, 128])
✓ Euclid weights: torch.Size([8, 