# Batch

`Batch` is the internal data structure extensively used in Tianshou. It is designed to store and manipulate hierarchical named tensors. This tutorial aims to help users correctly understand the concept and the behavior of `Batch` so that users can make the best of Tianshou.

The tutorial has three parts. We first explain the concept of hierarchical named tensors, and introduce basic usage of `Batch`, followed by advanced topics of `Batch`.

## Hierarchical Named Tensors

"Hierarchical named tensors" refers to a set of tensors where their names form a hierarchy. Suppose there are four tensors `[t1, t2, t3, t4]` with names `[name1, name2, name3, name4]`, where `name1` and `name2` belong to the same namespace `name0`, then the full name of tensor `t1` is `name0.name1`. That is, the hierarchy lies in the names of tensors.

We can describe the structure of hierarchical named tensors using a tree. There is always a "virtual root" node to represent the whole object; internal nodes are keys (names), and leaf nodes are values (scalars or tensors).

Hierarchical named tensors are needed because we have to deal with the heterogeneity of reinforcement learning problems. The abstraction of RL is very simple:

```python
state, reward, done = env.step(action)
```

`reward` and `done` are simple, they are mostly scalar values. However, the `state` and `action` vary with environments. For example, `state` can be simply a vector, a tensor, or a camera input combined with sensory input. In the last case, it is natural to store them as hierarchical named tensors. This hierarchy can go beyond `state` and `action`: we can store `state`, `action`, `reward`, and `done` together as hierarchical named tensors.

Note that, storing hierarchical named tensors is as easy as creating nested dictionary objects:

```python
{
    'done': done,
    'reward': reward,
    'state': {
        'camera': camera,
        'sensory': sensory
    },
    'action': {
        'direct': direct,
        'point_3d': point_3d,
        'force': force,
    }
}
```

The real problem is how to **manipulate them**, such as adding new transition tuples into replay buffer and dealing with their heterogeneity. `Batch` is designed to easily create, store, and manipulate these hierarchical named tensors.

Think of Batch as a numpy-enhanced version of a Python dictionary. It is also similar to PyTorch's tensordict, although with a somewhat different type structure.

<div align=center>
<img src="https://tianshou.readthedocs.io/en/master/_images/concepts_arch.png", title="The data flow is converted into a Batch in Tianshou">

<a> Data flow is converted into a Batch in Tianshou </a>
</div>

In [70]:
import pickle

import numpy as np
import torch

from tianshou.data import Batch

## Basic Usage

Here we cover some basic usages of `Batch`, describing what `Batch` contains, how to construct `Batch` objects and how to manipulate them.

### What Does Batch Contain

The content of `Batch` objects can be defined by the following rules:

1. A `Batch` object can be an empty `Batch()`, or have at least one key-value pair. `Batch()` can be used to reserve keys, too. See the Advanced Topics section for this usage.

2. The keys are always strings (they are names of corresponding values).

3. The values can be scalars, tensors, or Batch objects. The recursive definition makes it possible to form a hierarchy of batches.

