# FITS Dataloader

A dataloader that iterates through a directory of fits files that contain images large patches of the sky using different colour bands.

Each sample that the dataloader returns will be an array of cutouts from a given patch with all requested bands.

In [1]:
import numpy as np
import torch
import h5py
from torchvision.transforms import v2
import glob
from astropy.io import fits

def build_fits_dataloader(fits_path, bands, norm_type, batch_size, num_workers,
                     img_size=64, pix_mean=None, pix_std=None, 
                     augment=False, shuffle=True):

    if augment:
        transforms = v2.Compose([v2.GaussianBlur(kernel_size=5, sigma=(0.1,1.5)),
                                 v2.RandomResizedCrop(size=(img_size, img_size), scale=(0.8,1), antialias=True),
                                 v2.RandomHorizontalFlip(p=0.5),
                                 v2.RandomVerticalFlip(p=0.5)])
    else:
        transforms = None
    
    # Data loaders
    dataset = FitsDataset(fits_path, bands=bands, img_size=img_size,
                          batch_size=batch_size, shuffle=shuffle,
                          norm=norm_type, transform=transforms,
                          global_mean=pix_mean, global_std=pix_std)

    return torch.utils.data.DataLoader(dataset, batch_size=1, 
                                       shuffle=shuffle, num_workers=num_workers,
                                       pin_memory=True)

def find_HSC_bands(fits_path, bands):
    '''An HSC specific function that returns a list of file paths with all of the requested bands.'''

    # Look for fits files
    fits_files = sorted(glob.glob(f"{fits_path}/calexp-HSC-*.fits"))
    
    # Convert '/arc/projects/ots/pdr3_dud/calexp-HSC-I-9707-4%2C0.fits' to 9707-4%2C0.fits
    unique_patches = list(set(['-'.join(x.split('-')[-2:]) for x in fits_files]))
    unique_patches = sorted(unique_patches)
    
    # Make it hashable
    set_fits_files = set(fits_files)

    # Sort file names
    filenames = []
    for t in unique_patches:
        potential_files = [f'{fits_path}/calexp-HSC-{b}-{t}' for b in bands]

        # `f in set_fits_files` is O(n) if fits_files is a list,
        # ~O(1) if fits_files is a hash table
        if (all([f in set_fits_files for f in potential_files])):
            filenames.append(potential_files)
    print(f"Found {len(filenames)} patches with the {bands} bands.")

    return filenames

def split_3d_array(input_array, img_size, overlap):
    """
    Split a larger 3D numpy array into a batch of smaller cutouts.

    Args:
    - input_array: The larger 3D numpy array of shape (C, H, W).
    - img_size: The desired size of each cutout in both height and width.
    - overlap: The overlap between adjacent cutouts as a fraction (0 to 1).

    Returns:
    - A numpy array of shape (N, C, img_size, img_size), where N is the number of cutouts.
    """
    C, H, W = input_array.shape
    batch_size = int((H / (img_size - img_size * overlap)) * (W / (img_size - img_size * overlap)))
    
    cutout_list = []
    
    for _ in range(batch_size):
        h_start = np.random.randint(0, H - img_size + 1)
        w_start = np.random.randint(0, W - img_size + 1)
        
        cutout = input_array[:, h_start:h_start+img_size, w_start:w_start+img_size]
        cutout_list.append(cutout)
    
    return np.array(cutout_list)

