# NeuroTrain Complete Workflow Tutorial

This notebook demonstrates a complete end-to-end deep learning workflow using NeuroTrain, including:

1. Environment Setup
2. Data Loading and Exploration
3. Model Selection and Configuration
4. Training Process
5. Model Evaluation
6. Results Visualization
7. Model Export and Deployment

## Prerequisites

```bash
conda activate ntrain
uv pip install -e '.[cu128]'
```


In [None]:
# Step 1: Import Libraries and Setup
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent))

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from src.dataset import get_train_valid_test_dataloader
from src.models import get_model
from src.metrics import accuracy, dice, iou_seg
from src.utils.criterion import get_criterion
from src.utils import EarlyStopping

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Step 2: Data Loading and Exploration

We'll use CIFAR-10 as an example dataset for this tutorial.


In [None]:
# Configure dataset
config = {
    "dataset": {
        "name": "cifar10",
        "root_dir": "../data/cifar10",
        "train": True,
        "download": True,
    },
    "training": {"batch_size": 128, "num_workers": 2},
}

# Load data
print("Loading datasets...")
train_loader, valid_loader, test_loader = get_train_valid_test_dataloader(config)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(valid_loader) if valid_loader else 0}")
print(f"Test batches: {len(test_loader) if test_loader else 0}")

# CIFAR-10 classes
classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]


# Visualize some samples
def show_batch(loader, num_images=16):
    dataiter = iter(loader)
    images, labels = next(dataiter)

    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    for i, ax in enumerate(axes.flat):
        if i < num_images:
            img = images[i].permute(1, 2, 0).numpy()
            img = (img - img.min()) / (img.max() - img.min())
            ax.imshow(img)
            ax.set_title(f"{classes[labels[i]]}")
            ax.axis("off")
    plt.tight_layout()
    plt.show()


show_batch(train_loader)

## Step 3: Model Creation

We'll use a ResNet18 model pretrained on ImageNet and fine-tune it for CIFAR-10.


In [None]:
# Model configuration
model_config = {
    "arch": "resnet18",
    "pretrained": True,
    "n_classes": 10,
    "n_channels": 3,
}