4. Tensors are the most important values. In short, tensors are n-dimensional arrays of the same data type. We support two types of tensors: [PyTorch](https://pytorch.org/) tensor type `torch.Tensor` and [NumPy](https://numpy.org/) tensor type `np.ndarray`.

5. Scalars are also valid values. A scalar is a single boolean, number, or object. They can be Python scalar (`False`, `1`, `2.3`, `None`, `'hello'`) or NumPy scalar (`np.bool_(True)`, `np.int32(1)`, `np.float64(2.3)`). They just shouldn't be mixed up with Batch/dict/tensors.

**Note:** `Batch` cannot store `dict` objects, because internally `Batch` uses `dict` to store data. During construction, `dict` objects will be automatically converted to `Batch` objects.

The data types of tensors are bool and numbers (any size of int and float as long as they are supported by NumPy or PyTorch). Besides, NumPy supports ndarray of objects and we take advantage of this feature to store non-number objects in `Batch`. If one wants to store data that are neither boolean nor numbers (such as strings and sets), they can store the data in `np.ndarray` with the `np.object` data type. This way, `Batch` can store any type of Python objects.

In [71]:
data = Batch(a=4, b=[5, 5], c="2312312", d=("a", -2, -3))
print(data)
print(data.b)

Batch(
    a: array(4),
    b: array([5, 5]),
    c: '2312312',
    d: array(['a', '-2', '-3'], dtype=object),
)
[5 5]


A batch stores all passed in data as key-value pairs, and automatically turns the value into a numpy array if possible.

### Construction of Batch

There are two ways to construct a `Batch` object: from a `dict`, or using `kwargs`. Below are some examples.

#### Construct Batch from dict

In [72]:
# Directly passing a dict object (possibly nested) is ok
data = Batch({"a": 4, "b": [5, 5], "c": "2312312"})
# The list will automatically be converted to numpy array
print(data.b)
data.b = np.array([3, 4, 5])
print(data)

[5 5]
Batch(
    a: array(4),
    b: array([3, 4, 5]),
    c: '2312312',
)


In [73]:
# A list of dict objects (possibly nested) will be automatically stacked
data = Batch([{"a": 0.0, "b": "hello"}, {"a": 1.0, "b": "world"}])
print(data)

Batch(
    a: array([0., 1.]),
    b: array(['hello', 'world'], dtype=object),
)


#### Construct Batch from kwargs

In [74]:
# Construct a Batch with keyword arguments
data = Batch(a=[4, 4], b=[5, 5], c=[None, None])
print(data)

Batch(
    a: array([4, 4]),
    b: array([5, 5]),
    c: array([None, None], dtype=object),
)


In [75]:
# Combining keyword arguments and batch_dict works fine
data = Batch(
    {"a": [4, 4], "b": [5, 5]}, c=[None, None]
)  # the first argument is a dict, and 'c' is a keyword argument
print(data)

Batch(
    a: array([4, 4]),
    b: array([5, 5]),
    c: array([None, None], dtype=object),
)


In [76]:
arr = np.zeros((3, 4))
# By default, Batch only keeps the reference to the data, but it also supports data copying
data = Batch(arr=arr, copy=True)  # data.arr now is a copy of 'arr'

#### Nested Batch construction

In [77]:
# The dictionary can be nested, and it will be turned into a nested Batch
data = {
    "action": np.array([1.0, 2.0, 3.0]),
    "reward": 3.66,
    "obs": {
        "rgb_obs": np.zeros((3, 3)),
        "flatten_obs": np.ones(5),
    },
}

batch = Batch(data, extra="extra_string")
print(batch)
# batch.obs is also a Batch
print(type(batch.obs))
print(batch.obs.rgb_obs)

Batch(
    action: array([1., 2., 3.]),
    reward: array(3.66),
    obs: Batch(
             rgb_obs: array([[0., 0., 0.],
                             [0., 0., 0.],
                             [0., 0., 0.]]),
             flatten_obs: array([1., 1., 1., 1., 1.]),
         ),
    extra: 'extra_string',
)
<class 'tianshou.data.batch.Batch'>
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [78]:
# A list of dictionary/Batch will automatically be concatenated/stacked, providing convenience if you
# want to use parallelized environments to collect data.
batch = Batch([data] * 3)
print(batch)
print(batch.obs.rgb_obs.shape)

Batch(
    action: array([[1., 2., 3.],
                   [1., 2., 3.],
                   [1., 2., 3.]]),
    reward: array([3.66, 3.66, 3.66]),
    obs: Batch(
             flatten_obs: array([[1., 1., 1., 1., 1.],
                                 [1., 1., 1., 1., 1.],
                                 [1., 1., 1., 1., 1.]]),
             rgb_obs: array([[[0., 0., 0.],
                              [0., 0., 0.],
                              [0., 0., 0.]],
                      
                             [[0., 0., 0.],
                              [0., 0., 0.],
                              [0., 0., 0.]],
                      
                             [[0., 0., 0.],
                              [0., 0., 0.],
                              [0., 0., 0.]]]),
         ),
)
(3, 3, 3)


### Data Manipulation With Batch

Users can access the internal data by `b.key` or `b[key]`, where `b.key` finds the sub-tree with `key` as the root node. If the result is a sub-tree with non-empty keys, the key-reference can be chained, i.e. `b.key.key1.key2.key3`. When it reaches a leaf node, users get the data (scalars/tensors) stored in that `Batch` object.

In [79]:
data = Batch(a=4, b=[5, 5])
print(data.b)
# obj.key is equivalent to obj["key"]
print(data["a"])

[5 5]
4


In [80]:
# Iterating over data items like a dict is supported
for key, value in data.items():
    print(f"{key}: {value}")

a: 4
b: [5 5]


In [81]:
# obj.keys() and obj.values() work just like dict.keys() and dict.values()
for key in data.keys():
    print(f"{key}")

a
b


In [82]:
# obj.update() behaves like dict.update()
# this is the same as data.c = 1; data.d = 2; data.e = 3;
data.update(c=1, d=2, e=3)
print(data)

Batch(
    a: array(4),
    b: array([5, 5]),
    c: array(1),
    d: array(2),
    e: array(3),
)


In [83]:
# Add or delete key-value pair in batch
batch1 = Batch({"a": [4, 4], "b": (5, 5)})
print(batch1)

batch1.c = Batch(c1=np.arange(3), c2=False)
del batch1.a
print(batch1)

# Access value by key
assert batch1["c"] is batch1.c
print("c" in batch1)

Batch(
    a: array([4, 4]),
    b: array([5, 5]),
)
Batch(
    b: array([5, 5]),
    c: Batch(
           c1: array([0, 1, 2]),
           c2: array(False),
       ),
)
True


**Note:** If `data` is a `dict` object, `for x in data` iterates over keys in the dict. However, it has a different meaning for `Batch` objects: `for x in data` iterates over `data[0], data[1], ..., data[-1]`.

### Length, Shape, Indexing, and Slicing of Batch

`Batch` also partially reproduces the NumPy ndarray APIs. It supports advanced slicing, such as `batch[:, i]` so long as the slice is valid. Broadcast mechanism of NumPy works for `Batch`, too.

In [84]:
# Initialize Batch with tensors
data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5.0, -5.0], [1.0, -2.0]])
# If values have the same length/shape, that length/shape is used for this Batch
print(len(data))
print(data.shape)

