# Hyrax Getting Started

In this getting started notebook we'll create an instance of a Hyrax object, train a builtin model on the CiFAR training dataset, and then use that trained model to run inference on the CiFAR testing dataset.

## Create a Hyrax instance

In [1]:
import hyrax

h = hyrax.Hyrax()

[2025-08-25 21:14:19,112 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml


## Update the configuration

In [2]:
h.config["model"]["name"] = "HyraxAutoencoder"

For this demo, we'll make a few adjustments to the default configuration settings that the `hyrax` object was instantiated with.
By accessing the `.config` attribute of the hyrax instance, we can modify any configuration value.
There are many configuration values that can be set, but here, we update only the model to train.

## Train a model

In [3]:
h.train()

[2025-08-25 21:14:33,430 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
2025-08-25 21:14:33,503 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<hyrax.data_sets.hyr': 
	{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x70da50623980>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025-08-25 21:14:33,504 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<hyrax.data_sets.hyr': 
	{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x70da5046eea0>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
  from tqdm.autonotebook import tqdm
2025/08/25 21:14:33 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2025-08-25 21:14:33,874 hyrax.pytorch_ignite:INFO] Training model on device: cuda


  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

[2025-08-25 21:16:09,062 hyrax.pytorch_ignite:INFO] Total training time: 95.19[s]
[2025-08-25 21:16:09,063 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250825-211424-train-0D0a/checkpoint_epoch_10.pt
[2025-08-25 21:16:09,063 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250825-211424-train-0D0a/checkpoint_9_loss=-126.9350.pt
2025/08/25 21:16:09 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/08/25 21:16:09 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-08-25 21:16:09,079 hyrax.verbs.train:INFO] Finished Training
[2025-08-25 21:16:09,364 hyrax.model_exporters:INFO] Exported model to ONNX format: /home/drew/code/hyrax/docs/pre_executed/results/20250825-211424-train-0D0a/example_model_opset_20.onnx


HyraxAutoencoder(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): GELU(approximate='none')
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): GELU(approximate='none')
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): GELU(approximate='none')
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): GELU(approximate='none')
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): GELU(approximate='none')
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Linear(in_features=1024, out_features=64, bias=True)
  )
  (dec_linear): Sequential(
    (0): Linear(in_features=64, out_features=1024, bias=True)
    (1): GELU(approximate='none')
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): GELU(approximate='none')
    (2): 

The output of the training will be stored in a time-stamped directory under the `./results/`.
By default, a copy of the final configuration used in training is persisted as `runtime_config.toml`.
To train again with the same configuration, you can reference this runtime_config.toml file.

If running in another notebook, instantiate a hyrax object like so:
```
new_hyrax_instance = hyrax.Hyrax(config_file='./results/<timestamped_directory>/runtime_config.toml')
```

Or from the command line:
```
>> hyrax train --runtime-config ./results/<timestamped_directory>/runtime_config.toml
```

Note here we're training on only a small handful of CiFAR data, but Hyrax has demonstrated that it can scale up to training sets with >1M samples.

## Run inference

In [4]:
h.config["data_set"]["test_size"] = 1.0
h.config["data_set"]["train_size"] = 0.0
h.config["data_set"]["validate_size"] = 0.0
h.config["data_loader"]["batch_size"] = 128

h.infer()

[2025-08-25 21:16:17,243 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-08-25 21:16:17,244 hyrax.verbs.infer:INFO] data set has length 50000
2025-08-25 21:16:17,246 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<hyrax.data_sets.hyr': 
	{'sampler': None, 'batch_size': 128, 'shuffle': False, 'pin_memory': True}
[2025-08-25 21:16:17,263 hyrax.verbs.infer:INFO] Saving inference results at: /home/drew/code/hyrax/docs/pre_executed/results/20250825-211609-infer-SY49
[2025-08-25 21:16:17,602 hyrax.pytorch_ignite:INFO] Evaluating model on device: cuda
[2025-08-25 21:16:17,605 hyrax.pytorch_ignite:INFO] Total epochs: 1


  0%|          | 1/391 [00:00<?, ?it/s]

[2025-08-25 21:16:30,310 hyrax.pytorch_ignite:INFO] Total evaluation time: 12.71[s]
[2025-08-25 21:16:30,434 hyrax.verbs.infer:INFO] Inference Complete.


<hyrax.data_sets.inference_dataset.InferenceDataSet at 0x70db7a0bafc0>

Once a model has been trained, we can use the model weights file to run inference.
By default running `infer` will look for the latest available model weights file.
A specific model weights file can be specified with `h.config['infer']['model_weights_file'] = <path_to_model_weights_file>`.

Here we'll make use of the last trained model weights file, and update the data set splits so that 100% of the data will be used for inference.

With the configuration updated, we can run inference by calling `h.infer()`.

The results of running inference are saved in the output directory noted in the last log line.
The default output format is batched .npy files.
Additionally a ChromaDB vector database will be populated with the inference results to enable efficient similarity search.