## Training a Neural Tree

So far we've learned the basic structure of a `NeuralTree` and seen how task objects are used to interface with datasets.
Now we'll see how a `NeuralTree` is trained.


### Load the configuration file

In [None]:
from omegaconf import OmegaConf
import hydra

with hydra.initialize(config_path="./hydra"):
    cfg = hydra.compose(config_name="3_training_a_neural_tree")
    OmegaConf.set_struct(cfg, False)
print(OmegaConf.to_yaml(cfg))

### Setup

Set random seeds and instantiate the model

In [None]:
import lightning as L

# set random seed
L.seed_everything(seed=cfg.random_seed, workers=True)



### Scenario 1: Training data is read from disk

In this situation the `log_fluorescence` task data is read from disk using `cortex.data.dataset.TAPEFluorescenceDataset`.
We can load the data by passing `skip_task_setup=False` to the `build_tree` method.

In [None]:
# instantiate model
model = hydra.utils.instantiate(cfg.tree)
model.build_tree(cfg, skip_task_setup=False)

### Scenario 2: Training data is passed as NumPy arrays at runtime

In this situation the `log_fluorescence` task data is stored in memory at runtime as a 
generic `cortex.data.dataset.NumpyDataset` object.
We can load the data by passing `skip_task_setup=True` to the `build_tree` method,
then manually calling `task.data_module.setup`, passing the data as a keyword argument.

In [None]:
from cortex.data.dataset import TAPEFluorescenceDataset
import pandas as pd
from omegaconf import DictConfig

cfg.tasks.protein_property.log_fluorescence.data_module.dataset_config = DictConfig({"_target_": "cortex.data.dataset.NumpyDataset", "train": "???"})
model = hydra.utils.instantiate(cfg.tree)
model.build_tree(cfg, skip_task_setup=True)

root = "./.cache"
train_dataset = TAPEFluorescenceDataset(root=root, train=True, download=True)
test_dataset = TAPEFluorescenceDataset(root=root, train=False, download=False)

src_df = pd.concat([train_dataset._data, test_dataset._data], ignore_index=True)

task_setup_kwargs = {
    # task_key: 
    "log_fluorescence": {
        # dataset kwarg
        "data": {
            "tokenized_seq": src_df["tokenized_seq"].values,
            "log_fluorescence": src_df["log_fluorescence"].values,
        }
    }
}

for task_key, task_obj in model.task_dict.items():
    task_obj.data_module.setup(stage="test", dataset_kwargs=task_setup_kwargs[task_key])
    task_obj.data_module.setup(stage="fit", dataset_kwargs=task_setup_kwargs[task_key])

In [None]:
# instantiate trainer, set logger
trainer = hydra.utils.instantiate(cfg.trainer)

trainer.fit(
    model,
    train_dataloaders=model.get_dataloader(split="train"),
    val_dataloaders=model.get_dataloader(split="val"),
)