class FitsDataset(torch.utils.data.Dataset):
    
    """
    Dataset loader for the cutout datasets.
    """

    def __init__(self, fits_path, bands=['G','R','I','Z','Y'], img_size=64, overlap=0.,
                 batch_size=64, shuffle=True, norm=None, transform=None, 
                 global_mean=0.1, global_std=2., pixel_min=None, pixel_max=None):
        
        self.fits_path = fits_path
        self.img_size = img_size
        self.overlap = overlap
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.norm = norm
        self.transform = transform
        self.global_mean = global_mean
        self.global_std = global_std 
        self.pixel_min = pixel_min
        self.pixel_max = pixel_max

        # Find names of patch fits files
        self.band_filenames = find_HSC_bands(fits_path, bands)
                        
    def __len__(self):
        # The number of fits patches with all of the requested bands
        return len(self.band_filenames)
    
    def __getitem__(self, idx):

        # Grab fits filenames
        patch_filenames = self.band_filenames[idx]
        
        # Load all channels of the patch of sky
        cutouts = []
        for fn in patch_filenames:
            cutouts.append(fits.open(fn, mode='readonly')[1].data)
        # Organize into (C, H, W)
        cutouts = np.array(cutouts)

        # Split into a grid of cutouts based on img_size and overlap
        cutouts = split_3d_array(cutouts, self.img_size, self.overlap)

        # Shuffle images
        if self.shuffle:
            permutation = np.random.permutation(cutouts.shape[0])
            cutouts = cutouts[permutation]

        # Remove any NaN pixel values
        cutouts[np.isnan(cutouts)] = 0.

        # Clip pixel values
        if self.pixel_min is not None:
            cutouts[cutouts<self.pixel_min] = self.pixel_min
        if self.pixel_max is not None:
            cutouts[cutouts>self.pixel_max] = self.pixel_max

        # Apply any augmentations
        cutouts = torch.from_numpy(cutouts)
        if self.transform is not None:
            cutouts = self.transform(cutouts)
            
        if self.norm=='minmax':
            # Normalize each sample between 0 and 1
            sample_min = torch.amin(cutouts, dim=(1,2,3), keepdim=True)
            sample_max = torch.amax(cutouts, dim=(1,2,3), keepdim=True)
            cutouts = (cutouts - sample_min) / (sample_max - sample_min + 1e-6)
        elif self.norm=='zscore':
            # Normalize each sample to have zero mean and unit variance
            sample_mean = torch.mean(cutouts, dim=(1,2,3), keepdim=True)
            sample_std = torch.std(cutouts, dim=(1,2,3), keepdim=True)
            cutouts = (cutouts - sample_mean) / (sample_std + 1e-6)
        elif self.norm=='global':
            # Normalize dataset to have zero mean and unit variance
            cutouts = (cutouts - self.global_mean) / self.global_std

        # Sort into M batches of batch_size
        M = cutouts.shape[0] // self.batch_size
        C = cutouts.shape[1]
        cutouts = cutouts[:M * self.batch_size].reshape((M, self.batch_size, C, self.img_size, self.img_size))

        return cutouts



## Construct the dataloader

Increasing `num_workers` will allow you to load the next patch while training on the current one. However, it probably shouldn't be increased too high or else you will quickly run out of RAM.

In [2]:
fits_path = '/arc/projects/ots/pdr3_dud'
bands = ['G','R','I','Z','Y']
norm_type = 'minmax'
batch_size = 8
num_workers = 0
img_size = 64

dataloader = build_fits_dataloader(fits_path, bands, norm_type, batch_size, num_workers,
                                   img_size=img_size, pix_mean=None, pix_std=None, augment=True, shuffle=True)

Found 1477 patches with the ['G', 'R', 'I', 'Z', 'Y'] bands.


## Example usage

The outer loop loads a bunch of batches - all from the same patch in the sky - and the inner loop iterates through each batch, which allows you to do your training on one batch at a time.

In [3]:
# Iterate through dataloader
for sample_batches in dataloader:    
    # Iterate through each batch of images in this patch of the sky
    print('This patch of sky created cutouts with the shape:', sample_batches[0].shape)
    for batch in sample_batches[0]:
        # Move batch to GPU and do your training stuff here
        pass
    print('Each batch has the shape:', batch.shape)
    break

This patch of sky created cutouts with the shape: torch.Size([538, 8, 5, 64, 64])
Each batch has the shape: torch.Size([8, 5, 64, 64])
