# Robot CN Network - Data Analysis

This notebook provides tools for analyzing collected demonstration data and training results.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path

from robot_cn_network.data import RobotDataset
from robot_cn_network.utils import load_dataset, compute_metrics
from robot_cn_network.models import ACTPolicy, ModelConfig

## Load and Analyze Dataset

In [None]:
# Load demonstration data
data_path = "../data/demonstrations"
dataset = RobotDataset(data_path, sequence_length=1, action_horizon=10)

print(f"Dataset size: {len(dataset)} samples")
print(f"Number of episodes: {len(dataset.episodes)}")

In [None]:
# Analyze action distributions
all_actions = []
for episode in dataset.episodes:
    actions = [step['action'] for step in episode]
    all_actions.extend(actions)

all_actions = np.array(all_actions)

fig, axes = plt.subplots(2, 4, figsize=(15, 8))
axes = axes.flatten()

for i in range(min(7, all_actions.shape[1])):
    axes[i].hist(all_actions[:, i], bins=50, alpha=0.7)
    axes[i].set_title(f'Action Dimension {i}')
    axes[i].set_xlabel('Value')
    axes[i].set_ylabel('Frequency')

plt.tight_layout()
plt.show()

## Training Progress Analysis

In [None]:
# Load and visualize training logs (if available)
import json
import os

log_path = "../outputs/training/training_log.json"
if os.path.exists(log_path):
    with open(log_path, 'r') as f:
        logs = json.load(f)
    
    epochs = [log['epoch'] for log in logs]
    train_losses = [log['train_loss'] for log in logs]
    val_losses = [log.get('val_loss') for log in logs if 'val_loss' in log]
    
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses, label='Training Loss')
    if val_losses:
        val_epochs = [log['epoch'] for log in logs if 'val_loss' in log]
        plt.plot(val_epochs, val_losses, label='Validation Loss')
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True)
    plt.show()
else:
    print("No training logs found. Run training first.")

## Model Analysis

In [None]:
# Load trained model and analyze
model_path = "../outputs/training/best_model.pth"
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location='cpu')
    
    print("Model checkpoint info:")
    print(f"Epoch: {checkpoint.get('epoch', 'N/A')}")
    print(f"Validation Loss: {checkpoint.get('metadata', {}).get('val_loss', 'N/A')}")
    
    # Count parameters
    total_params = sum(p.numel() for p in checkpoint['model_state_dict'].values())
    print(f"Total parameters: {total_params:,}")
else:
    print("No trained model found. Run training first.")