# Batch: Tianshou's Core Data Structure

The `Batch` class is Tianshou's fundamental data structure for efficiently storing and manipulating heterogeneous data in reinforcement learning. This tutorial provides comprehensive guidance on understanding its conceptual foundations, operational behavior, and best practices.


In [None]:
import pickle
from typing import cast

import numpy as np
import torch
from torch.distributions import Categorical, Normal

from tianshou.data import Batch
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol

## 1. Introduction: Why Batch?

### The Challenge in Reinforcement Learning

Reinforcement learning algorithms face a fundamental data management challenge:

1. **Diverse Data Requirements**: Different RL algorithms need different data fields:
   - Basic algorithms: `state`, `action`, `reward`, `done`, `next_state`
   - Actor-Critic: additionally `advantages`, `returns`, `values`
   - Policy Gradient: additionally `log_probs`, `old_log_probs`
   - Off-policy: additionally `priority_weights`

2. **Heterogeneous Observation Spaces**: Environments return diverse observation types:
   - Simple: vectors (`np.array([1.0, 2.0, 3.0])`)
   - Complex: images (`np.array(shape=(84, 84, 3))`)
   - Hybrid: dictionaries combining multiple modalities
   ```python
   obs = {
       'camera': np.array(shape=(64, 64, 3)),
       'velocity': np.array([1.2, 0.5]),
       'inventory': np.array([5, 2, 0])
   }
   ```

3. **Data Flow Across Components**: Data must flow seamlessly through:
   - Collectors (gathering experience from environments)
   - Replay Buffers (storing and sampling transitions)
   - Policies and Algorithms (learning and inference)

### Why Not Alternatives?

#### Plain Dictionaries
Dictionaries lack essential features
```python
data = {'obs': np.array([1, 2]), 'reward': np.array([1.0, 2.0])}
```

They would work in principle but has no shape/length semantics, no indexing, and no type safety.

#### TensorDict
While `TensorDict` (used in `pytorch-rl`) is a powerful alternative:
- **Batch supports arbitrary objects**, not just tensors (useful for object-dtype arrays, custom types)
- **Batch has better type checking** via `BatchProtocol` (enables IDE autocompletion)
- **Batch preceded TensorDict** and provides a stable foundation for Tianshou
- **TensorDict isn't part of core PyTorch** (external dependency)

### What is Batch?

**Batch = Dictionary + Array hybrid with RL-specific features**

Key capabilities:
- **Dict-like**: Key-value storage with attribute access (`batch.obs`, `batch.reward`)
- **Array-like**: Shape, indexing, slicing (`batch[0]`, `batch[:10]`, `batch.shape`)
- **Hierarchical**: Nested structures for complex data
- **Type-safe**: Protocol-based typing for IDE support
- **RL-aware**: Special handling for distributions, missing values, heterogeneous aggregation

## 2. Core Concepts

### Hierarchical Named Tensors

Batch stores **hierarchical named tensors** - collections of tensors whose identifiers form a structured hierarchy. Consider tensors `[t1, t2, t3, t4]` with names `[name1, name2, name3, name4]`, where `name1` and `name2` are under namespace `name0`. The fully qualified name of `t1` is `name0.name1`.

### Tree Structure Visualization

The structure can be visualized as a tree with:
- **Root**: The Batch object itself
- **Internal nodes**: Keys (names)
- **Leaf nodes**: Values (scalars, arrays, tensors)

```mermaid
graph TD
    root["Batch (root)"]
    root --> obs["obs"]
    root --> act["act"]
    root --> rew["rew"]
    obs --> camera["camera"]
    obs --> sensory["sensory"]
    camera --> cam_data["np.array(3,3)"]
    sensory --> sens_data["np.array(5,)"]
    act --> act_data["np.array(2,)"]
    rew --> rew_data["3.66"]
    
    style root fill:#e1f5ff
    style obs fill:#fff4e1
    style act fill:#fff4e1
    style rew fill:#fff4e1
    style camera fill:#ffe1f5
    style sensory fill:#ffe1f5
    style cam_data fill:#e8f5e1
    style sens_data fill:#e8f5e1
    style act_data fill:#e8f5e1
    style rew_data fill:#e8f5e1
```

In [None]:
# Example: hierarchical structure
data = {
    "action": np.array([1.0, 2.0, 3.0]),
    "reward": 3.66,
    "obs": {
        "camera": np.zeros((3, 3)),
        "sensory": np.ones(5),
    },
}

