# ðŸš€ Model Training

This notebook guides you through the process of training a policy on your converted dataset. 

The process is broken down into a few simple steps:
1.  **Setup**: Apply necessary patches to the `lerobot` library.
2.  **Dataset**: Specify the path to your training data.
3.  **Configuration**: Select a model architecture and its hyperparameters.
4.  **Training**: Launch the training process.

--- 
### 1. Setup

First, apply our custom patches to the `lerobot` library. This only needs to be done once per session.

In [1]:
# Set environment variables BEFORE any imports
import os
import warnings

os.environ["LEROBOT_VIDEO_BACKEND"] = "pyav"

# Suppress torchvision video deprecation warning
warnings.filterwarnings("ignore", message=".*video decoding and encoding capabilities.*")

from example_policies import lerobot_patches

lerobot_patches.apply_patches()

--- 
### 2. Select Dataset

> **Action Required:** Update `DATA_DIR` to point to the dataset you created in the previous notebook.

In [2]:
import pathlib

# TODO: Set the path to your converted dataset directory.
DATA_DIR = pathlib.Path("/home/yizhang/Projects/hackathon-ki-fabrik/data/sort_red_blocks_80_old")

--- 
### 3. Select Model Configuration

We provide several pre-made configurations as a starting point, but recommend using dif_flow_config. You can also adjust parameters like `batch_size` as needed.

In [3]:
# Select one of the following configurations
from example_policies.config_factory import diffusion_config, dit_flow_config, dit_flow_image_config

# cfg = dit_flow_config(DATA_DIR, enable_wandb=False)
cfg = dit_flow_image_config(DATA_DIR, enable_wandb=False)

# Disable multiprocessing if there are issues with dataloader workers
# cfg.num_workers = 0




ðŸ“Š Training by epochs:
   - Dataset size: 21308 frames
   - Batch size: 64
   - Epochs: 200
   - Calculated steps: 66400
   - Save every: 100 epochs (33200 steps)

Final Training Configuration (full details):
TrainPipelineConfig(dataset=DatasetConfig(repo_id='sort_red_blocks_80',
                                          root=PosixPath('/data/sort_red_blocks_80'),
                                          episodes=[0,
                                                    1,
                                                    2,
                                                    3,
                                                    4,
                                                    5,
                                                    6,
                                                    7,
                                                    8,
                                                    9,
                                                    10,
                        

In [None]:
# cfg.log_freq = 1
cfg.save_freq = 10000
cfg.steps = 80000


cfg.policy.optimizer_lr = 2e-4

cfg.job_name = "ditflow_sort_red_blocks_80"
cfg.output_dir = pathlib.Path("/home/yizhang/Projects/hackathon-ki-fabrik/outputs/ditflow_sort_red_blocks_80")
cfg.wandb.enable = True
cfg.wandb.project = "paper"
cfg.wandb.entity = "470620104-technical-university-of-munich"

import json

print(json.dumps(cfg.to_dict(), indent=4))

You can specify additional keywords by looking at the lerobot configuration code, e.g. `lerobot.policies.act.configuration_act`
and then adapt the code cell accordingly:
```python
cfg = act_config(DATA_DIR, policy_kwargs={
    optimizer_lr=1e-5
})
```

In [4]:
# cfg.steps = 

cfg.job_name = "ditflow_image_sort_red_blocks_80" # "ditflow_sort_red_blocks_80"
cfg.output_dir = pathlib.Path("/home/jovyan/outputs/ditflow_image_sort_red_blocks_80/")
cfg.save_freq = 5000
cfg.steps = 20000
cfg.log_freq = 100

cfg.batch_size = 32

cfg.wandb.enable = True
cfg.wandb.project = "paper"
cfg.wandb.entity = "470620104-technical-university-of-munich"

print(cfg.to_dict())

{'dataset': {'repo_id': 'sort_red_blocks_80', 'root': '/data/sort_red_blocks_80', 'episodes': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81], 'image_transforms': {'enable': False, 'max_num_transforms': 3, 'random_order': False, 'tfs': {'brightness': {'weight': 1.0, 'type': 'ColorJitter', 'kwargs': {'brightness': [0.8, 1.2]}}, 'contrast': {'weight': 1.0, 'type': 'ColorJitter', 'kwargs': {'contrast': [0.8, 1.2]}}, 'saturation': {'weight': 1.0, 'type': 'ColorJitter', 'kwargs': {'saturation': [0.5, 1.5]}}, 'hue': {'weight': 1.0, 'type': 'ColorJitter', 'kwargs': {'hue': [-0.05, 0.05]}}, 'sharpness': {'weight': 1.0, 'type': 'SharpnessJitter', 'kwargs': {'sharpness': [0.5, 1.5]}}}}, 'revision': None, 'use_imagenet_stats': Tru

--- 
### 4. Start Training

This cell will start the training process. Metrics and logs will be streamed to the console, and if you have configured it, to Weights & Biases.

In [None]:
from example_policies.train import train

# Set video backend in config
cfg.dataset.video_backend = "pyav"

train(cfg)

INFO 2025-12-03 10:14:08 ts/train.py:111 {'batch_size': 32,
 'dataset': {'episodes': [0,
                          1,
                          2,
                          3,
                          4,
                          5,
                          6,
                          7,
                          8,
                          9,
                          10,
                          11,
                          12,
                          13,
                          14,
                          15,
                          16,
                          17,
                          18,
                          19,
                          20,
                          21,
                          22,
                          23,
                          24,
                          25,
                          26,
                          27,
                          28,
                          29,
                          30,
                    


Starting training...


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
INFO 2025-12-03 10:14:10 _patches.py:128 Track this run --> [1m[33mhttps://wandb.ai/470620104-technical-university-of-munich/paper/runs/5j26axxl[0m
INFO 2025-12-03 10:14:10 ts/train.py:127 Creating dataset


[1m[34mLogs will be synced with wandb.[0m


Resolving data files:   0%|          | 0/82 [00:00<?, ?it/s]

INFO 2025-12-03 10:14:10 ts/train.py:138 Creating policy
INFO 2025-12-03 10:14:11 ts/train.py:144 Creating optimizer and scheduler
INFO 2025-12-03 10:14:11 ts/train.py:156 [1m[33mOutput dir:[0m /home/jovyan/outputs/ditflow_image_sort_red_blocks_80
INFO 2025-12-03 10:14:11 ts/train.py:159 cfg.steps=20000 (20K)
INFO 2025-12-03 10:14:11 ts/train.py:160 dataset.num_frames=21308 (21K)
INFO 2025-12-03 10:14:11 ts/train.py:161 dataset.num_episodes=82
INFO 2025-12-03 10:14:11 ts/train.py:162 num_learnable_params=75699491 (76M)
INFO 2025-12-03 10:14:11 ts/train.py:163 num_total_params=75699741 (76M)
INFO 2025-12-03 10:14:11 ts/train.py:202 Start offline training on a fixed dataset


Number of flow params: 42.11M


INFO 2025-12-03 10:15:26 ts/train.py:232 step:100 smpl:3K ep:12 epch:0.15 loss:1.224 grdn:0.997 lr:2.0e-05 updt_s:0.075 data_s:0.677


In [None]:
# upload policy to hub

!hf upload --repo-type model continuallearning/ditflow_sort_red_blocks_80 /home/yizhang/Projects/hackathon-ki-fabrik/outputs/ditflow_sort_red_blocks_80/checkpoints/last/pretrained_model