# ðŸš€ Model Training

This notebook guides you through the process of training a policy on your converted dataset. 

The process is broken down into a few simple steps:
1.  **Setup**: Apply necessary patches to the `lerobot` library.
2.  **Dataset**: Specify the path to your training data.
3.  **Configuration**: Select a model architecture and its hyperparameters.
4.  **Training**: Launch the training process.

--- 
### 1. Setup

First, apply our custom patches to the `lerobot` library. This only needs to be done once per session.

In [None]:
# Set environment variables BEFORE any imports
import os
import warnings

os.environ["LEROBOT_VIDEO_BACKEND"] = "pyav"

# Suppress torchvision video deprecation warning
warnings.filterwarnings("ignore", message=".*video decoding and encoding capabilities.*")

from example_policies import lerobot_patches

lerobot_patches.apply_patches()

--- 
### 2. Select Dataset

> **Action Required:** Update `DATA_DIR` to point to the dataset you created in the previous notebook.

In [None]:
import pathlib

# TODO: Set the path to your converted dataset directory.
DATA_DIR = pathlib.Path("/data/[TODO]")

--- 
### 3. Select Model Configuration

We provide several pre-made configurations as a starting point, but recommend using dit_flow_config. You can also adjust parameters like `batch_size` as needed.

In [None]:
# Select one of the following configuration classes
from example_policies.config_factory import DiffusionConfig, DiTFlowConfig

# Create a configuration (DiTFlowConfig is recommended)
config = DiTFlowConfig(
    dataset_root_dir=DATA_DIR,
    wandb_enable=False,
    policy_kwargs={"image_only": True}
)

# Build the training configuration
cfg = config.build()

# Disable multiprocessing if there are issues with dataloader workers
# cfg.num_workers = 0

You can customize the configuration by modifying the dataclass parameters:
```python
config = DiTFlowConfig(
    dataset_root_dir=DATA_DIR,
    batch_size=64,
    lr=2e-4,
    steps=100_000,
    save_freq=10_000,
    wandb_enable=True,
    policy_kwargs={"image_only": True}
)
cfg = config.build()
```

Available configuration classes: `ACTConfig`, `DiffusionConfig`, `DiTFlowConfig`, `SmolVLAConfig`, `Pi0Config`

--- 
### 4. Start Training

This cell will start the training process. Metrics and logs will be streamed to the console, and if you have configured it, to Weights & Biases.

In [None]:
from example_policies.train import train

# Set video backend in config
cfg.dataset.video_backend = "pyav"

train(cfg)