batch = Batch(data)
print(batch)
print("\nAccessing nested values:")
print(f"batch.obs.camera.shape = {batch.obs.camera.shape}")
print(f"batch.obs.sensory = {batch.obs.sensory}")

### Data Flow in RL Pipeline

Batch facilitates data flow throughout the RL pipeline:

```mermaid
graph LR
    A[Environment] -->|ObsBatchProtocol| B[Collector]
    B -->|RolloutBatchProtocol| C[Replay Buffer]
    C -->|RolloutBatchProtocol| D[Policy]
    D -->|ActBatchProtocol| A
    D -->|BatchWithAdvantages| E[Algorithm/Trainer]
    E --> D
    
    style A fill:#e1f5ff
    style B fill:#fff4e1
    style C fill:#ffe1f5
    style D fill:#e8f5e1
    style E fill:#f5e1e1
```

Each arrow represents a specific `BatchProtocol` that defines what fields are expected at that stage.

## 3. Basic Operations

### 3.1 Construction

Batch objects can be constructed in several ways:

In [None]:
# From keyword arguments
batch1 = Batch(a=4, b=[5, 5], c="hello")
print("From kwargs:", batch1)

# From dictionary
batch2 = Batch({"a": 4, "b": [5, 5], "c": "hello"})
print("\nFrom dict:", batch2)

# From list of dictionaries (automatically stacked)
batch3 = Batch([{"a": 1, "b": 2}, {"a": 3, "b": 4}])
print("\nFrom list of dicts:", batch3)

# Nested batch
batch4 = Batch(obs=Batch(x=1, y=2), act=5)
print("\nNested:", batch4)

### 3.2 Content Rules

Understanding what Batch can store and how it converts data:

In [None]:
# Keys must be strings
batch = Batch()
batch.key1 = "value"
batch.key2 = np.array([1, 2, 3])
print("Keys:", list(batch.keys()))

# Automatic conversions
demo = Batch(
    scalar_int=5,  # → np.array(5)
    scalar_float=3.14,  # → np.array(3.14)
    list_nums=[1, 2, 3],  # → np.array([1, 2, 3])
    list_mixed=[1, "hello", None],  # → np.array([1, "hello", None], dtype=object)
    dict_val={"x": 1, "y": 2},  # → Batch(x=1, y=2)
)

print("\nAutomatic conversions:")
print(f"scalar_int type: {type(demo.scalar_int)}, value: {demo.scalar_int}")
print(f"list_nums type: {type(demo.list_nums)}, dtype: {demo.list_nums.dtype}")
print(f"list_mixed dtype: {demo.list_mixed.dtype}")
print(f"dict_val type: {type(demo.dict_val)}")

**Important conversions:**
- Lists of numbers → NumPy arrays
- Lists with mixed types → Object-dtype arrays
- Dictionaries → Batch objects (recursively)
- Scalars → NumPy scalars

### 3.3 Access Patterns

**Important: Understanding Iteration**

In [None]:
batch = Batch(a=[1, 2, 3], b=[4, 5, 6])

# Attribute vs dictionary access (equivalent)
print("Attribute access:", batch.a)
print("Dict access:", batch["a"])

# Getting keys
print("\nKeys:", list(batch.keys()))

# Gotcha: Iteration is array like, not over keys
print("\nIteration behavior:")
print("for x in batch iterates over batch[0], batch[1], ..., NOT keys!")
for i, item in enumerate(batch):
    print(f"batch[{i}] = {item}")

# This is different from dict behavior!
regular_dict = {"a": [1, 2, 3], "b": [4, 5, 6]}
print("\nCompare with dict iteration (iterates over keys):")
for key in regular_dict:
    print(f"key = {key}")

### 3.4 Indexing & Slicing

Batch supports NumPy-like indexing and slicing:

In [None]:
batch = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5.0, -5.0], [1.0, -2.0]])

print("Original batch shape:", batch.shape)
print("Original batch length:", len(batch))

# Single index
print("\nbatch[0]:")
print(batch[0])

# Slicing
print("\nbatch[:1]:")
print(batch[:1])

# Advanced indexing
print("\nbatch[[0, 1]]:")
print(batch[[0, 1]])

# Multi-dimensional indexing
print("\nbatch[:, 0] (first column of all arrays):")
print(batch[:, 0])

