# Getting started with Hyrax Custom Dataset Classes

In this notebook we are going to build up a custom dataset class for hyrax, and show how you can use the 
`prepare` verb in hyrax to test various aspects of your new dataclass.

First we will create some data in the form of 1000 random 10x10 tensors, fake filenames for these tensors:

In [1]:
import numpy as np
import torch

rng = np.random.default_rng()
num_tensors = 1000

# Generate filenames
alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
filename_length = 15
filenames = ["".join(list(rng.choice(alphabet, 15))) for _ in range(num_tensors)]

# Generate tensors
shape = (3, 10, 10)
random_data = {file: torch.from_numpy(rng.random(size=shape, dtype=np.float32)) for file in filenames}

## Building a custom Dataset class

We will treat these tensors as if they are on the filesystem, and write a dataclass that gives hyrax access to 
these "files" treating `_read_tensor` as a library function which returns a torch.Tensor from our "files", 
and `_list_filenames` as a library function which lists the filenames in a particular path.

The first thing we need to do is make a new class derived from HyraxDataset and torch.Dataset as shown below.

In [2]:
from torch.utils.data import Dataset
from hyrax.data_sets import HyraxDataset
from pathlib import Path
from typing import Union


class MyDataset(HyraxDataset, Dataset):
    def __init__(self, config: dict, data_location: Union[Path, str] = None):
        self.filenames = MyDataset._list_filenames(data_location)

        super().__init__(config)

    def __getitem__(self, idx):
        return {
            "data": {"image": MyDataset.get_image(self.filenames[idx])},
        }

    def __len__(self):
        return len(self.filenames)

    @staticmethod
    def _list_filenames(path_to_data):
        """This is a pretend implementation so we ignore path_to_data"""
        global filenames
        return filenames

    def get_image(self, index):
        """Pretend to read specific data from the disk."""
        filename = self.filenames[index]
        global random_data
        return random_data[filename]

Key aspects of this class that you will need to replicate are:

* `__init__` must call `super().__init__(config)` This is important for hyrax to function appropriately, and 
gives you access to hyrax's config in other functions should you want it later. You will probably want to 
access `config["general"]["data_dir"]` to figure out what directory to start in.

* `__getitem__` You must implement this function, it takes an index and return the appropriate torch.Tensor
for your data.

* `__len__` must return the length of your tensorial data.

Note that all of these are instance methods that use `self` as the first argument. This `self` is the current
`MyDataset` object, and allows you to set and get values as is done with `self.filenames` in the code above.

The functions `_list_filenames()` and `_read_tensor()` are both reading our fake data, and are there so we 
have an effective demonstration. The functional organization of your analogous file reading code is entirely 
up to you!


We're now going to start up Hyrax and use the `prepare` verb to create an instance of this class and see
that it works correctly. Note that we have set `config["general]["data_dir"]` to specify the location of our
data for the `__init__`  function we wrote earlier, as well as the `config["data_set"]["name"]` to the 
name of our class, so that Hyrax knows to use our dataset class rather than one of the built-in ones.

Our `h.prepare()` line in the script will have the effect of calling our `__init__` function with the 
current hyrax config.

In [3]:
import hyrax

h = hyrax.Hyrax()
h.config["general"]["data_dir"] = "/fake/path/to/some/data"
h.config["data_set"]["name"] = "MyDataset"

dataset = h.prepare()

[2025-09-12 12:15:10,407 hyrax:INFO] Runtime Config read from: /Users/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-12 12:15:12,179 hyrax.data_sets.data_provider:INFO] No fields were specified for 'data'. The request will be modified to select all by default. You can specify `fields` in `model_inputs`.
[2025-09-12 12:15:12,208 hyrax.prepare:INFO] Finished Prepare


### Testing

The object we recieved from `h.prepare()` is an instance of our dataset, which we can test for functionality

We're going to index into the dataset object with `[]` this has the effect of calling our `__getitem__` function
and returning the result.

We're also going to call `len()` on the dataset which will have the effect of calling our `__len__` function

In [4]:
print("Checking __getitem__ ...", end="\n\n")
item = dataset[0]