2
[2, 2]


In [85]:
# Access the first item of all the stored tensors, while keeping the structure of Batch
print(data[0])

Batch(
    a: array([0., 2.]),
    b: array([ 5., -5.]),
)


In [86]:
# Iterates over data[0], data[1], ..., data[-1]
for sample in data:
    print(sample.a)

[0. 2.]
[1. 3.]


In [87]:
# Advanced slicing works just fine
# Arithmetic operations are passed to each value in the Batch, with broadcast enabled
data[:, 1] += 1
print(data)

Batch(
    a: array([[0., 3.],
              [1., 4.]]),
    b: array([[ 5., -4.],
              [ 1., -1.]]),
)


In [88]:
# Amazingly, you can directly apply np.mean to a Batch object
print(np.mean(data))

Batch(
    a: array(2.),
    b: array(0.25),
)


In [89]:
# Directly converted to a list is also available
list(data)

[Batch(
     a: array([0., 3.]),
     b: array([ 5., -4.]),
 ),
 Batch(
     a: array([1., 4.]),
     b: array([ 1., -1.]),
 )]

#### Example with environment stepping

In [90]:
# Let us suppose we have collected the data from stepping from 4 environments
step_outputs = [
    {
        "act": np.random.randint(10),
        "rew": 0.0,
        "obs": np.ones((3, 3)),
        "info": {"done": np.random.choice(2), "failed": False},
        "terminated": False,
        "truncated": False,
    }
    for _ in range(4)
]
batch = Batch(step_outputs)
print(batch)
print(batch.shape)

