In [None]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from einops import rearrange
from utils import my_transform

class MantisShrimpDataset(Dataset):
    def __init__(self, kind: str, WORLD_RANK:int =0, ZMAX: float=1.6):
        assert kind in ['train','val','test']
        
        self.max = np.array([2.0336876,8.301209523105817,6.198499090528679]) #TODO
        
        csv_path1 = '/rcfs/projects/mantis_shrimp/mantis_shrimp/data/spectroscopy/redshifts_withextinction.pkl'
        csv_path2 = '/rcfs/projects/mantis_shrimp/mantis_shrimp/data/redshifts_broken_beck/SDSS_MGS/MGS_qwsaz123.csv'
        csv_path3 = '/rcfs/projects/mantis_shrimp/mantis_shrimp/data/redshifts_broken_beck/WISE_PS1_STRM.csv'
        
        DF = pd.read_pickle(csv_path1,)
        DF_pas = pd.read_csv(csv_path2,usecols=['bestObjID','zphot'])
        
        DF_pas.drop_duplicates('bestObjID',inplace=True)
        DF_pas_comb = pd.merge(DF,DF_pas,'left',left_on='photoObjID_survey',right_on='bestObjID')
        DF_pas_comb = DF_pas_comb.drop('bestObjID',axis = 1)
        DF = DF_pas_comb
        DF_wps = pd.read_csv(csv_path3,
                             usecols=['dstArcSec',
                                      'cellDistance_Photoz',
                                      'z_phot0',
                                      'z_photErr',
                                      'prob_Galaxy',
                                      'photoObjID_survey']
                             )
        
        DF = pd.merge(DF,DF_wps,how='left',on='photoObjID_survey')
        
        indices = np.load(f'/rcfs/projects/mantis_shrimp/mantis_shrimp/data/npy_blocks/{kind}_indices.npy', mmap_mode = 'r')
        exists_mask = np.load('/rcfs/projects/mantis_shrimp/mantis_shrimp/data/exists_mask.npy', mmap_mode = 'r')
        
        #get the correct chunk's indices-- these match what is in img.
        indices = indices[WORLD_RANK]
        
        
        z = DF['z'].values
        ebv_csfd = DF['ebv_csfd'].values
        ebv_planck = DF['ebv_planck'].values
        zphot_MGS = DF['zphot'].values
        zphot_WPS = DF['z_phot0'].values
        
        #now apply indices to find the correct values for this chunk
        z = z[indices]
        ebv_csfd = ebv_csfd[indices]
        ebv_planck = ebv_planck[indices]
        zphot_MGS = zphot_MGS[indices]
        zphot_WPS = zphot_WPS[indices]
        exists_mask = exists_mask[indices]
        
        #Now we use a mixture of whether the data exists + whether it satisfies our ZMAX constraint to create a mask.
        #unfortunately, we cannot mask img b/c it is a mmap array. So we need to create a dictionary mapping from indices
        #the user would supply to the existing data in img.
        
        zmax_mask = z<ZMAX
        total_mask = np.logical_and(exists_mask,zmax_mask) #both must be True to accept.
        
        self.z = z[total_mask]
        self.ebv_csfd = ebv_csfd[total_mask]
        self.ebv_planck = ebv_planck[total_mask]
        self.zphot_MGS = zphot_MGS[total_mask]
        self.zphot_WPS = zphot_WPS[total_mask]
        
        self.idx_to_imgidx = dict(zip(np.arange(total_mask.sum()),np.where(total_mask)[0]))
        
        
        self.img = np.load(f'/rcfs/projects/mantis_shrimp/mantis_shrimp/data/npy_blocks/{kind}/mantis_shrimp_{WORLD_RANK}.npy',
                mmap_mode='r')
                             
        
    def __len__(self):
        return len(self.idx_to_imgidx)

    def __getitem__(self, idx):
        img_indices = np.array([self.idx_to_imgidx[idx],])
        idx = np.array([idx,])
        img = self.img[img_indices]
        if len(img.shape) == 1:
            img = img[None,] #add leading batch dimension
        
        galex = img[:,0:2048]
        panstarrs = img[:,2048:146548]
        unwise = img[:,146548::]
        
        galex = rearrange(galex,'b (f h w) -> b f h w',f=2,h=32,w=32)
        panstarrs = rearrange(panstarrs,'b (f h w) -> b f h w',f=5,h=170,w=170)
        unwise = rearrange(unwise,'b (f h w) -> b f h w',f=2,h=32,w=32)
        
        #these Nones will add a leading batch dimension.
        z = self.z[idx]
        ebv_csfd = self.ebv_csfd[idx]
        ebv_planck = self.ebv_planck[idx]
        zphot_MGS = self.zphot_MGS[idx]
        zphot_WPS = self.zphot_WPS[idx]
        
        #apply arcsinh scaling
        galex = my_transform(galex,0.2) #TODO
        panstarrs = my_transform(panstarrs,0.2)
        unwise = my_transform(unwise,0.2)
        
        galex = galex/self.max[0]
        panstarrs = panstarrs/self.max[1]
        unwise = unwise/self.max[2]
        
        ebvs = np.concatenate([ebv_csfd,ebv_planck]).T
        
        z = z.squeeze()
        
        galex = galex.astype('float32')
        panstarrs = panstarrs.astype('float32')
        unwise = unwise.astype('float32')
        z = z.astype('float32')
        ebvs = ebvs.astype('float32')
        zphot_MGS = zphot_MGS.astype('float32')
        zphot_WPS = zphot_WPS.astype('float32')
        
        return galex, panstarrs, unwise, z, ebvs, zphot_MGS, zphot_WPS