print('Shape of our first element, should be "torch.Size([3,10,10])": ')
print(item.shape, end="\n\n")

print("Type of our first element, should be \"<class 'torch.Tensor'>\": ")
print(type(item), end="\n\n")

print("Checking __len__ ...\n\nShould print 0: ")
print(len(dataset))

Checking __getitem__ ...



KeyError: 'object_id'

This dataset class is suitable for training or inference with Hyrax; however, you may want to read on to learn
about more advanced features such as custom IDs for your data elements, metadata, and configuration access.

Below is a short example that uses the HyraxAutoencoder built-in model, demonstrating that training is possible:

In [None]:
import hyrax

h = hyrax.Hyrax()
h.config["general"]["data_dir"] = "/fake/path/to/some/data"
h.config["data_set"]["name"] = "MyDataset"
h.config["model"]["name"] = "HyraxAutoencoder"

h.train()

## Extending to support visualization

This section is primarily concerned with binding different sorts of metadata to your dataset. This metadata
is used by the Hyrax visualization components to identify the source data of your latent space representation
and link it back to a particular object/event in your astronomical dataset. 

When we built `MyDataclass` above, we invisibly picked up two major aspects from `HyraxDataset`:

1. Unique IDs: Every tensor in our dataset got an ID of a sequential zero-based index, which was exactly the 
argument to `__getitem__`/`[]`. This list of ids is available as an iterator by calling `ids()` on the dataset
object. These IDs are used in inference results and visualizations of the data, but they can be overriden.

2. Metadata Interface: Every `HyraxDataset` can provide an astropy `Table` of values in the same order as 
their `__getitem__`/`[]` This allows each tensor in the dataset to have associated scalar data such as ra/dec, 
ephemeris parameters, redshift, magnitude, etc. For our class there currently is no metadata.

Below is how we would access the metadata and IDs demonstrating the default behavior if your custom class
does no overrides:

In [None]:
import hyrax

h = hyrax.Hyrax()
h.config["general"]["data_dir"] = "/fake/path/to/some/data"
h.config["data_set"]["name"] = "MyDataset"

dataset = h.prepare()

print("\nIDs:")
print(f"list(dataset.ids())[0:10] = {list(dataset.ids())[0:10]}")


print("\nMetadata field list:")
print(f"dataset.metadata_fields() = {dataset.metadata_fields()} (there is no metadata)")

### Adding IDs

We're going to use the filename in our fake data as IDs by adding a single `ids()` method to our `MyDataset` 
object. The most expedient way to do this will be to redefine the entire class below. Note that functions 
marked with a comment are just the same as earlier.

Note that the `ids()` function is required to return a generator, so we will use a `for` loop and `yield`
each sequential value. This interface allows Hyrax to partially enumerate the IDs in a dataset when that
is desirable. It is easy enough to get all the ids in order with `list(dataset.ids())`.

In [None]:
from torch.utils.data import Dataset
from hyrax.data_sets import HyraxDataset


class MyDataset(HyraxDataset, Dataset):
    def ids(self):
        for filename in self.filenames:
            yield filename

    # Unchanged from before below this comment ...
    def __init__(self, config: dict):
        self.filenames = MyDataset._list_filenames(config["general"]["data_dir"])
        super().__init__(config)

    def __getitem__(self, idx):
        return MyDataset._read_tensor(self.filenames[idx])

    def __len__(self):
        return len(self.filenames)

    @staticmethod
    def _list_filenames(path_to_data):
        """This is a pretend implementation so we ignore path_to_data"""
        global filenames
        return filenames

    @staticmethod
    def _read_tensor(filename):
        """Pretend to read a particular tensor from the disk."""
        global random_data
        return random_data[filename]

Running `prepare` again on our newly defined dataset class, we can see that the ids are now the fake 
"filenames" we generated at the top of the notebook, rather than sequential integers:

In [None]:
import hyrax

h = hyrax.Hyrax()
h.config["general"]["data_dir"] = "/fake/path/to/some/data"
h.config["data_set"]["name"] = "MyDataset"

dataset = h.prepare()