Batch(
    terminated: array([False, False, False, False]),
    obs: array([[[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]],
         
                [[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]],
         
                [[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]],
         
                [[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]]]),
    truncated: array([False, False, False, False]),
    rew: array([0., 0., 0., 0.]),
    info: Batch(
              done: array([1, 1, 0, 1]),
              failed: array([False, False, False, False]),
          ),
    act: array([9, 1, 6, 6]),
)
[4]


In [91]:
# Advanced indexing is supported, if we only want to select data in a given set of environments
print(batch[0])
print(batch[[0, 3]])

Batch(
    terminated: False,
    obs: array([[1., 1., 1.],
                [1., 1., 1.],
                [1., 1., 1.]]),
    truncated: False,
    rew: 0.0,
    info: Batch(
              done: 1,
              failed: False,
          ),
    act: 9,
)
Batch(
    terminated: array([False, False]),
    obs: array([[[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]],
         
                [[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]]]),
    truncated: array([False, False]),
    rew: array([0., 0.]),
    info: Batch(
              done: array([1, 1]),
              failed: array([False, False]),
          ),
    act: array([9, 6]),
)


In [92]:
# Slicing is also supported
print(batch[-2:])

Batch(
    terminated: array([False, False]),
    obs: array([[[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]],
         
                [[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]]]),
    truncated: array([False, False]),
    rew: array([0., 0.]),
    info: Batch(
              done: array([0, 1]),
              failed: array([False, False]),
          ),
    act: array([6, 6]),
)


### Stack / Concatenate / Split of Batches

Stacking and concatenating multiple `Batch` instances, or splitting an instance into multiple batches, are all easy and intuitive in Tianshou. For now, we stick to the aggregation (stack/concatenate) of homogeneous (same structure) batches.

In [93]:
data_1 = Batch(a=np.array([0.0, 2.0]), b=5)
data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)
data = Batch.stack((data_1, data_2))
print(data)

Batch(
    a: array([[0., 2.],
              [1., 3.]]),
    b: array([ 5, -5]),
)


In [94]:
# Split supports random shuffling
data_split = list(data.split(1, shuffle=False))
print(data_split)

[Batch(
    a: array([[0., 2.]]),
    b: array([5]),
), Batch(
    a: array([[1., 3.]]),
    b: array([-5]),
)]


In [95]:
data_cat = Batch.cat(data_split)
print(data_cat)

Batch(
    a: array([[0., 2.],
              [1., 3.]]),
    b: array([ 5, -5]),
)


#### More concatenation and stacking examples

In [96]:
# Concat batches with compatible keys
b1 = Batch(a=[{"b": np.float64(1.0), "d": Batch(e=np.array(3.0))}])
b2 = Batch(a=[{"b": np.float64(4.0), "d": {"e": np.array(6.0)}}])
b12_cat_out = Batch.cat([b1, b2])
print(b1)
print(b2)
print(b12_cat_out)

Batch(
    a: Batch(
           d: Batch(
                  e: array([3.]),
              ),
           b: array([1.]),
       ),
)
Batch(
    a: Batch(
           d: Batch(
                  e: array([6.]),
              ),
           b: array([4.]),
       ),
)
Batch(
    a: Batch(
           d: Batch(
                  e: array([3., 6.]),
              ),
           b: array([1., 4.]),
       ),
)


In [97]:
# Stack batches with compatible keys
b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))
b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))
b34_stack = Batch.stack((b3, b4), axis=1)
print(b3)
print(b4)
print(b34_stack)

Batch(
    a: array([[0., 0.],
              [0., 0.],
              [0., 0.]]),
    b: array([[1., 1., 1.],
              [1., 1., 1.]]),
    c: Batch(
           d: array([[1],
                     [2]]),
       ),
)
Batch(
    a: array([[1., 1.],
              [1., 1.],
              [1., 1.]]),
    b: array([[1., 1., 1.],
              [1., 1., 1.]]),
    c: Batch(
           d: array([[0],
                     [3]]),
       ),
)
Batch(
    b: array([[[1., 1., 1.],
               [1., 1., 1.]],
       
              [[1., 1., 1.],
               [1., 1., 1.]]]),
    a: array([[[0., 0.],
               [1., 1.]],
       
              [[0., 0.],
               [1., 1.]],
       
              [[0., 0.],
               [1., 1.]]]),
    c: Batch(
           d: array([[[1],
                      [0]],
              
                     [[2],
                      [3]]]),
       ),
)


