## Defining a Task

In the last tutorial we learned the basic structure of a `NeuralTree`. In this section we will explore `Task` objects, which define the interface
between a `NeuralTree` and the datasets we will use to train it.

### A Dataset

Every task starts with a dataset. In this example we will use the GFP fluorescence dataset (TODO add link).

In [None]:
from cortex.data.dataset import TAPEFluorescenceDataset

dataset = TAPEFluorescenceDataset(
    root='./.cache',
    download=True,
    train=True,
)
dataset[0]

### A Task Data Module

The `cortex` package uses the `lightning` package to handle data loading and distributed training. 
The `TaskDataModule` subclasses `lightning.DataModule`.

In [None]:
from cortex.data.data_module import TaskDataModule
from omegaconf import DictConfig

dataset_cfg = DictConfig(
    {
        '_target_': 'cortex.data.dataset.TAPEFluorescenceDataset',
        'root': './.cache',
        'download': True,
        'train': "???"
    }
)

data_module = TaskDataModule(
    batch_size=2,
    dataset_config=dataset_cfg,
)

train_loader = data_module.train_dataloader()
batch = next(iter(train_loader))
print(batch)

### A Task object

A task object in `cortex` determines how a batch of data from a dataloader is passed to a `NeuralTree` during training.

In [None]:
from cortex.task import RegressionTask

task = RegressionTask(
    data_module=data_module,
    input_map={"protein_seq": ["tokenized_seq"]},  # {root_key: [input_key]}
    outcome_cols=["log_fluorescence"],  # [*target_keys]
    leaf_key="log_fluorescence_0"  # name of leaf node
)

formatted_batch = task.format_batch(batch)
print(formatted_batch)

## Usage

Now we will instantiate the a `NeuralTree` similar to the last tutorial, however we will use Hydra to simplify the instantiation.

In [None]:
import hydra
from omegaconf import OmegaConf

with hydra.initialize(config_path="./hydra"):
    cfg = hydra.compose(config_name="2_defining_a_task")
    OmegaConf.set_struct(cfg, False)

tree = hydra.utils.instantiate(cfg.tree)
tree.build_tree(cfg)
tree

In [None]:
tree_output = tree(formatted_batch["root_inputs"])
tree_output.leaf_outputs["log_fluorescence_0"].loc

### Computing a task loss

In [None]:
leaf_key = "log_fluorescence_0"
leaf_node = tree.leaf_nodes[leaf_key]

loss = leaf_node.loss(
    leaf_outputs=tree_output.leaf_outputs[leaf_key],
    root_outputs=tree_output.root_outputs["protein_seq"],
    **formatted_batch["leaf_targets"][leaf_key]
)
print(loss)

### Evaluating task output

In [None]:
leaf_node.evaluate(
    outputs=tree_output.leaf_outputs[leaf_key],
    **formatted_batch["leaf_targets"][leaf_key]
)