In [None]:
from mantis_shrimp import datasets

MSD = datasets.MantisShrimpDataset(kind='train',WORLD_RANK=0,ZMAX=1.6,loc='vast')

In [None]:
len(MSD)

In [None]:
len(MantisShrimpDataset('train', WORLD_RANK = 0, ZMAX = 1.6))

### USING FFCV for dataloading

We use the FFCV module and a custom dataset defintion to speed up dataloading by 17x. First we need to write the dataset using FFCV's DatasetWriter module. We also need to know the shapes of any arrays we want to convert in advance, and we feed those into the NDArray field parameter. A key point is to match the datatypes of the input to what we specify in the DatasetWriter module. 

In [1]:
from ffcv.writer import DatasetWriter
from ffcv.fields import NDArrayField, FloatField
from mantis_shrimp import datasets
import numpy as np

# Loop through dataset partitions and world ranks to write datasets to disk
for kind in ['train', 'val', 'test']:
    for WORLD_RANK in range(16):
         # Initialize MantisShrimp dataset
        dataset = datasets.MantisShrimpDataset(kind, WORLD_RANK, ZMAX =1.6, loc = 'rcfs',to_torch = False, mmap = True)
        # Initialize DatasetWriter with specified fields and output file path
        writer = DatasetWriter(f'/rcfs/projects/mantis_shrimp/Adam/mantis_shrimp_{kind}_{WORLD_RANK}.beton', {
            'galex': NDArrayField(shape=(1, 2, 32, 32), dtype=np.dtype('float32')), 
            'panstarrs': NDArrayField(shape=(1, 5, 170, 170), dtype=np.dtype('float32')),
            'unwise': NDArrayField(shape=(1, 2, 32, 32), dtype=np.dtype('float32')),
            'z': FloatField(),
            'ebvs': NDArrayField(shape=(2,), dtype=np.dtype('float32')),
            'zphot_MGS': NDArrayField(shape=(1,), dtype=np.dtype('float32')),
            'zphot_WPS': NDArrayField(shape=(1,), dtype=np.dtype('float32')),
        }, num_workers=16) # Number of worker threads to use for writing
        # Write the dataset to disk from the indexed dataset
        writer.from_indexed_dataset(dataset)

  0%|                                                                                                                                                                                                                                                        | 0/188117 [00:00<?, ?it/s]

KeyboardInterrupt: 

Process Process-16:
Traceback (most recent call last):
  File "/people/tsou806/.conda/envs/ntkenvironment/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/people/tsou806/.conda/envs/ntkenvironment/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/people/tsou806/.local/lib/python3.9/site-packages/ffcv/writer.py", line 113, in worker_job_indexed_dataset
    handle_sample(sample, dest_ix, field_names, metadata, allocator, fields)
  File "/people/tsou806/.local/lib/python3.9/site-packages/ffcv/writer.py", line 51, in handle_sample
    field.encode(destination, field_value, allocator.malloc)
  File "/people/tsou806/.local/lib/python3.9/site-packages/ffcv/fields/ndarray.py", line 98, in encode
    destination[0], data_region = malloc(self.element_size)
  File "/people/tsou806/.local/lib/python3.9/site-packages/ffcv/memory_allocator.py", line 43, in malloc
    self.flush_page()
  File "/

In [None]:
writer.from_indexed_dataset(dataset)

In [None]:
from ffcv.transforms import ToTensor, ToDevice
from ffcv.loader import Loader, OrderOption
from ffcv.fields.decoders import NDArrayDecoder, FloatDecoder
PIPELINES = {
  'galex': [NDArrayDecoder(), ToTensor(), ToDevice(torch.device('cuda'), non_blocking=True)],
  'panstarrs': [NDArrayDecoder(), ToTensor()],
  'unwise': [NDArrayDecoder(), ToTensor()],
  'z' : [FloatDecoder(), ToTensor()],
  'ebvs' : [NDArrayDecoder(), ToTensor()],
  'zphot_MGS': [NDArrayDecoder(), ToTensor()],
  'zphot_WPS': [NDArrayDecoder(), ToTensor()],
}

ORDERING = OrderOption.QUASI_RANDOM

BATCH_SIZE = 32

NUM_WORKERS = 8

loader = Loader('/rcfs/projects/mantis_shrimp/Adam/test.beton',
                batch_size=BATCH_SIZE,
                num_workers=NUM_WORKERS,
                order=ORDERING,
                pipelines=PIPELINES)

In [None]:
for i, data in enumerate(loader): 
    print(data[0], data[1], data[2], data[3], data[4], data[5], data[6])

In [None]:
dataset[0][0].shape

In [None]:
dataset[0][1].shape

In [None]:
dataset[0][2].shape

In [None]:
dataset[0][3]

In [None]:
dataset[0][4].shape

In [None]:
dataset[0][5].shape

In [None]:
dataset[0][6].shape

In [2]:
from mantis_shrimp import datasets

len(datasets.MantisShrimpDataset(kind='val',WORLD_RANK=0,ZMAX=1.6,loc='vast'))

26859