# Create model
print("Creating model...")
model = get_model("torchvision", model_config)
model = model.to(device)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel: {model_config['arch']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB (float32)")

## Step 4: Training Configuration

Define loss function, optimizer, and learning rate scheduler.


In [None]:
# Training configuration
num_epochs = 20
learning_rate = 0.001

# Loss function
criterion = nn.CrossEntropyLoss()
print(f"Loss function: {criterion.__class__.__name__}")

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
print(f"Optimizer: {optimizer.__class__.__name__}, LR: {learning_rate}")

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
print(f"Scheduler: {scheduler.__class__.__name__}")

# Early stopping
early_stopping = EarlyStopping(patience=5, min_delta=1e-4, mode="min")
print(f"Early stopping: patience={early_stopping.patience}")

## Step 5: Training Loop

Now let's implement the training loop with validation.


In [None]:
# Training loop
train_losses = []
val_losses = []
train_accs = []
val_accs = []
best_val_loss = float("inf")

print("\n" + "=" * 80)
print("Starting Training")
print("=" * 80)

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for images, labels in train_pbar:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Statistics
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()

        # Update progress bar
        train_pbar.set_postfix(
            {
                "loss": f"{loss.item():.4f}",
                "acc": f"{100.*train_correct/train_total:.2f}%",
            }
        )

    avg_train_loss = train_loss / len(train_loader)
    train_acc = 100.0 * train_correct / train_total
    train_losses.append(avg_train_loss)
    train_accs.append(train_acc)

    # Validation phase
    if valid_loader:
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        avg_val_loss = val_loss / len(valid_loader)
        val_acc = 100.0 * val_correct / val_total
        val_losses.append(avg_val_loss)
        val_accs.append(val_acc)

        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{num_epochs}:")
        print(f"  Train - Loss: {avg_train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"  Valid - Loss: {avg_val_loss:.4f}, Acc: {val_acc:.2f}%")
        print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
            print(f"  ✓ New best model!")

        # Early stopping
        if early_stopping(avg_val_loss):
            print(f"\n⚠ Early stopping triggered at epoch {epoch+1}")
            break

    # Learning rate scheduling
    scheduler.step()

print("\n" + "=" * 80)
print("Training Complete!")
print("=" * 80)
print(f"Best validation loss: {best_val_loss:.4f}")

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
ax1.plot(train_losses, label="Training Loss", marker="o", linewidth=2)
if val_losses:
    ax1.plot(val_losses, label="Validation Loss", marker="s", linewidth=2)
ax1.set_xlabel("Epoch", fontsize=12)
ax1.set_ylabel("Loss", fontsize=12)
ax1.set_title("Training and Validation Loss", fontsize=14, fontweight="bold")
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2.plot(train_accs, label="Training Accuracy", marker="o", linewidth=2)
if val_accs:
    ax2.plot(val_accs, label="Validation Accuracy", marker="s", linewidth=2)
ax2.set_xlabel("Epoch", fontsize=12)
ax2.set_ylabel("Accuracy (%)", fontsize=12)
ax2.set_title("Training and Validation Accuracy", fontsize=14, fontweight="bold")
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final statistics
print(f"\nFinal Training Accuracy: {train_accs[-1]:.2f}%")
if val_accs:
    print(f"Final Validation Accuracy: {val_accs[-1]:.2f}%")
print(f"Best Validation Loss: {best_val_loss:.4f}")

## Step 7: Model Evaluation on Test Set

Load the best model and evaluate on the test set.


In [None]:
# Load best model
if "best_model_state" in locals():
    model.load_state_dict(best_model_state)
    print("Loaded best model for testing")

# Test evaluation
if test_loader:
    model.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0

    # For confusion matrix
    all_predictions = []
    all_labels = []

    print("\nEvaluating on test set...")
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_acc = 100.0 * test_correct / test_total
    avg_test_loss = test_loss / len(test_loader)

    print("\n" + "=" * 80)
    print("Test Results")
    print("=" * 80)
    print(f"Test Loss: {avg_test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Correct: {test_correct} / {test_total}")
else:
    print("No test loader available")

## Step 8: Confusion Matrix and Per-Class Metrics

Visualize the confusion matrix and compute per-class metrics.


In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

if test_loader and "all_predictions" in locals():
    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)

    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=classes,
        yticklabels=classes,
        cbar_kws={"label": "Count"},
    )
    plt.xlabel("Predicted Label", fontsize=12)
    plt.ylabel("True Label", fontsize=12)
    plt.title("Confusion Matrix", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()

    # Classification report
    print("\nPer-Class Metrics:")
    print(
        classification_report(
            all_labels, all_predictions, target_names=classes, digits=4
        )
    )

    # Per-class accuracy
    print("\nPer-Class Accuracy:")
    for i, class_name in enumerate(classes):
        class_correct = cm[i, i]
        class_total = cm[i].sum()
        class_acc = 100.0 * class_correct / class_total if class_total > 0 else 0
        print(f"  {class_name:12s}: {class_acc:6.2f}% ({class_correct}/{class_total})")

## Step 9: Sample Predictions Visualization

Let's visualize some predictions to see how the model performs.


In [None]:
if test_loader:
    # Get a batch of test images
    dataiter = iter(test_loader)
    images, labels = next(dataiter)
    images, labels = images.to(device), labels.to(device)

    # Get predictions
    model.eval()
    with torch.no_grad():
        outputs = model(images)
        _, predicted = outputs.max(1)

    # Move to CPU for visualization
    images = images.cpu()
    labels = labels.cpu()
    predicted = predicted.cpu()

    # Visualize predictions
    fig, axes = plt.subplots(4, 4, figsize=(16, 16))
    for i, ax in enumerate(axes.flat):
        if i < len(images):
            img = images[i].permute(1, 2, 0).numpy()
            img = (img - img.min()) / (img.max() - img.min())

            ax.imshow(img)
            true_label = classes[labels[i]]
            pred_label = classes[predicted[i]]

            # Color: green if correct, red if wrong
            color = "green" if labels[i] == predicted[i] else "red"
            ax.set_title(
                f"True: {true_label}\nPred: {pred_label}",
                color=color,
                fontsize=10,
                fontweight="bold",
            )
            ax.axis("off")

    plt.suptitle(
        "Sample Predictions (Green=Correct, Red=Wrong)", fontsize=16, fontweight="bold"
    )
    plt.tight_layout()
    plt.show()

## Step 10: Model Export and Deployment

Save the model for future use and optionally export to ONNX format.


In [None]:
# Create output directory
output_dir = Path("../runs/tutorial_example")
output_dir.mkdir(parents=True, exist_ok=True)

# Save PyTorch model
model_path = output_dir / "best_model.pth"
torch.save(
    {
        "epoch": len(train_losses),
        "model_state_dict": (
            best_model_state if "best_model_state" in locals() else model.state_dict()
        ),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_losses": train_losses,
        "val_losses": val_losses,
        "train_accs": train_accs,
        "val_accs": val_accs,
        "test_acc": test_acc if "test_acc" in locals() else None,
        "config": {
            "model": model_config,
            "training": {"num_epochs": num_epochs, "learning_rate": learning_rate},
        },
    },
    model_path,
)

print(f"✓ Model saved to: {model_path}")
print(f"  File size: {model_path.stat().st_size / 1024 / 1024:.1f} MB")

# Optional: Export to ONNX (uncomment to use)
# try:
#     onnx_path = output_dir / 'model.onnx'
#     dummy_input = torch.randn(1, 3, 32, 32).to(device)
#     torch.onnx.export(
#         model,
#         dummy_input,
#         onnx_path,
#         input_names=['input'],
#         output_names=['output'],
#         dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
#         opset_version=11
#     )
#     print(f"✓ ONNX model saved to: {onnx_path}")
# except Exception as e:
#     print(f"⚠ ONNX export failed: {e}")

## Summary

Congratulations! You've completed a full deep learning workflow with NeuroTrain. Here's what we covered:

### ✅ What We Accomplished

1. **Environment Setup** - Configured PyTorch and NeuroTrain
2. **Data Loading** - Loaded and visualized CIFAR-10 dataset
3. **Model Creation** - Built a ResNet18 model with pretrained weights
4. **Training Configuration** - Set up loss, optimizer, and scheduler
5. **Model Training** - Trained the model with validation
6. **Visualization** - Plotted training curves and metrics
7. **Evaluation** - Tested model on test set
8. **Analysis** - Generated confusion matrix and per-class metrics
9. **Prediction Visualization** - Visualized sample predictions
10. **Model Export** - Saved model for deployment

### 📊 Key Results

- Training completed successfully
- Model saved and ready for deployment
- Comprehensive evaluation metrics computed
- Visualization of model performance

### 🚀 Next Steps

1. **Experiment with different models**: Try VGG, EfficientNet, or Vision Transformers
2. **Try different datasets**: Medical images, COCO, ImageNet
3. **Tune hyperparameters**: Learning rate, batch size, augmentation
4. **Advanced features**: Mixed precision training, distributed training
5. **Deploy your model**: Export to ONNX, quantization, TorchScript

### 📚 Additional Resources

- [Dataset Module Documentation](../docs/DATASET_MODULE.md)
- [Models Module Documentation](../docs/MODELS_MODULE.md)
- [Engine Module Documentation](../docs/ENGINE_MODULE.md)
- [Project Architecture](../docs/ARCHITECTURE.md)
- [More Examples](../examples/)

### 💡 Tips for Better Results

1. **Use data augmentation** to improve generalization
2. **Monitor validation metrics** to detect overfitting
3. **Use learning rate scheduling** for better convergence
4. **Save checkpoints regularly** to resume interrupted training
5. **Visualize results** to understand model behavior

---

**Happy Training with NeuroTrain! 🎉**
