In [None]:
import torch
import fitsio
import numpy as np
import anacal
import matplotlib.pyplot as plt

import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# path to one batch of generated images and psf parameters
image = torch.load('/data/scratch/taodingr/AnaCal/batch_1_images.pt')
psf = torch.load('/data/scratch/taodingr/AnaCal/batch_1_psf_params.pt')

In [None]:
def extract_psf_for_anacal(batch_psf_file, image_idx_in_batch):
    """Extract PSF image for a specific image to use with AnaCal"""
    
    # Load the batch PSF data
    psf_data = torch.load(batch_psf_file, map_location='cpu')
    
    if image_idx_in_batch not in psf_data:
        raise ValueError(f"Image index {image_idx_in_batch} not found in PSF data")
    
    psf_info = psf_data[image_idx_in_batch]
    
    if 'psf_image' not in psf_info:
        raise ValueError("No PSF image found in data - need to regenerate with updated code")
    
    # Get the PSF image
    if isinstance(psf_info['psf_image'], torch.Tensor):
        psf_image = psf_info['psf_image'].numpy()
    else:
        psf_image = psf_info['psf_image']
    
    # Ensure proper normalization (AnaCal requirement)
    psf_sum = np.sum(psf_image)
    if psf_sum > 0:
        psf_image = psf_image / psf_sum
    else:
        raise ValueError("PSF image has zero or negative sum!")
    
    logger.info(f"PSF shape: {psf_image.shape}, PSF sum: {psf_sum:.6f}")
    return psf_image.astype(np.float32)

def extract_galaxy_image(batch_image_file, image_idx_in_batch, band_idx=0):
    """Extract galaxy image for a specific image"""
    
    # Load batch images
    batch_images = torch.load(batch_image_file, map_location='cpu')
    
    # Get specific image 
    galaxy_image = batch_images[image_idx_in_batch].numpy()
    
    # AnaCal typically works with single-band images
    # Select the band you want to analyze (0 = first band)
    if len(galaxy_image.shape) == 3:  # [bands, height, width]
        galaxy_image = galaxy_image[band_idx]
    
    logger.info(f"Galaxy image shape: {galaxy_image.shape}")
    return galaxy_image.astype(np.float32)

In [None]:
def run_anacal_with_fits_exact(save_folder, batch_num, image_idx_in_batch):
    """Use FITS files exactly like the working example"""
    
    # Extract your data
    batch_image_file = f"{save_folder}/batch_{batch_num}_images.pt"
    batch_psf_file = f"{save_folder}/batch_{batch_num}_psf_params.pt"
    
    galaxy_image = extract_galaxy_image(batch_image_file, image_idx_in_batch, band_idx=0)
    psf_array = extract_psf_for_anacal(batch_psf_file, image_idx_in_batch)
    
    # Create temporary FITS files
    temp_gal_file = f"temp_galaxy_{batch_num}_{image_idx_in_batch}.fits"
    temp_psf_file = f"temp_psf_{batch_num}_{image_idx_in_batch}.fits"
    
    try:
        # Save to FITS
        fitsio.write(temp_gal_file, galaxy_image, clobber=True)
        fitsio.write(temp_psf_file, psf_array, clobber=True)
        
        # Now run the EXACT same code as the working example
        fpfs_config = anacal.fpfs.FpfsConfig(
            sigma_arcsec=0.52,  # The first measurement scale (also for detection)
            sigma_arcsec1=0.45,  # The second measurement scale
            sigma_arcsec2=0.55,  # The second measurement scale
        )
        
        gal_array = fitsio.read(temp_gal_file)  # Read exactly like the example
        psf_array = fitsio.read(temp_psf_file)  # Read exactly like the example
        
        logger.info(f"FITS galaxy shape: {gal_array.shape}, dtype: {gal_array.dtype}")
        logger.info(f"FITS PSF shape: {psf_array.shape}, dtype: {psf_array.dtype}")
        
        mag_zero = 30.0
        pixel_scale = 0.2
        noise_variance = 0.23**2.0
        noise_array = None
        detection = None
        
        out = anacal.fpfs.process_image(
            fpfs_config=fpfs_config,
            mag_zero=mag_zero,
            gal_array=gal_array,
            psf_array=psf_array,
            pixel_scale=pixel_scale,
            noise_variance=noise_variance,
            noise_array=noise_array,
            detection=detection,
        )
        
        # Print results exactly like the example
        logger.info("\n=== RESULTS (EXACT EXAMPLE FORMAT) ===")
        
        # base kernel scale
        e1 = out["fpfs_w"] * out["fpfs_e1"]
        e1g1 = out["fpfs_dw_dg1"] * out["fpfs_e1"] + out["fpfs_w"] * out["fpfs_de1_dg1"]
        logger.info(np.sum(e1) / np.sum(e1g1))

        # kernel 1
        e1 = out["fpfs_w"] * out["fpfs1_e1"]
        e1g1 = out["fpfs_dw_dg1"] * out["fpfs1_e1"] + out["fpfs_w"] * out["fpfs1_de1_dg1"]
        logger.info(np.sum(e1) / np.sum(e1g1))

        # kernel 2
        e1 = out["fpfs_w"] * out["fpfs2_e1"]
        e1g1 = out["fpfs_dw_dg1"] * out["fpfs2_e1"] + out["fpfs_w"] * out["fpfs2_de1_dg1"]
        logger.info(np.sum(e1) / np.sum(e1g1))
        
        return out
        
    finally:
        # Clean up
        import os
        for temp_file in [temp_gal_file, temp_psf_file]:
            if os.path.exists(temp_file):
                os.remove(temp_file)

In [None]:
config = {'pixel_scale': 0.2, 'mag_zero': 30.0, 'noise_factor': 0.23}
save_folder = "/data/scratch/taodingr/AnaCal"

In [None]:
batch_num = 1
for image_idx in range(10):  # Test first 10 images in batch 1
    results = run_anacal_with_fits_exact(
        save_folder=save_folder,
        batch_num=batch_num, 
        image_idx_in_batch=image_idx,
    )        


# True shear values

In [None]:
# path to the stored catalog 
catalog = torch.load('/data/scratch/taodingr/AnaCal/batch_1_catalog.pt')

In [None]:
# True shear 1 values from the catalog
catalog['shear_1'][:10, :, :, 0, :].squeeze()

In [None]:
# True shear 2 values from the catalog
catalog['shear_2'][:10, :, :, 0, :].squeeze()