In [98]:
# Split the batch into small batches of size 1, breaking the order of the data
print(type(b34_stack.split(1)))
print(list(b34_stack.split(1, shuffle=True)))

<class 'generator'>
[Batch(
    b: array([[[1., 1., 1.],
               [1., 1., 1.]]]),
    a: array([[[0., 0.],
               [1., 1.]]]),
    c: Batch(
           d: array([[[2],
                      [3]]]),
       ),
), Batch(
    b: array([[[1., 1., 1.],
               [1., 1., 1.]]]),
    a: array([[[0., 0.],
               [1., 1.]]]),
    c: Batch(
           d: array([[[1],
                      [0]]]),
       ),
)]


### Data Type Converting

Besides numpy array, Batch also supports PyTorch Tensor. The usages are exactly the same.

In [99]:
batch1 = Batch(a=np.arange(2), b=torch.zeros((2, 2)))
batch2 = Batch(a=np.arange(2), b=torch.ones((2, 2)))
batch_cat = Batch.cat([batch1, batch2, batch1])
print(batch_cat)

Batch(
    a: array([0, 1, 0, 1, 0, 1]),
    b: tensor([[0., 0.],
               [0., 0.],
               [1., 1.],
               [1., 1.],
               [0., 0.],
               [0., 0.]]),
)


You can convert the data type easily, if you no longer want to use hybrid data types.

In [100]:
data = Batch(a=np.zeros((3, 4)))
data.to_torch_(dtype=torch.float32, device="cpu")
print(data.a)
# data.to_numpy_ is also available
data.to_numpy_()
print(data.a)

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], dtype=torch.float64)
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]


In [101]:
batch_cat.to_numpy_()
print(batch_cat)
batch_cat.to_torch_()
print(batch_cat)

Batch(
    a: array([0, 1, 0, 1, 0, 1]),
    b: array([[0., 0.],
              [0., 0.],
              [1., 1.],
              [1., 1.],
              [0., 0.],
              [0., 0.]], dtype=float32),
)
Batch(
    a: tensor([0, 1, 0, 1, 0, 1], dtype=torch.int32),
    b: tensor([[0., 0.],
               [0., 0.],
               [1., 1.],
               [1., 1.],
               [0., 0.],
               [0., 0.]]),
)


### Serialization

Batch is serializable and therefore Pickle compatible. `Batch` objects can be saved to disk and later restored by the Python `pickle` module. This pickle compatibility is especially important for distributed sampling from environments.

In [102]:
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
print(batch_pk)

Batch(
    obs: Batch(
             a: array(0.),
             c: tensor([1., 2.]),
         ),
    np: array([[0., 0., 0., 0.],
               [0., 0., 0., 0.],
               [0., 0., 0., 0.]]),
)


## Advanced Topics

From here on, this tutorial focuses on advanced topics of `Batch`, including key reservation, length/shape, and aggregation of heterogeneous batches.

### Key Reservations

In many cases, we know in the first place what keys we have, but we do not know the shape of values until we run the environment. To deal with this, Tianshou supports key reservations: **reserve a key and use a placeholder value**.

The usage is easy: just use `Batch()` to be the value of reserved keys.

In [103]:
a = Batch(b=Batch())  # 'b' is a reserved key
print(a)

# This is called hierarchical key reservation
a = Batch(b=Batch(c=Batch()), d=Batch())  # 'c' and 'd' are reserved keys
print(a)

a = Batch(key1=np.array([1, 2]), key2=np.array([3, 4]), key3=Batch(key4=Batch(), key5=Batch()))
print(a)