In [None]:
# Broadcasting and in-place operations
batch[:, 1] += 10
print("After batch[:, 1] += 10:")
print(batch)

### 3.5 Stack, Concatenate, and Split

Combining and splitting batches:

In [None]:
# Stack: adds a new dimension
batch1 = Batch(a=np.array([1, 2]), b=np.array([5, 6]))
batch2 = Batch(a=np.array([3, 4]), b=np.array([7, 8]))

stacked = Batch.stack([batch1, batch2])
print("Stacked:")
print(stacked)
print(f"Shape: {stacked.shape}")

# Concatenate: extends along existing dimension
concatenated = Batch.cat([batch1, batch2])
print("\nConcatenated:")
print(concatenated)
print(f"Shape: {concatenated.shape}")

In [None]:
# Split
batch = Batch(a=np.arange(10), b=np.arange(10, 20))
splits = list(batch.split(size=3, shuffle=False))
print(f"Split into {len(splits)} batches:")
for i, split in enumerate(splits):
    print(f"Split {i}: a={split.a}, length={len(split)}")

### 3.6 Data Type Conversion

Converting between NumPy and PyTorch:

In [None]:
# Create batch with NumPy arrays
batch = Batch(a=np.zeros((3, 4)), b=np.ones(5))
print("Original (NumPy):")
print(f"batch.a type: {type(batch.a)}")

# Convert to PyTorch (in-place)
batch.to_torch_(dtype=torch.float32, device="cpu")
print("\nAfter to_torch_():")
print(f"batch.a type: {type(batch.a)}")
print(f"batch.a dtype: {batch.a.dtype}")

# Convert back to NumPy (in-place)
batch.to_numpy_()
print("\nAfter to_numpy_():")
print(f"batch.a type: {type(batch.a)}")

# Non-in-place versions return a new batch
batch_torch = batch.to_torch()
print("\nOriginal batch unchanged:", type(batch.a))
print("New batch:", type(batch_torch.a))

## 4. Type Safety with Protocols

### Why Protocols?

Batch needs to be **flexible** (not fixed fields like dataclasses) but we still want **type safety** and **IDE autocompletion**. Protocols provide the best of both worlds:

- **Runtime flexibility**: Add any fields dynamically
- **Static type checking**: Type checkers (mypy, pyright) verify correct usage
- **IDE support**: Autocompletion for expected fields

### What is BatchProtocol?

A `Protocol` defines an interface without implementation. Think of it as a contract: "any object with these fields is valid."

In [None]:
# Creating a typed batch using cast
# This enables IDE autocompletion and type checking

# ActBatchProtocol: just needs 'act' field
act_batch = cast(ActBatchProtocol, Batch(act=np.array([1, 2, 3])))
print("ActBatchProtocol:", act_batch.act)

# ObsBatchProtocol: needs 'obs' and 'info' fields
obs_batch = cast(
    ObsBatchProtocol,
    Batch(obs=np.array([[1.0, 2.0], [3.0, 4.0]]), info=np.array([{}, {}], dtype=object)),
)
print("\nObsBatchProtocol:", obs_batch.obs)

# RolloutBatchProtocol: needs obs, obs_next, act, rew, terminated, truncated
rollout_batch = cast(
    RolloutBatchProtocol,
    Batch(
        obs=np.array([[1.0, 2.0], [3.0, 4.0]]),
        obs_next=np.array([[2.0, 3.0], [4.0, 5.0]]),
        act=np.array([0, 1]),
        rew=np.array([1.0, 2.0]),
        terminated=np.array([False, True]),
        truncated=np.array([False, False]),
        info=np.array([{}, {}], dtype=object),
    ),
)
print("\nRolloutBatchProtocol reward:", rollout_batch.rew)

### Protocol Hierarchy

Tianshou defines a hierarchy of protocols for different use cases:

```mermaid
graph TD
    BP[BatchProtocol<br/>Base protocol] --> OBP[ObsBatchProtocol<br/>obs, info]
    BP --> ABP[ActBatchProtocol<br/>act]
    ABP --> ASBP[ActStateBatchProtocol<br/>act, state]
    OBP --> RBP[RolloutBatchProtocol<br/>+obs_next, act, rew,<br/>terminated, truncated]
    RBP --> BWRP[BatchWithReturnsProtocol<br/>+returns]
    BWRP --> BWAP[BatchWithAdvantagesProtocol<br/>+adv, v_s]
    ASBP --> MOBP[ModelOutputBatchProtocol<br/>+logits]
    MOBP --> DBP[DistBatchProtocol<br/>+dist]
    DBP --> DLPBP[DistLogProbBatchProtocol<br/>+log_prob]
    BWAP --> LOPBP[LogpOldProtocol<br/>+logp_old]
    
    style BP fill:#e1f5ff
    style OBP fill:#fff4e1
    style ABP fill:#fff4e1
    style RBP fill:#ffe1f5
    style BWRP fill:#e8f5e1
    style BWAP fill:#e8f5e1
    style DBP fill:#f5e1e1
    style LOPBP fill:#e1e1f5
```

