# IRST Library - Complete Tutorial
## Infrared Small Target Detection with Deep Learning

Welcome to the comprehensive tutorial for the IRST Library! This notebook will guide you through the complete workflow of infrared small target detection, from installation to deployment.

### What you'll learn:
- 🔧 **Setup & Installation** - Get started with IRST Library
- 📊 **Data Loading & Exploration** - Work with SIRST dataset
- 🧠 **Model Training** - Train state-of-the-art models like SERANKNet
- 📈 **Evaluation & Metrics** - Assess model performance
- 🎯 **Inference** - Detect targets in new images
- 🚀 **Deployment** - Export and serve models

### Prerequisites:
- Python 3.8+
- PyTorch 1.12+
- CUDA (optional, for GPU acceleration)

Let's get started! 🚀

## 1. Import Required Libraries

First, let's import all the essential libraries we'll need for this tutorial.

In [None]:
# Core libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Deep learning libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# IRST Library - Main components
from irst_library import IRSTDetector
from irst_library.datasets import SIRSTDataset, IRSTD1kDataset
from irst_library.models import SERANKNet, ACMNet, MSHNet
from irst_library.training import IRSTTrainer
from irst_library.evaluation import IRSTEvaluator
from irst_library.utils import visualize_detection, plot_metrics

# Computer vision libraries
import cv2
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Utilities
from tqdm.auto import tqdm
import logging
from datetime import datetime

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

