## Anatomy of a Neural Tree

### Introduction

A `NeuralTree` is a subclass of `torch.Module` with submodules we will call _nodes_.
Each node outputs a `dataclass`, which allows users to add or remove fields 
as needed for their specific use case.
A `NeuralTree` can be used to represent a wide variety of neural network architectures,
from simple feedforward networks to complex architectures like multi-modal, multi-task transformers.

A `NeuralTree` has four types of nodes:
- `RootNode`: this type of node takes in one specific input modality and embeds the input into a continuous feature space.
- `TrunkNode`: this type of node takes in the outputs of one or more root nodes and aggregates them into a shared feature space.
- `BranchNode`: this type of node takes in the outputs of a trunk node and learns task-specific features.
- `LeafNode`: this type of node takes in the outputs of a branch node and transforms them into a task-specific output 
(e.g. classifier logits or the predictive mean of a regressor).

In this tutorial we will manually construct a simple ResNet for protein sequence classification. 

### Hyperparameters

In [None]:
max_context_len = 16  # maximum number of input tokens
feature_dim = 32  # dimension of output features
kernel_size = 3  # size of convolutional kernels 

### A Root Node

In [None]:
import numpy as np
from cortex.model.root import Conv1dRoot
from cortex.tokenization import ProteinSequenceTokenizerFast
from cortex.transforms import HuggingFaceTokenizerTransform

tokenizer = ProteinSequenceTokenizerFast()
root_node = Conv1dRoot(
        tokenizer_transform=HuggingFaceTokenizerTransform(tokenizer),
        max_len=max_context_len,
        embed_dim=feature_dim,  # dimension of initial token embeddings
        channel_dim=feature_dim,  # dimension of intermediate features
        out_dim=feature_dim,
        num_blocks=2,  # number of residual blocks
        kernel_size=kernel_size,
)

protein_seqs = np.array(["M K K I A I A", "E V Q I A I A E"])
root_output = root_node(seq_array=protein_seqs)

print(root_output.root_features.shape)

### A Trunk Node

A `SumTrunk` takes the features of one or more root nodes and sums them together.
If the root nodes have different dimensions, the `SumTrunk` will use a linear layer to project them to the same dimension before summing them together.
In this example we only have one root node and the features are already the same as the output dimension so the trunk is equivalent to an identity function.

In a later tutorial we will see how to use a `SumTrunk` to combine the outputs of multiple root nodes to make a multi-modal model.

In [None]:
from cortex.model.trunk import SumTrunk

num_roots = 1
trunk_node = SumTrunk(in_dims=[feature_dim] * num_roots, out_dim=feature_dim)
trunk_output = trunk_node(root_output)

print(trunk_output.trunk_features.shape)


### A Branch Node

A `BranchNode` learns features for a specific task (or group of tasks). 
In this example we only have one task so we will set `num_blocks=0`, which means the branch node will be equivalent to an identity function.

In a later tutorial we will see how to use a `BranchNode` to learn different features for different tasks in a multi-task model.

In [None]:
from cortex.model.branch import Conv1dBranch

branch_node = Conv1dBranch(
    in_dim=feature_dim,
    channel_dim=feature_dim,
    out_dim=feature_dim,
    num_blocks=0,
    kernel_size=kernel_size,
)
branch_output = branch_node(trunk_output)
print(branch_output.branch_features.shape)

### A Leaf Node

In [None]:
from cortex.model.leaf import ClassifierLeaf

leaf_node = ClassifierLeaf(
    in_dim=feature_dim,
    num_classes=2,
    branch_key="protein_features_0"  # used to attach the leaf to a particular branch
)
leaf_output = leaf_node(branch_output)
print(leaf_output.logits.shape)

### A Neural Tree

We've defined all the nodes we need and shown how to manually pass data through them. 
We can accomplish the same thing by using a `NeuralTree` to manage the nodes for us.


In [None]:
from torch import nn
from cortex.model.tree import SequenceModelTree

tree = SequenceModelTree(
    root_nodes=nn.ModuleDict({"protein_seq": root_node}),  # {"root_key": root_node}
    trunk_node=trunk_node,
    branch_nodes=nn.ModuleDict({"protein_features_0": branch_node}),  # {"branch_key": branch_node}
    leaf_nodes=nn.ModuleDict({"protein_property_0": leaf_node}),  # {"leaf_key": leaf_node}
)

tree_input = {"protein_seq": {"seq_array": protein_seqs}}  # {"root_key": {**root_kwargs}}
tree_output = tree(tree_input)

print(tree_output.leaf_outputs["protein_property_0"].logits.shape)

### Summary

What have we accomplished? This was certainly not the simplest way to construct this particular model, and in fact may have seemed awkward and overly complicated.
As we will see in the next tutorial, the real power of these abstractions comes when we want to define more complex models, or take an existing model and extend its behavior.