### Using Protocols in Functions

Protocols enable type-safe function signatures:

In [None]:
def process_observations(batch: ObsBatchProtocol) -> np.ndarray:
    """Function that expects observations.

    IDE will autocomplete batch.obs and batch.info!
    Type checker will verify these fields exist.
    """
    # IDE knows batch.obs exists
    return batch.obs if isinstance(batch.obs, np.ndarray) else np.array(batch.obs)


def compute_advantage(batch: RolloutBatchProtocol) -> np.ndarray:
    """Function that expects rollout data.

    IDE will autocomplete batch.rew, batch.obs_next, etc.
    """
    # Simplified advantage computation
    return batch.rew  # IDE knows this exists


# Example usage
obs_data = Batch(obs=np.array([1, 2, 3]), info=np.array([{}], dtype=object))
result = process_observations(obs_data)
print("Processed obs:", result)

**Key Protocol Types:**

- `ActBatchProtocol`: Just actions (for simple policies)
- `ObsBatchProtocol`: Observations and info
- `RolloutBatchProtocol`: Complete transitions (obs, act, rew, done, obs_next)
- `BatchWithReturnsProtocol`: Rollouts + computed returns
- `BatchWithAdvantagesProtocol`: Returns + advantages and values
- `DistBatchProtocol`: Contains distribution objects
- `LogpOldProtocol`: For importance sampling (PPO, etc.)

See `tianshou/data/types.py` for the complete list!

## 5. Distribution Slicing

### Why Special Handling?

PyTorch `Distribution` objects need special slicing because they're not simple arrays. When you slice `batch[0:2]`, Tianshou needs to slice the underlying distribution parameters correctly.

### Supported Distributions

Tianshou supports slicing for:
- `Categorical`: Discrete distributions
- `Normal`: Continuous Gaussian distributions
- `Independent`: Wraps other distributions

In [None]:
# Categorical distribution
probs = torch.tensor([[0.3, 0.7], [0.4, 0.6], [0.5, 0.5]])
dist = Categorical(probs=probs)
batch = Batch(dist=dist, values=np.array([1, 2, 3]))

print("Original batch length:", len(batch))
print("Original dist probs shape:", batch.dist.probs.shape)

# Slicing automatically handles the distribution
sliced = batch[0:2]
print("\nSliced batch length:", len(sliced))
print("Sliced dist probs shape:", sliced.dist.probs.shape)
print("Sliced values:", sliced.values)

In [None]:
# Normal distribution
loc = torch.tensor([0.0, 1.0, 2.0])
scale = torch.tensor([1.0, 1.0, 1.0])
normal_dist = Normal(loc=loc, scale=scale)
batch_normal = Batch(dist=normal_dist, actions=np.array([0.5, 1.5, 2.5]))

print("Normal distribution batch:")
print(f"Original mean: {batch_normal.dist.mean}")

# Index a single element
single = batch_normal[1]
print(f"\nSingle element mean: {single.dist.mean}")
print(f"Single element action: {single.actions}")

### Converting to At Least 2D

Sometimes you need to ensure distributions have a batch dimension:

In [None]:
from tianshou.data.batch import dist_to_atleast_2d

# Scalar distribution (no batch dimension)
scalar_dist = Categorical(probs=torch.tensor([0.3, 0.7]))
print("Scalar dist batch_shape:", scalar_dist.batch_shape)

# Convert to have batch dimension
batched_dist = dist_to_atleast_2d(scalar_dist)
print("Batched dist batch_shape:", batched_dist.batch_shape)

# For entire batch
scalar_batch = Batch(a=1, b=2, dist=Categorical(probs=torch.ones(3)))
print("\nBefore to_at_least_2d:", scalar_batch.dist.batch_shape)

batch_2d = scalar_batch.to_at_least_2d()
print("After to_at_least_2d:", batch_2d.dist.batch_shape)