print("✅ All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Load and Explore Data

Let's load the SIRST dataset and explore its characteristics. The SIRST dataset contains infrared images with small target annotations.

In [None]:
# Set up data directory
DATA_DIR = "./data/SIRST"
MODEL_DIR = "./models"
OUTPUT_DIR = "./outputs"

# Create directories if they don't exist
for dir_path in [DATA_DIR, MODEL_DIR, OUTPUT_DIR]:
    Path(dir_path).mkdir(parents=True, exist_ok=True)

# Load SIRST dataset
print("📊 Loading SIRST Dataset...")

# Training dataset
train_dataset = SIRSTDataset(
    root=DATA_DIR,
    split="train",
    transform=None,  # We'll add transforms later
    download=True    # Download if not available
)

# Test dataset
test_dataset = SIRSTDataset(
    root=DATA_DIR,
    split="test",
    transform=None,
    download=False
)

print(f"✅ Dataset loaded successfully!")
print(f"   📈 Training samples: {len(train_dataset)}")
print(f"   📊 Test samples: {len(test_dataset)}")

# Explore dataset statistics
sample_image, sample_mask = train_dataset[0]
print(f"\n📐 Image dimensions: {sample_image.shape}")
print(f"📐 Mask dimensions: {sample_mask.shape}")
print(f"🎨 Image dtype: {sample_image.dtype}")
print(f"🎨 Mask dtype: {sample_mask.dtype}")
print(f"📊 Image range: [{sample_image.min():.3f}, {sample_image.max():.3f}]")
print(f"📊 Mask range: [{sample_mask.min():.3f}, {sample_mask.max():.3f}]")

In [None]:
# Visualize sample images from the dataset
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle('SIRST Dataset Sample Images', fontsize=16, fontweight='bold')

for i in range(4):
    # Get sample
    image, mask = train_dataset[i * 50]  # Sample every 50th image
    
    # Convert to numpy if tensor
    if torch.is_tensor(image):
        image = image.numpy()
    if torch.is_tensor(mask):
        mask = mask.numpy()
    
    # Original image
    axes[0, i].imshow(image.squeeze(), cmap='gray')
    axes[0, i].set_title(f'Image {i+1}')
    axes[0, i].axis('off')
    
    # Ground truth mask
    axes[1, i].imshow(mask.squeeze(), cmap='hot')
    axes[1, i].set_title(f'Ground Truth {i+1}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

# Analyze target characteristics
print("\n🎯 Target Analysis:")
target_counts = []
target_sizes = []

for i in range(min(100, len(train_dataset))):  # Analyze first 100 samples
    _, mask = train_dataset[i]
    if torch.is_tensor(mask):
        mask = mask.numpy()
    
    # Count targets
    num_targets = len(np.unique(mask)) - 1  # Subtract background
    target_counts.append(num_targets)
    
    # Calculate target sizes
    if num_targets > 0:
        target_pixels = np.sum(mask > 0)
        target_sizes.append(target_pixels)

print(f"📊 Average targets per image: {np.mean(target_counts):.2f}")
print(f"📊 Max targets per image: {np.max(target_counts)}")
print(f"📊 Images with targets: {np.sum(np.array(target_counts) > 0)} / {len(target_counts)}")
if target_sizes:
    print(f"📊 Average target size: {np.mean(target_sizes):.2f} pixels")
    print(f"📊 Target size range: [{np.min(target_sizes)}, {np.max(target_sizes)}] pixels")

## 3. Data Preprocessing

Now let's set up data preprocessing pipelines including normalization, augmentation, and proper data loaders for training.

In [None]:
# Define augmentation pipelines
train_transform = A.Compose([
    # Geometric transformations
    A.RandomRotate90(p=0.5),
    A.Flip(p=0.5),
    A.Transpose(p=0.5),
    A.ShiftScaleRotate(
        shift_limit=0.1,
        scale_limit=0.1,
        rotate_limit=15,
        p=0.5
    ),
    
    # Photometric transformations
    A.RandomBrightnessContrast(
        brightness_limit=0.2,
        contrast_limit=0.2,
        p=0.5
    ),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
    A.GaussianBlur(blur_limit=(3, 7), p=0.3),
    
    # Normalization
    A.Normalize(
        mean=[0.485],  # Single channel for infrared
        std=[0.229],
        max_pixel_value=255.0
    ),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

test_transform = A.Compose([
    A.Normalize(
        mean=[0.485],
        std=[0.229],
        max_pixel_value=255.0
    ),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

# Create datasets with transforms
train_dataset_aug = SIRSTDataset(
    root=DATA_DIR,
    split="train",
    transform=train_transform
)

test_dataset_clean = SIRSTDataset(
    root=DATA_DIR,
    split="test",
    transform=test_transform
)

print(f"✅ Preprocessing setup complete!")
print(f"   🔄 Training dataset with augmentation: {len(train_dataset_aug)} samples")
print(f"   🔍 Test dataset (clean): {len(test_dataset_clean)} samples")

# Create data loaders
BATCH_SIZE = 16
NUM_WORKERS = 4

train_loader = DataLoader(
    train_dataset_aug,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available(),
    drop_last=True
)

test_loader = DataLoader(
    test_dataset_clean,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available()
)

print(f"   📦 Train batches: {len(train_loader)}")
print(f"   📦 Test batches: {len(test_loader)}")
print(f"   🔢 Batch size: {BATCH_SIZE}")

# Visualize augmented samples
print("\n🎨 Augmentation Examples:")
sample_batch = next(iter(train_loader))
images, masks = sample_batch

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle('Data Augmentation Examples', fontsize=16, fontweight='bold')

for i in range(4):
    # Denormalize for visualization
    img = images[i].squeeze().numpy()
    img = img * 0.229 + 0.485  # Reverse normalization
    img = np.clip(img, 0, 1)
    
    mask = masks[i].squeeze().numpy()
    
    axes[0, i].imshow(img, cmap='gray')
    axes[0, i].set_title(f'Augmented Image {i+1}')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(mask, cmap='hot')
    axes[1, i].set_title(f'Augmented Mask {i+1}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

## 4. Model Training

Time to train our SERANKNet model! We'll use the IRST Library's built-in trainer with advanced features like mixed precision, learning rate scheduling, and early stopping.

In [None]:
# Initialize SERANKNet model
print("🧠 Initializing SERANKNet model...")

model = SERANKNet(
    in_channels=1,  # Grayscale infrared images
    num_classes=1,  # Binary segmentation
    pretrained=True,
    use_attention=True
)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(f"✅ Model initialized on {device}")
print(f"📊 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"📊 Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Set up training configuration
config = {
    'model_name': 'serank_sirst_tutorial',
    'epochs': 50,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'scheduler': 'cosine',
    'warmup_epochs': 5,
    'mixed_precision': True,
    'gradient_clipping': 1.0,
    'early_stopping_patience': 10,
    'save_best_only': True,
    'monitor_metric': 'val_iou',
    'loss_weights': {
        'dice': 0.5,
        'focal': 0.3,
        'iou': 0.2
    }
}

# Initialize trainer
trainer = IRSTTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,  # Using test as validation for demo
    config=config,
    device=device,
    save_dir=MODEL_DIR
)

print(f"🚀 Trainer initialized with configuration:")
for key, value in config.items():
    print(f"   {key}: {value}")

# Training loop with progress tracking
print(f"\n🔥 Starting training for {config['epochs']} epochs...")
print(f"💾load checkpoint from: {MODEL_DIR}/{config['model_name']}_best.pth")

# Start training
training_history = trainer.train()

print(f"✅ Training completed!")
print(f"💾 Best model saved at: {MODEL_DIR}/{config['model_name']}_best.pth")

## 5. Model Evaluation

Let's evaluate our trained model using comprehensive metrics and visualizations to understand its performance on infrared small target detection.

In [None]:
# Load the best trained model
best_model_path = f"{MODEL_DIR}/{config['model_name']}_best.pth"
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=device))
    print(f"✅ Loaded best model from {best_model_path}")
else:
    print("⚠️ Best model not found, using current model state")

# Initialize evaluator
evaluator = IRSTEvaluator(
    model=model,
    device=device,
    metrics=['iou', 'dice', 'precision', 'recall', 'f1', 'ber']
)

# Evaluate on test set
print("📊 Evaluating model on test set...")
model.eval()
test_metrics = evaluator.evaluate(test_loader)

print(f"\n🎯 Test Results:")
print(f"   IoU Score: {test_metrics['iou']:.4f}")
print(f"   Dice Score: {test_metrics['dice']:.4f}")
print(f"   Precision: {test_metrics['precision']:.4f}")
print(f"   Recall: {test_metrics['recall']:.4f}")
print(f"   F1 Score: {test_metrics['f1']:.4f}")
print(f"   BER (Background Error Rate): {test_metrics['ber']:.4f}")

# Plot training history
if training_history:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Training History', fontsize=16, fontweight='bold')
    
    # Loss curves
    axes[0, 0].plot(training_history['train_loss'], label='Train Loss', linewidth=2)
    axes[0, 0].plot(training_history['val_loss'], label='Validation Loss', linewidth=2)
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # IoU curves
    axes[0, 1].plot(training_history['train_iou'], label='Train IoU', linewidth=2)
    axes[0, 1].plot(training_history['val_iou'], label='Validation IoU', linewidth=2)
    axes[0, 1].set_title('IoU Curves')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('IoU')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Learning rate
    axes[1, 0].plot(training_history['learning_rate'], linewidth=2, color='orange')
    axes[1, 0].set_title('Learning Rate Schedule')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Metrics summary
    metrics_data = {
        'Metric': ['IoU', 'Dice', 'Precision', 'Recall', 'F1'],
        'Score': [test_metrics['iou'], test_metrics['dice'], 
                 test_metrics['precision'], test_metrics['recall'], test_metrics['f1']]
    }
    axes[1, 1].bar(metrics_data['Metric'], metrics_data['Score'], color='skyblue', alpha=0.7)
    axes[1, 1].set_title('Test Metrics Summary')
    axes[1, 1].set_ylim(0, 1)
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for i, v in enumerate(metrics_data['Score']):
        axes[1, 1].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

# Performance analysis by target size
print("\n🔍 Performance Analysis by Target Size:")
size_analysis = evaluator.analyze_by_target_size(test_loader)
for size_range, metrics in size_analysis.items():
    print(f"   {size_range}: IoU={metrics['iou']:.4f}, Count={metrics['count']}")

## 6. Model Inference

Now let's use our trained model to detect infrared small targets in new images and visualize the results.

In [None]:
# Create high-level detector for easy inference
detector = IRSTDetector.from_pretrained(
    model_path=best_model_path,
    config=config,
    device=device
)

print("🎯 IRST Detector initialized for inference!")

# Inference on test samples
print("\n🔍 Running inference on test samples...")
model.eval()

# Select random test samples
import random
test_indices = random.sample(range(len(test_dataset_clean)), 6)

fig, axes = plt.subplots(3, 6, figsize=(20, 12))
fig.suptitle('Model Inference Results', fontsize=18, fontweight='bold')

with torch.no_grad():
    for idx, test_idx in enumerate(test_indices):
        # Get test sample
        image, gt_mask = test_dataset_clean[test_idx]
        
        # Add batch dimension
        image_batch = image.unsqueeze(0).to(device)
        
        # Inference
        pred_mask = model(image_batch)
        pred_mask = torch.sigmoid(pred_mask)  # Apply sigmoid for binary classification
        pred_mask = (pred_mask > 0.5).float()  # Threshold
        
        # Convert back to numpy
        image_np = image.squeeze().numpy()
        gt_mask_np = gt_mask.squeeze().numpy()
        pred_mask_np = pred_mask.squeeze().cpu().numpy()
        
        # Denormalize image for visualization
        image_vis = image_np * 0.229 + 0.485
        image_vis = np.clip(image_vis, 0, 1)
        
        # Plot original image
        axes[0, idx].imshow(image_vis, cmap='gray')
        axes[0, idx].set_title(f'Input Image {idx+1}')
        axes[0, idx].axis('off')
        
        # Plot ground truth
        axes[1, idx].imshow(gt_mask_np, cmap='hot', alpha=0.8)
        axes[1, idx].imshow(image_vis, cmap='gray', alpha=0.6)
        axes[1, idx].set_title(f'Ground Truth {idx+1}')
        axes[1, idx].axis('off')
        
        # Plot prediction
        axes[2, idx].imshow(pred_mask_np, cmap='hot', alpha=0.8)
        axes[2, idx].imshow(image_vis, cmap='gray', alpha=0.6)
        axes[2, idx].set_title(f'Prediction {idx+1}')
        axes[2, idx].axis('off')

plt.tight_layout()
plt.show()

# Calculate per-sample metrics
print("\n📊 Per-sample Performance:")
sample_metrics = []

for idx, test_idx in enumerate(test_indices):
    image, gt_mask = test_dataset_clean[test_idx]
    image_batch = image.unsqueeze(0).to(device)
    
    with torch.no_grad():
        pred_mask = model(image_batch)
        pred_mask = torch.sigmoid(pred_mask)
        pred_mask = (pred_mask > 0.5).float()
    
    # Calculate IoU for this sample
    pred_np = pred_mask.squeeze().cpu().numpy()
    gt_np = gt_mask.squeeze().numpy()
    
    intersection = np.sum(pred_np * gt_np)
    union = np.sum(pred_np) + np.sum(gt_np) - intersection
    iou = intersection / (union + 1e-8)
    
    sample_metrics.append(iou)
    print(f"   Sample {idx+1}: IoU = {iou:.4f}")

print(f"\n📈 Average IoU on selected samples: {np.mean(sample_metrics):.4f}")

# Inference time benchmarking
print("\n⚡ Performance Benchmarking:")
import time

# Warm up
for _ in range(10):
    with torch.no_grad():
        _ = model(image_batch)

# Benchmark
inference_times = []
for _ in range(100):
    start_time = time.time()
    with torch.no_grad():
        _ = model(image_batch)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    inference_times.append(time.time() - start_time)

avg_time = np.mean(inference_times) * 1000  # Convert to ms
fps = 1.0 / (np.mean(inference_times))

print(f"   ⏱️ Average inference time: {avg_time:.2f} ms")
print(f"   🚀 Throughput: {fps:.1f} FPS")
print(f"   💾 Model size: {os.path.getsize(best_model_path) / 1024 / 1024:.2f} MB")

## 7. Model Export and Deployment

Finally, let's export our trained model for production deployment using ONNX and create a simple REST API.

In [None]:
# Export model to ONNX format
print("📦 Exporting model to ONNX...")

# Create dummy input for ONNX export
dummy_input = torch.randn(1, 1, 256, 256).to(device)
onnx_path = f"{OUTPUT_DIR}/serank_sirst_model.onnx"

try:
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size', 2: 'height', 3: 'width'},
            'output': {0: 'batch_size', 2: 'height', 3: 'width'}
        }
    )
    print(f"✅ ONNX model exported to: {onnx_path}")
    print(f"📦 ONNX model size: {os.path.getsize(onnx_path) / 1024 / 1024:.2f} MB")
except Exception as e:
    print(f"❌ ONNX export failed: {e}")

# Save model metadata
metadata = {
    'model_name': 'SERANKNet',
    'dataset': 'SIRST',
    'input_size': [1, 1, 256, 256],
    'num_classes': 1,
    'metrics': test_metrics,
    'training_config': config,
    'export_date': datetime.now().isoformat(),
    'model_version': '1.0.0'
}

import json
metadata_path = f"{OUTPUT_DIR}/model_metadata.json"
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"✅ Model metadata saved to: {metadata_path}")

# Create deployment package
print("\n🚀 Creating deployment package...")

deployment_files = {
    'model': best_model_path,
    'onnx': onnx_path,
    'metadata': metadata_path,
    'config': f"{OUTPUT_DIR}/deployment_config.yaml"
}

# Create deployment config
deployment_config = {
    'model': {
        'name': 'serank_sirst',
        'version': '1.0.0',
        'input_shape': [1, 256, 256],
        'preprocessing': {
            'normalize': True,
            'mean': [0.485],
            'std': [0.229]
        },
        'postprocessing': {
            'threshold': 0.5,
            'min_target_size': 5
        }
    },
    'api': {
        'host': '0.0.0.0',
        'port': 8000,
        'max_batch_size': 16,
        'timeout': 30
    },
    'monitoring': {
        'log_predictions': True,
        'collect_metrics': True,
        'alert_threshold': 0.1
    }
}

import yaml
with open(deployment_files['config'], 'w') as f:
    yaml.dump(deployment_config, f, default_flow_style=False)

print(f"✅ Deployment configuration saved!")

# Simple REST API example
print("\n🌐 Creating sample REST API...")

api_code = '''
from flask import Flask, request, jsonify
import torch
import numpy as np
from PIL import Image
import io
import base64
import onnxruntime as ort

app = Flask(__name__)

# Load ONNX model
ort_session = ort.InferenceSession("serank_sirst_model.onnx")

@app.route('/health', methods=['GET'])
def health():
    return jsonify({'status': 'healthy', 'model': 'serank_sirst'})

@app.route('/predict', methods=['POST'])
def predict():
    try:
        # Get image from request
        data = request.json
        image_b64 = data['image']
        
        # Decode image
        image_bytes = base64.b64decode(image_b64)
        image = Image.open(io.BytesIO(image_bytes)).convert('L')
        
        # Preprocess
        image = image.resize((256, 256))
        image_array = np.array(image, dtype=np.float32) / 255.0
        image_array = (image_array - 0.485) / 0.229
        image_array = image_array[None, None, :, :]  # Add batch and channel dims
        
        # Inference
        outputs = ort_session.run(None, {'input': image_array})
        prediction = outputs[0][0, 0]  # Remove batch and channel dims
        
        # Postprocess
        prediction = 1 / (1 + np.exp(-prediction))  # Sigmoid
        binary_mask = (prediction > 0.5).astype(np.uint8)
        
        # Count targets
        from scipy import ndimage
        labeled, num_targets = ndimage.label(binary_mask)
        
        return jsonify({
            'num_targets': int(num_targets),
            'confidence': float(np.max(prediction)),
            'prediction_shape': prediction.shape,
            'status': 'success'
        })
        
    except Exception as e:
        return jsonify({'error': str(e), 'status': 'error'}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8000, debug=True)
'''

api_path = f"{OUTPUT_DIR}/api_server.py"
with open(api_path, 'w') as f:
    f.write(api_code)

print(f"✅ REST API created: {api_path}")

# Docker deployment
dockerfile_content = '''
FROM python:3.8-slim

# Install dependencies
RUN pip install torch torchvision flask onnxruntime numpy pillow scipy

# Copy model files
COPY serank_sirst_model.onnx /app/
COPY api_server.py /app/
COPY model_metadata.json /app/

WORKDIR /app

# Expose port
EXPOSE 8000

# Run API
CMD ["python", "api_server.py"]
'''

dockerfile_path = f"{OUTPUT_DIR}/Dockerfile"
with open(dockerfile_path, 'w') as f:
    f.write(dockerfile_content)

print(f"✅ Dockerfile created: {dockerfile_path}")

# Summary
print("\n🎉 Deployment Package Ready!")
print(f"📁 Output directory: {OUTPUT_DIR}")
print(f"   📦 PyTorch model: {os.path.basename(best_model_path)}")
print(f"   📦 ONNX model: {os.path.basename(onnx_path)}")
print(f"   📄 Metadata: model_metadata.json")
print(f"   ⚙️ Config: deployment_config.yaml")
print(f"   🌐 API: api_server.py")
print(f"   🐳 Docker: Dockerfile")

print(f"\n🚀 To deploy:")
print(f"   1. cd {OUTPUT_DIR}")
print(f"   2. docker build -t irst-api .")
print(f"   3. docker run -p 8000:8000 irst-api")
print(f"   4. Test: curl -X GET http://localhost:8000/health")

print(f"\n📊 Final Model Performance:")
print(f"   🎯 IoU: {test_metrics['iou']:.4f}")
print(f"   ⚡ Speed: {fps:.1f} FPS")
print(f"   💾 Size: {os.path.getsize(best_model_path) / 1024 / 1024:.2f} MB")
print(f"\n✨ Tutorial completed successfully! ✨")

## 🎯 What's Next?

Congratulations! You've completed the comprehensive IRST Library tutorial. Here are your next steps:

### 📚 **Explore More Resources**

- **[Model Zoo](../docs/models.md)** - Discover all available pretrained models and their capabilities
- **[Dataset Guide](../docs/datasets.md)** - Learn about dataset preparation and advanced data handling
- **[Benchmarks](../docs/BENCHMARKS.md)** - Compare performance across different models and datasets
- **[API Reference](../docs/api_reference.md)** - Complete API documentation with examples

### 🧪 **Advanced Tutorials**

Continue your journey with specialized notebooks:

1. **[Advanced Training Techniques](training_advanced.ipynb)** - Multi-GPU training, hyperparameter optimization, and custom losses
2. **[Model Zoo Exploration](model_zoo_tutorial.ipynb)** - Compare different architectures and find the best model for your use case
3. **[Dataset Preparation](dataset_preparation.ipynb)** - Create custom datasets and advanced augmentation strategies
4. **[Production Deployment](deployment_tutorial.ipynb)** - Deploy models to cloud platforms and edge devices
5. **[Benchmarking & Analysis](benchmarking_tutorial.ipynb)** - Comprehensive performance analysis and optimization

### 🚀 **Production Ready**

Your model is now ready for:
- **Research Publications** - Use in academic papers with proper citations
- **Commercial Applications** - Deploy in production environments
- **Open Source Contributions** - Contribute back to the community
- **Custom Implementations** - Extend the library for specific use cases

### 🤝 **Community & Support**

- **GitHub Issues** - Report bugs or request features
- **Discussions** - Ask questions and share experiences
- **Contributions** - Help improve the library
- **Publications** - Cite our work in your research

### 📊 **Performance Summary**

Your trained model achieved:
- **IoU Score**: High accuracy infrared target detection
- **Inference Speed**: Real-time performance capabilities
- **Model Size**: Optimized for deployment
- **Export Formats**: PyTorch, ONNX, and deployment-ready packages

---

## 🎉 **Congratulations!**

You've successfully:
✅ Loaded and explored the SIRST dataset  
✅ Implemented data preprocessing and augmentation  
✅ Trained a state-of-the-art SERANKNet model  
✅ Evaluated model performance with comprehensive metrics  
✅ Performed inference on new images  
✅ Exported the model for production deployment  
✅ Created a complete deployment package  

**Happy target detecting!** 🎯✨