Batch(
    b: Batch(),
)
Batch(
    b: Batch(
           c: Batch(),
       ),
    d: Batch(),
)
Batch(
    key1: array([1, 2]),
    key2: array([3, 4]),
    key3: Batch(
              key4: Batch(),
              key5: Batch(),
          ),
)


Still, we can use a tree to show the structure of `Batch` objects with reserved keys, where reserved keys are special internal nodes that do not have attached leaf nodes.

**Note:** Reserved keys mean that in the future there will eventually be values attached to them. The values can be scalars, tensors, or even **Batch** objects. Understanding this is critical to understanding the behavior of `Batch` when dealing with heterogeneous Batches.

The introduction of reserved keys gives rise to the need to check if a key is reserved.

In [104]:
# Examples of checking whether Batch is empty
print(len(Batch().get_keys()) == 0)
print(len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) == 0)
print(len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0)
print(len(Batch(d=1).get_keys()) == 0)
print(len(Batch(a=np.float64(1.0)).get_keys()) == 0)

True
False
True
False
False


To check whether a Batch is empty, simply use `len(Batch.get_keys()) == 0` to decide whether to identify direct emptiness (just a `Batch()`) or `len(Batch) == 0` to identify recursive emptiness (a `Batch` object without any scalar/tensor leaf nodes).

**Note:** Do not get confused with `Batch.empty`. `Batch.empty` and its in-place variant `Batch.empty_` are used to set some values to zeros or None. Check the API documentation for further details.

### Length and Shape Details

The most common usage of `Batch` is to store a batch of data. The term "Batch" comes from the deep learning community to denote a mini-batch of sampled data from the whole dataset. In this regard, "Batch" typically means a collection of tensors whose first dimensions are the same. Then the length of a `Batch` object is simply the batch-size.

If all the leaf nodes in a `Batch` object are tensors, but they have different lengths, they can be readily stored in `Batch`. However, for `Batch` of this kind, the `len(obj)` seems a bit ambiguous. Currently, Tianshou returns the length of the shortest tensor, but we strongly recommend that users do not use the `len(obj)` operator on `Batch` objects with tensors of different lengths.

In [105]:
# Examples of len and obj.shape for Batch objects
data = Batch(a=[5.0, 4.0], b=np.zeros((2, 3, 4)))
print(data.shape)
print(len(data))
print(data[0].shape)
try:
    len(data[0])
except TypeError as e:
    print(f"TypeError: {e}")

[2]
2
[]
TypeError: Entry for a in Batch(
    a: 5.0,
    b: array([[0., 0., 0., 0.],
              [0., 0., 0., 0.],
              [0., 0., 0., 0.]]),
) is 5.0 has no len()


**Note:** Following the convention of scientific computation, scalars have no length. If there is any scalar leaf node in a `Batch` object, an exception will occur when users call `len(obj)`.

Besides, values of reserved keys are undetermined, so they have no length, neither. Or, to be specific, values of reserved keys have lengths of **any**. When there is a mix of tensors and reserved keys, the latter will be ignored in `len(obj)` and the minimum length of tensors is returned. When there is not any tensor in the `Batch` object, Tianshou raises an exception, too.

The `obj.shape` attribute of `Batch` behaves somewhat similar to `len(obj)`:

1. If all the leaf nodes in a `Batch` object are tensors with the same shape, that shape is returned.

2. If all the leaf nodes in a `Batch` object are tensors but they have different shapes, the minimum length of each dimension is returned.

3. If there is any scalar value in a `Batch` object, `obj.shape` returns `[]`.

4. The shape of reserved keys is undetermined, too. We treat their shape as `[]`.

### Aggregation of Heterogeneous Batches

In this section, we talk about aggregation operators (stack/concatenate) on heterogeneous `Batch` objects. We only consider the heterogeneity in the structure of `Batch` objects. The aggregation operators are eventually done by NumPy/PyTorch operators (`np.stack`, `np.concatenate`, `torch.stack`, `torch.cat`). Heterogeneity in values can fail these operators (such as stacking `np.ndarray` with `torch.Tensor`, or stacking tensors with different shapes) and an exception will be raised.