### Use Cases

Distribution slicing is used in:
- **Policy sampling**: When policies output distributions, slicing batches preserves distribution structure
- **Replay buffer sampling**: Distributions are stored and retrieved correctly
- **Advantage computation**: Computing log probabilities on subsets of data

## 6. Advanced Topics

### 6.1 Key Reservation

Sometimes you know what keys you'll need but don't have values yet. Reserve keys using empty `Batch()` objects:

```mermaid
graph TD
    root["Batch"]
    root --> a["key1: np.array([1,2,3])"]
    root --> b["key2: Batch() (reserved)"]
    root --> c["key3"]
    c --> c1["subkey1: Batch() (reserved)"]
    c --> c2["subkey2: np.array([4,5])"]
    
    style root fill:#e1f5ff
    style a fill:#e8f5e1
    style b fill:#ffcccc
    style c fill:#fff4e1
    style c1 fill:#ffcccc
    style c2 fill:#e8f5e1
```

In [None]:
# Reserving keys
batch = Batch(
    known_field=np.array([1, 2]),
    future_field=Batch(),  # Reserved for later
)
print("Batch with reserved key:")
print(batch)

# Later, assign actual data
batch.future_field = np.array([3, 4])
print("\nAfter assignment:")
print(batch)

# Nested reservation
batch2 = Batch(
    obs=Batch(
        camera=Batch(),  # Reserved
        lidar=np.zeros(10),
    )
)
print("\nNested reservation:")
print(batch2)

### 6.2 Length and Shape Semantics

Understanding when `len()` works and what `shape` means:

In [None]:
# Normal case: all tensors same length
batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5, 6]))
print("Normal batch:")
print(f"len(batch1) = {len(batch1)}")
print(f"batch1.shape = {batch1.shape}")

# Scalars have no length
batch2 = Batch(a=5, b=10)
print("\nScalar batch:")
print(f"batch2.shape = {batch2.shape}")
try:
    print(f"len(batch2) = {len(batch2)}")
except TypeError as e:
    print(f"len(batch2) raises TypeError: {e}")

# Mixed lengths: returns minimum
batch3 = Batch(a=[1, 2], b=[3, 4, 5])
print("\nMixed length batch:")
print(f"len(batch3) = {len(batch3)} (minimum of 2 and 3)")

# Reserved keys are ignored
batch4 = Batch(a=[1, 2, 3], reserved=Batch())
print("\nBatch with reserved key:")
print(f"len(batch4) = {len(batch4)} (reserved key ignored)")

### 6.3 Empty Batches

Understanding different meanings of "empty":

In [None]:
# 1. No keys at all
empty1 = Batch()
print("No keys:")
print(f"len(empty1.get_keys()) = {len(list(empty1.get_keys()))}")
print(f"len(empty1) = {len(empty1)}")

# 2. Has keys but they're all reserved
empty2 = Batch(a=Batch(), b=Batch())
print("\nReserved keys only:")
print(f"len(empty2.get_keys()) = {len(list(empty2.get_keys()))}")
print(f"len(empty2) = {len(empty2)}")

# 3. Has data but length is 0
empty3 = Batch(a=np.array([]), b=np.array([]))
print("\nZero-length arrays:")
print(f"len(empty3.get_keys()) = {len(list(empty3.get_keys()))}")
print(f"len(empty3) = {len(empty3)}")

**Checking emptiness:**
- `len(batch.get_keys()) == 0`: No keys (completely empty)
- `len(batch) == 0`: No data elements (may have reserved keys)

**The `.empty()` and `.empty_()` methods:**
These reset values to zeros/None, different from checking emptiness:

In [None]:
batch = Batch(a=[1, 2, 3], b=["x", "y", "z"])
print("Original:", batch)

# Empty specific index
batch[0] = Batch.empty(batch[0])
print("\nAfter emptying index 0:")
print(batch)

### 6.4 Heterogeneous Aggregation

Stacking/concatenating batches with different keys:

```mermaid
graph LR
    A["Batch(a=[1,2], c=5)"] --> C["Batch.stack"]
    B["Batch(b=[3,4], c=6)"] --> C
    C --> D["Batch(a=[[1,2],[0,0]],<br/>b=[[0,0],[3,4]],<br/>c=[5,6])"]
    
    style A fill:#e1f5ff
    style B fill:#fff4e1
    style C fill:#ffe1f5
    style D fill:#e8f5e1
```

