In [None]:
from lerobot.common.datasets.factory import make_dataset
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.default import DatasetConfig
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.datasets.utils import cycle
from lerobot.configs.types import FeatureType
from lerobot.common.utils.utils import (
    format_big_number,
    get_safe_torch_device,
    has_method,
    init_logging,
)
import torch
from src.policies.baseline.configuration import BaselineConfig
from src.policies.baseline.modeling import BaselinePolicy

In [None]:
ROOT = "./dataset/transformed_data_notebook"
CKPT_PATH = "./ckpt/baseline_model"
LOG_EVERY = 100
TOTAL_EPOCHS = 100

## Configurations

This is the baseline configurations that can be edited. 

```
@PreTrainedConfig.register_subclass("omy_baseline")
@dataclass
class BaselineConfig(PreTrainedConfig):
    # Input / output structure.
    n_obs_steps: int = 1
    chunk_size: int = 5
    n_action_steps: int = 5

    # Architecture.
    backbone: str = 'mlp' # 'mlp' or 'transformer'
    # Vision encoder
    vision_backbone: str ="facebook/dinov3-vitb16-pretrain-lvd1689m" #"facebook/dinov2-base"
    projection_dim : int = 128
    freeze_backbone:  bool = True


    # Num hidden layers
    n_hidden_layers: int = 5
    hidden_dim: int = 512   

    ## For transformer-based architectures
    n_heads: int = 4
    dim_feedforward: int = 2048
    feedforward_activation: str = "gelu"
    dropout: float = 0.1
    pre_norm: bool = True
    n_encoder_layers: int = 6

    # Training preset
    optimizer_lr: float = 1e-3
    optimizer_weight_decay: float = 1e-6

    # Learning rate scheduler parameters 
    lr_warmup_steps: int = 1000
    total_training_steps: int = 500000
```

In [None]:
'''
Load Dataset and Configurations     
'''
# Dataset Config
dataset_cfg = DatasetConfig("transformed_data")
dataset_cfg.root = ROOT
pipeline_cfg = TrainPipelineConfig(dataset_cfg)

# Policy Config
cfg = BaselineConfig(
    chunk_size=10,
    n_action_steps=10,
    backbone='mlp',
    optimizer_lr= 5e-4,
    n_hidden_layers=4,
    hidden_dim=256,
    # If you are using image features, uncomment the following line
    vision_backbone='facebook/dinov3-vitb16-pretrain-lvd1689m',#"facebook/dinov2-base", **You need access to use this model** Use dinov2 if you don't have access
    projection_dim=128,
    freeze_backbone=True,

)


## Build Dataset

In [None]:
# Create kwargs and configure pipeline
kwargs = {}
pipeline_cfg.policy = cfg
pipeline_cfg.optimizer = cfg.get_optimizer_preset()
pipeline_cfg.scheduler = cfg.get_scheduler_preset()

# Create Dataset
# Meta data is for loading dataset statistics and feature information
dataset = make_dataset(pipeline_cfg)
ds_meta = dataset.meta
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg

## Load Policy

In [None]:
'''`
Instantiate Policy
'''
policy = BaselinePolicy(**kwargs)
policy.to(pipeline_cfg.policy.device)

In [None]:
'''
Create Optimizer and Scheduler
'''
optimizer, lr_scheduler = make_optimizer_and_scheduler(pipeline_cfg, policy)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size= 64,
    drop_last=False,
    shuffle=True,
    num_workers=4,
)

In [None]:
'''
Check Parameter Counts
'''
trainable_params = [p for p in policy.parameters() if p.requires_grad]
total_params = list(policy.parameters())
print(f"Total number of parameters: {format_big_number(sum(p.numel() for p in total_params))}")
print(f"Number of trainable parameters: {format_big_number(sum(p.numel() for p in trainable_params))}")

## Training Loop

In [None]:
'''
Training Loop
'''
device = get_safe_torch_device(pipeline_cfg.policy.device, log=True)
step = 0
for epoch in range(TOTAL_EPOCHS):
    print(f"Starting epoch {epoch+1}/{TOTAL_EPOCHS}")
    for batch in dataloader:
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch[key] = batch[key].to(device, non_blocking=True)
        policy.train()
        loss, output_dict = policy.forward(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Step through pytorch scheduler at every batch instead of epoch
        if lr_scheduler is not None:
            lr_scheduler.step()
        step += 1
        if step % LOG_EVERY == 0:
            print(f"Step: {step}, Loss: {loss.item():.4f}, learning rate: {optimizer.param_groups[0]['lr']:.6f}")
# Save checkpoint at the end of training
policy.save_pretrained(CKPT_PATH)

In [None]:
policy.save_pretrained(CKPT_PATH)