The behavior is natural: for keys that are not shared across all batches, batches that do not have these keys will be padded by zeros (or `None` if the data type is `np.object`).

In [106]:
# Examples of stack: a is missing key `b`, and b is missing key `a`
a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
c = Batch.stack([a, b])
print(c.a.shape)
print(c.b.shape)
print(c.common.c.shape)

(2, 4, 4)
(2, 4, 6)
(2, 4, 5)


In [107]:
# None or 0 is padded with appropriate shape
data_1 = Batch(a=np.array([0.0, 2.0]))
data_2 = Batch(a=np.array([1.0, 3.0]), b="done")
data = Batch.stack((data_1, data_2))
print(data)

Batch(
    a: array([[0., 2.],
              [1., 3.]]),
    b: array([None, 'done'], dtype=object),
)


In [108]:
# Examples of cat: a is missing key `b`, and b is missing key `a`
a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
# TODO: Something has changed; this is actually no longer admissible!
# c = Batch.cat([a, b])
# print(c.a.shape)
# print(c.b.shape)
# print(c.common.c.shape)

However, there are some cases when batches are too heterogeneous that they cannot be aggregated:

In [109]:
# This will raise an exception
try:
    a = Batch(a=np.zeros([4, 4]))
    b = Batch(a=Batch(b=Batch()))
    c = Batch.stack([a, b])
except Exception as e:
    print(f"Exception: {e}")



Then how to determine if batches can be aggregated? Let's rethink the purpose of reserved keys. What is the advantage of `a1=Batch(b=Batch())` over `a2=Batch()`? The only difference is that `a1.b` returns `Batch()` but `a2.b` raises an exception. That's to say, **we reserve keys for attribute reference**.

We say a key chain `k=[key1, key2, ..., keyn]` applies to `b` if the expression `b.key1.key2.{...}.keyn` is valid, and the result is `b[k]`.

For a set of `Batch` objects denoted as S, they can be aggregated if there exists a `Batch` object `b` satisfying the following rules:

1. **Key chain applicability:** For any object `bi` in S, and any key chain `k`, if `bi[k]` is valid, then `b[k]` is valid.

2. **Type consistency:** If `bi[k]` is not `Batch()` (the last key in the key chain is not a reserved key), then the type of `b[k]` should be the same as `bi[k]` (both should be scalar/tensor/non-empty Batch values).

The `Batch` object `b` satisfying these rules with the minimum number of keys determines the structure of aggregating S. The values are relatively easy to define: for any key chain `k` that applies to `b`, `b[k]` is the stack/concatenation of `[bi[k] for bi in S]` (if `k` does not apply to `bi`, the appropriate size of zeros or `None` are filled automatically). If `bi[k]` are all `Batch()`, then the aggregation result is also an empty `Batch()`.

### Miscellaneous Notes

1. It is often the case that the observations returned from the environment are all NumPy ndarray but the policy requires `torch.Tensor` for prediction and learning. In this regard, Tianshou provides helper functions to convert the stored data in-place into Numpy arrays or Torch tensors.

2. `obj.stack_([a, b])` is the same as `Batch.stack([obj, a, b])`, and `obj.cat_([a, b])` is the same as `Batch.cat([obj, a, b])`. Considering the frequent requirement of concatenating two `Batch` objects, Tianshou also supports `obj.cat_(a)` to be an alias of `obj.cat_([a])`.

3. `Batch.cat` and `Batch.cat_` does not support `axis` argument as `np.concatenate` and `torch.cat` currently.

4. `Batch.stack` and `Batch.stack_` support the `axis` argument so that one can stack batches besides the first dimension. But be cautious, if there are keys that are not shared across all batches, `stack` with `axis != 0` is undefined, and will cause an exception currently.

## Further Reading

Would you like to learn more advanced usages of Batch? Feel curious about how data is organized inside the Batch? Check the [documentation](https://tianshou.readthedocs.io/en/master/03_api/data/batch.html) for more details.