In [None]:
# Stack with different keys (missing keys padded with zeros)
batch_a = Batch(a=np.ones((2, 3)), shared=np.array([1, 2]))
batch_b = Batch(b=np.zeros((2, 4)), shared=np.array([3, 4]))

stacked = Batch.stack([batch_a, batch_b])
print("Stacked batch:")
print(f"a.shape = {stacked.a.shape} (padded with zeros for batch_b)")
print(f"b.shape = {stacked.b.shape} (padded with zeros for batch_a)")
print(f"shared.shape = {stacked.shared.shape} (in both batches)")
print(stacked)

### 6.5 Missing Values

Handling `None` and `NaN` values:

In [None]:
# Batch with missing values
batch = Batch(a=[1, 2, None, 4], b=[5.0, np.nan, 7.0, 8.0], c=[[1, 2], [3, 4], [5, 6], [7, 8]])

# Check for nulls
print("Has null?", batch.hasnull())

# Get null mask
null_mask = batch.isnull()
print("\nNull mask:")
print(f"a: {null_mask.a}")
print(f"b: {null_mask.b}")

# Drop rows with any null
clean_batch = batch.dropnull()
print("\nAfter dropnull() (keeps rows 0 and 3):")
print(f"Length: {len(clean_batch)}")
print(f"a: {clean_batch.a}")
print(f"b: {clean_batch.b}")

### 6.6 Value Transformations

Applying functions to all values recursively:

In [None]:
batch = Batch(a=np.array([1, 2, 3]), nested=Batch(b=np.array([4.0, 5.0]), c=np.array([6, 7, 8])))

# Apply transformation (returns new batch)
doubled = batch.apply_values_transform(lambda x: x * 2)
print("Original batch a:", batch.a)
print("Doubled batch a:", doubled.a)
print("Doubled nested.b:", doubled.nested.b)

# In-place transformation
batch.apply_values_transform(lambda x: x + 10, inplace=True)
print("\nAfter in-place +10:")
print("a:", batch.a)
print("nested.b:", batch.nested.b)

## 7. Surprising Behaviors & Gotchas

### Iteration Does NOT Iterate Over Keys!

**This is the most common source of confusion:**

In [None]:
batch = Batch(a=[1, 2, 3], b=[4, 5, 6])

print("WRONG: This doesn't iterate over keys!")
for item in batch:
    print(f"item = {item}")  # Prints batch[0], batch[1], batch[2]

print("\nCORRECT: To iterate over keys:")
for key in batch.keys():
    print(f"key = {key}")

print("\nCORRECT: To iterate over key-value pairs:")
for key, value in batch.items():
    print(f"{key} = {value}")

### Automatic Type Conversions

Be aware of these automatic conversions:

In [None]:
# Lists become arrays
batch = Batch(a=[1, 2, 3])
print("List → array:", type(batch.a), batch.a.dtype)

# Dicts become Batch
batch = Batch(a={"x": 1, "y": 2})
print("Dict → Batch:", type(batch.a))

# Scalars become numpy scalars
batch = Batch(a=5)
print("Scalar → np.ndarray:", type(batch.a), batch.a)

# Mixed types → object dtype
batch = Batch(a=[1, "hello", None])
print("Mixed → object:", batch.a.dtype, batch.a)

### Length Edge Cases

In [None]:
# 1. Scalars have no length
batch_scalar = Batch(a=5, b=10)
try:
    len(batch_scalar)
except TypeError as e:
    print(f"Scalar batch: {e}")

# 2. Empty nested batches ignored in len()
batch_empty_nested = Batch(a=[1, 2, 3], b=Batch())
print(f"\nWith empty nested: len = {len(batch_empty_nested)} (ignores b)")

# 3. Different lengths: returns minimum
batch_different = Batch(a=[1, 2], b=[1, 2, 3, 4])
print(f"Different lengths: len = {len(batch_different)} (minimum)")

# 4. None values don't affect length
batch_none = Batch(a=[1, 2, 3], b=None)
print(f"With None: len = {len(batch_none)} (None ignored)")

### String Keys Only

In [None]:
# Integer keys not allowed
try:
    batch = Batch({1: "value", 2: "other"})
except AssertionError as e:
    print("Integer keys not allowed:", e)

# String keys work
batch = Batch({"key1": "value", "key2": "other"})
print("\nString keys work:", list(batch.keys()))

