# State tensor tracing/inference

There are some scenarios your model needs to return large tensors as output and you need to pass these outputs again to the model to continue the inference. A good example of that is an Encoder. In a scenario like that, moving tensors from HBM to Host Memory and vice-versa is slow and sometimes you will not enough memory to allocate everything.

State tensors allow you to keep a memory space reserved for a set of tensors. To update their values you use an API provided by the traced model, like you see below.

You can see a complete example of how state tensors are important for decoder models [here](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html). As you'll notice, the example is too complex for someone to understand how to use this mechanism. That's why you have this notebook.

Also, you can take a [look at the tracing api](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-trace.html#torch_neuronx.trace), where you see the input parameter **input_output_aliases**.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
  def __init__(self):
      super(MLP, self).__init__()
      ## Dummy linear layer
      self.fc = nn.Linear(2, 2)
      ## Special set of tensors that will transformed into state tensors
      self.cache = nn.ParameterList(
        [
            nn.Parameter(torch.zeros((2,2), dtype=torch.float32), requires_grad=False)
        ]
    ) 
  def forward(self, x):
      x = F.relu(self.fc(x))
      for p in self.cache:
          x = p * x
      cache = [p + 1 for p in self.cache]
      return x, cache

Please note that there is no input parameter for the cache in the forward. I'm returning a 2nd output with a tensor that uses the ParameterList based cache.  
### Warm up

In [None]:
m = MLP()
x = torch.rand((2), dtype=torch.float32)
y,cache = m(x)
y,cache

### Trace

idx=1 because x is the output element 0, so our cache will be 1, given we have only 1 element in the cache. However if you have a list of tensors, like you do when build a kv_cache, you'll have a list and each element of the list will count as an additional parameter

In [None]:
import torch_neuronx

idx=1 # 
aliases = {c:idx+i for i,c in enumerate(m.cache)}
neuron_m = torch_neuronx.trace(m, x, input_output_aliases=aliases)

### Warm up neuron model

you should see something like this as the output:
```
(tensor([[0., 0.],
         [0., 0.]]),
 [tensor([[1., 1.],
          [1., 1.]])])
```

In [None]:
y,cache = neuron_m(x)
print(y,cache)

### Update state tensor
Now let's update the state tensors with a new value of the cache, simulating we're getting the output of the previous invocation and setting new values for the cache:


In [None]:
for c,p in zip(cache, neuron_m.parameters()):
    p.copy_(c)

The command above is how you copy new values to the state tensors without having to duplicate the amount of mem on host.

### Now let's invoke the model again:

You'll see something like:
```
tensor([[0.8259, 0.0000],
        [0.8259, 0.0000]]) [tensor([[2., 2.],
        [2., 2.]])]
```

In [None]:
y,cache = neuron_m(x)
print(y,cache)