print("\nIDs:")
print(f"list(dataset.ids())[0:5] = {list(dataset.ids())[0:5]}")

### Adding Metadata

Now we are going to generate some fake metadata for our fake data. This will take the form of 
random ra/dec pairs for each fake object.

In [None]:
import astropy.units as u
from astropy.coordinates import SkyCoord

ras = rng.uniform(low=0.0, high=360.0, size=num_tensors) * u.deg
decs = rng.uniform(low=-90.0, high=90.0, size=num_tensors) * u.deg

In order to override metadata we will provide `HyraxDataset` with an astropy table containing all of the metadata in the constructor for our class as shown below. We do this in `__init__` by passing an astropy table of our metadata to `super().__init__` as a second, optional argument.

Note the new function `_read_metadata()` which constructs this table. On a real dataset this function would
most likely call astropy's `Table.read` [high level interface](https://docs.astropy.org/en/latest/io/unified.html) to construct a table directly from your catalog.

As before we re-implement the entire class below with small modifications marked with comments:

In [None]:
from torch.utils.data import Dataset
from hyrax.data_sets import HyraxDataset


class MyDataset(HyraxDataset, Dataset):
    def __init__(self, config: dict):
        self.filenames = MyDataset._list_filenames(config["general"]["data_dir"])
        metadata_table = MyDataset._read_metadata(config["general"]["data_dir"])
        super().__init__(config, metadata_table=metadata_table)

    def _read_metadata(path_to_data):
        """This is a pretend implementation so we don't use the path passed, which you might use
        to find your .csv/.fits/.tsv catalog file and call astropy's Table.read().

        We simply construct a table from our mock data"""
        from astropy.table import Table

        global ras, decs, filenames
        return Table({"object_id": filenames, "ra": ras, "dec": decs})

    # Unchanged from before below this comment ...
    def ids(self):
        for filename in self.filenames:
            yield filename

    def __getitem__(self, idx):
        return MyDataset._read_tensor(self.filenames[idx])

    def __len__(self):
        return len(self.filenames)

    @staticmethod
    def _list_filenames(path_to_data):
        """This is a pretend implementation so we ignore path_to_data"""
        global filenames
        return filenames

    @staticmethod
    def _read_tensor(filename):
        """Pretend to read a particular tensor from the disk."""
        global random_data
        return random_data[filename]

Now that our dataset class supports metadata, we can access the metadata interface directly on the dataset object using the `metadata_fields` and `metadata` functions on the dataset object.

- `metadata_fields` lists the available fields, in our case only "ra" and "dec" are available, but 
this is only because that is what was defined in the cell above
- `metadata` takes a list (or array) of indexes, and a list (or array) of valid fields. It returns a numpy rec-array of the selected metadata fields for the selected data indexes. It is essentially
equivalent to `metadata_table[indexes][fields].as_array()` where `metadata_table` is the original astropy table.

In [None]:
import hyrax
from astropy.table import Table

h = hyrax.Hyrax()
h.config["general"]["data_dir"] = "/fake/path/to/some/data"
h.config["data_set"]["name"] = "MyDataset"

dataset = h.prepare()

print("\nMetadata field list:")
print(f"dataset.metadata_fields() = {dataset.metadata_fields()}")
print(f'Table(dataset.metadata([1, 3, 4], "ra")) =>')
Table(dataset.metadata([1, 3, 4], ["ra"]))

Now that we have a Dataset capable of 'ra' and 'dec' metadata, we can do a full analysis with hyrax, `train`ing the model, `infer`ing the latent space,`umap`ping the latent space to a 2d representation, and `visualize`-ing the result.

At time of writing `visualize` requres "object_id", "ra" and "dec" fields to be defined in order to work at all. Note the appearance of those same fields in the visualizer table to the immediate right of the "x" and "y" values for the 2d projected latent space.

In [None]:
import hyrax

h = hyrax.Hyrax()
h.config["general"]["data_dir"] = "/fake/path/to/some/data"
h.config["data_set"]["name"] = "MyDataset"
h.config["model"]["name"] = "HyraxAutoencoder"

h.train()
h.infer()
h.umap()
h.visualize()