### Cat vs Stack Behavior

Recent changes have made concatenation stricter about structure:

In [None]:
# Stack pads missing keys with zeros
b1 = Batch(a=[1, 2])
b2 = Batch(b=[3, 4])
stacked = Batch.stack([b1, b2])
print("Stack (different keys):")
print(f"  a: {stacked.a}  (b2.a padded with 0)")
print(f"  b: {stacked.b}  (b1.b padded with 0)")

# Cat requires same structure now
b3 = Batch(a=[1, 2], b=[3, 4])
b4 = Batch(a=[5, 6], b=[7, 8])
concatenated = Batch.cat([b3, b4])
print("\nCat (same keys):")
print(f"  a: {concatenated.a}")
print(f"  b: {concatenated.b}")

# Cat with different structures raises error
try:
    Batch.cat([b1, b2])  # Different keys!
except ValueError:
    print("\nCat with different keys: ValueError raised")

## 8. Best Practices

### When to Use Batch

**Good use cases:**
- Collecting environment data (transitions, episodes)
- Storing replay buffer data
- Passing data between components (collector → buffer → policy)
- Handling heterogeneous observations (dict spaces)

**Consider alternatives:**
- Simple scalar tracking (use regular variables)
- Pure tensor operations (use PyTorch tensors directly)
- Deeply nested arbitrary structures (use dataclasses)

### Structuring Your Batches

**Use protocols for type safety:**

In [None]:
# Good: Use protocols for clear interfaces
def train_step(batch: RolloutBatchProtocol) -> float:
    """IDE knows what fields exist."""
    loss = ((batch.rew - 0.5) ** 2).mean()  # Type-safe
    return float(loss)


# Create properly typed batch
train_batch = cast(
    RolloutBatchProtocol,
    Batch(
        obs=np.random.randn(10, 4),
        obs_next=np.random.randn(10, 4),
        act=np.random.randint(0, 2, 10),
        rew=np.random.randn(10),
        terminated=np.zeros(10, dtype=bool),
        truncated=np.zeros(10, dtype=bool),
        info=np.array([{}] * 10, dtype=object),
    ),
)

loss = train_step(train_batch)
print(f"Loss: {loss:.4f}")

**Consistent key naming:**
- Follow Tianshou conventions: `obs`, `act`, `rew`, `terminated`, `truncated`
- Use descriptive names: `camera_obs` not `co`
- Avoid name collisions with Batch methods: don't use `keys`, `items`, `get`, etc.

**When to nest vs flatten:**

In [None]:
# Good: Nest related data
batch_nested = Batch(
    obs=Batch(
        camera=np.zeros((32, 64, 64, 3)), lidar=np.zeros((32, 100)), position=np.zeros((32, 3))
    ),
    act=np.zeros(32),
)
print("Nested structure for related obs:")
print(f"  Access: batch.obs.camera.shape = {batch_nested.obs.camera.shape}")

# Less good: Flat structure loses semantic grouping
batch_flat = Batch(
    camera=np.zeros((32, 64, 64, 3)),
    lidar=np.zeros((32, 100)),
    position=np.zeros((32, 3)),
    act=np.zeros(32),
)
print("\nFlat structure (works but less clear):")
print(f"  Access: batch.camera.shape = {batch_flat.camera.shape}")

### Performance Tips

**Use in-place operations:**

In [None]:
import time

batch = Batch(a=np.random.randn(1000, 100))

# Creates copy
start = time.time()
for _ in range(100):
    _ = batch.to_torch()
time_copy = time.time() - start

# In-place (faster)
start = time.time()
for _ in range(100):
    batch.to_torch_()
    batch.to_numpy_()
time_inplace = time.time() - start

print(f"Copy: {time_copy:.4f}s")
print(f"In-place: {time_inplace:.4f}s")
print(f"Speedup: {time_copy / time_inplace:.1f}x")

**Be mindful of copies:**

In [None]:
arr = np.array([1, 2, 3])

# Default: creates reference (be careful!)
batch1 = Batch(a=arr)
batch1.a[0] = 999
print(f"Original array modified: {arr}")  # Changed!

# Explicit copy when needed
arr = np.array([1, 2, 3])
batch2 = Batch(a=arr, copy=True)
batch2.a[0] = 999
print(f"Original array preserved: {arr}")  # Unchanged

**Avoid unnecessary conversions:**

