# Jetson Autopilot - Training

Train the CNN model for autonomous steering and throttle control.

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import sys
sys.path.insert(0, '..')

import torch
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import display

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Configuration

In [None]:
from jetson import Config

config = Config()

# Paths
config.models_dir = Path("../models")
config.model_name = "autopilot_v1"

DATASETS_DIR = Path("../datasets")
TRAINING_DIR = DATASETS_DIR / "training"
VALIDATION_DIR = DATASETS_DIR / "validation"
TESTING_DIR = DATASETS_DIR / "testing"

# Training hyperparameters
config.training.batch_size = 128
config.training.max_epochs = 50
config.training.early_stopping_patience = 10
config.training.initial_lr = 0.0005

print(f"Model will be saved to: {config.model_path}")

## Load Datasets

In [None]:
from jetson import create_data_loaders

train_loader, val_loader, test_loader = create_data_loaders(
    config=config,
    training_dir=TRAINING_DIR,
    validation_dir=VALIDATION_DIR,
    testing_dir=TESTING_DIR,
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
if test_loader:
    print(f"Testing samples: {len(test_loader.dataset)}")

## Create Model

In [None]:
from jetson import AutopilotModel, Trainer

model = AutopilotModel(config=config.model, pretrained=True)
trainer = Trainer(model=model, config=config)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Training on: {trainer.device}")

## Train

In [None]:
history = trainer.train(train_loader, val_loader)

## Plot Training Progress

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

epochs = range(1, len(history['training_loss']) + 1)
ax.plot(epochs, history['training_loss'], label='Training Loss', marker='o')
ax.plot(epochs, history['validation_loss'], label='Validation Loss', marker='s')

ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('Training Progress')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Best validation loss: {min(history['validation_loss']):.6f}")

## Test Model

In [None]:
if test_loader:
    avg_loss, results = trainer.test(test_loader, config.model_path)
    
    passed = sum(1 for r in results if r['passed'])
    print(f"\nTest Results: {passed}/{len(results)} passed")
    print(f"Average Loss: {avg_loss:.4f}")

## Visualize Test Results

In [None]:
from jetson.preprocessing import ImagePreprocessor
import torchvision.transforms.functional as TF

if test_loader:
    preprocessor = ImagePreprocessor(device=trainer.device, frame_size=config.model.frame_size)
    
    for i, (name, image, target) in enumerate(test_loader):
        if i >= 8:  # Show first 8 samples
            break
            
        result = results[i]
        
        # Denormalize image for display
        img_display = preprocessor.denormalize(image[0])
        
        fig, ax = plt.subplots(figsize=(4, 4))
        ax.imshow(img_display)
        ax.axis('off')
        
        status = "PASS" if result['passed'] else "FAIL"
        color = 'green' if result['passed'] else 'red'
        
        ax.set_title(
            f"[{status}] Loss: {result['loss']:.4f}\n"
            f"Expected: [{result['expected'][0]:.2f}, {result['expected'][1]:.2f}]\n"
            f"Predicted: [{result['predicted'][0]:.2f}, {result['predicted'][1]:.2f}]",
            color=color,
            fontsize=10
        )
        
        plt.tight_layout()
        plt.show()

## Export for Jetson

Copy the model file to your Jetson device. TensorRT conversion will happen automatically on first run.

In [None]:
print(f"Model saved at: {config.model_path}")
print(f"\nTo run on Jetson:")
print(f"  python -m jetson run --model-path {config.model_path.name} --show-fps")