In [None]:
# Inefficient: multiple conversions
batch = Batch(a=np.random.randn(100, 10))
batch.to_torch_()
batch.to_numpy_()  # Unnecessary if we just need NumPy

# Efficient: convert once, use many times
batch = Batch(a=np.random.randn(100, 10))
batch.to_torch_()  # Convert once
# ... do torch operations ...
# Keep as torch if that's what you need!

### Common Patterns

**Pattern 1: Building batches incrementally**

In [None]:
# Collect data from multiple steps
step_data = []
for i in range(5):
    step_data.append({"obs": np.random.randn(4), "act": i, "rew": np.random.randn()})

# Convert to batch (automatically stacks)
episode_batch = Batch(step_data)
print("Episode batch shape:", episode_batch.shape)
print("obs shape:", episode_batch.obs.shape)

**Pattern 2: Slicing for mini-batches**

In [None]:
# Large batch
large_batch = Batch(obs=np.random.randn(100, 4), act=np.random.randint(0, 2, 100))

# Split into mini-batches
batch_size = 32
for mini_batch in large_batch.split(batch_size, shuffle=True):
    print(f"Mini-batch size: {len(mini_batch)}")
    # Train on mini_batch...
    break  # Just show one iteration

**Pattern 3: Extending batches**

In [None]:
# Start with some data
batch = Batch(obs=np.array([[1, 2], [3, 4]]), act=np.array([0, 1]))
print("Initial:", len(batch))

# Add more data
new_data = Batch(obs=np.array([[5, 6]]), act=np.array([1]))
batch.cat_(new_data)
print("After cat_:", len(batch))
print("obs:", batch.obs)

## 9. Summary

### Key Takeaways

1. **Batch = Dict + Array**: Combines key-value storage with array operations
2. **Hierarchical Structure**: Perfect for complex RL data (nested observations, etc.)
3. **Type Safety via Protocols**: Use `BatchProtocol` subclasses for IDE support and type checking
4. **Special RL Features**: Distribution slicing, heterogeneous aggregation, missing value handling
5. **Remember**: Iteration is over indices, NOT keys!

### Quick Reference

| Operation | Code | Notes |
|-----------|------|-------|
| Create | `Batch(a=1, b=[2, 3])` | Auto-converts types |
| Access | `batch.a` or `batch["a"]` | Equivalent |
| Index | `batch[0]`, `batch[:10]` | Returns sliced Batch |
| Iterate indices | `for item in batch:` | Yields batch[0], batch[1], ... |
| Iterate keys | `for k in batch.keys():` | Like dict |
| Stack | `Batch.stack([b1, b2])` | Adds dimension |
| Concatenate | `Batch.cat([b1, b2])` | Extends dimension |
| Split | `batch.split(size=10)` | Returns iterator |
| To PyTorch | `batch.to_torch_()` | In-place |
| To NumPy | `batch.to_numpy_()` | In-place |
| Transform | `batch.apply_values_transform(fn)` | Recursive |

### Next Steps

- **Collector Deep Dive**: See how Batch flows through data collection
- **Buffer Deep Dive**: Understand how Batch is stored and sampled
- **Policy Guide**: Learn how policies work with BatchProtocol
- **API Reference**: Full details at [Batch API documentation](https://tianshou.org/en/stable/api/tianshou.data.html#tianshou.data.Batch)

### Questions?

- Check the [Tianshou GitHub discussions](https://github.com/thu-ml/tianshou/discussions)
- Review [issue tracker](https://github.com/thu-ml/tianshou/issues) for known gotchas
- Read the [source code](https://github.com/thu-ml/tianshou/blob/master/tianshou/data/batch.py) - it's well-documented!

## Appendix: Serialization & Advanced Topics

### Pickle Support

In [None]:
# Batch objects are picklable
original = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))

# Serialize and deserialize
serialized = pickle.dumps(original)
restored = pickle.loads(serialized)

print("Original obs.a:", original.obs.a)
print("Restored obs.a:", restored.obs.a)
print("Equal:", original == restored)

### Advanced Indexing

In [None]:
# Multi-dimensional data
batch = Batch(a=np.random.randn(5, 3, 2))
print("Original shape:", batch.a.shape)

# Various indexing operations
print("batch[0].a.shape:", batch[0].a.shape)
print("batch[:, 0].a.shape:", batch[:, 0].a.shape)
print("batch[[0, 2, 4]].a.shape:", batch[[0, 2, 4]].a.shape)