# ðŸŒŠ Marine Debris Detection - Data Exploration

This notebook provides an introduction to the marine debris detection system and explores the MARIDA dataset.

In [None]:
# Setup
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Check device
from src.utils.device import print_device_info
print_device_info()

## 1. Load Configuration

In [None]:
from src.utils.config import load_config, get_default_config

# Load config
config_path = Path('../config.yaml')
if config_path.exists():
    config = load_config(str(config_path))
else:
    config = get_default_config()

print("Configuration loaded!")
print(f"Bands: {config['data']['bands']}")
print(f"Model backbone: {config['model']['backbone']}")

## 2. Create Sample Data (if MARIDA not downloaded)

In [None]:
from src.data.download import create_sample_data

# Create sample data for exploration
sample_dir = create_sample_data('../data/sample')
print(f"Sample data created at: {sample_dir}")

## 3. Load and Visualize Sample Data

In [None]:
import rasterio

# Load sample scene
sample_path = '../data/sample/sample_scene.tif'
mask_path = '../data/sample/sample_mask.tif'

with rasterio.open(sample_path) as src:
    image = src.read()
    print(f"Image shape: {image.shape}")
    print(f"CRS: {src.crs}")
    print(f"Bounds: {src.bounds}")

with rasterio.open(mask_path) as src:
    mask = src.read(1)
    print(f"Mask shape: {mask.shape}")
    print(f"Unique values: {np.unique(mask)}")

In [None]:
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# RGB composite (bands 3, 2, 1 = R, G, B)
rgb = np.stack([image[2], image[1], image[0]], axis=-1)
rgb = np.clip(rgb * 5, 0, 1)  # Enhance for visibility

axes[0].imshow(rgb)
axes[0].set_title('RGB Composite')
axes[0].axis('off')

# NIR false color (bands 4, 3, 2 = NIR, R, G)
nir = np.stack([image[3], image[2], image[1]], axis=-1)
nir = np.clip(nir * 3, 0, 1)

axes[1].imshow(nir)
axes[1].set_title('NIR False Color')
axes[1].axis('off')

# Mask
axes[2].imshow(mask, cmap='Reds')
axes[2].set_title('Debris Mask')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 4. Spectral Analysis

In [None]:
# Band statistics
band_names = ['B2 (Blue)', 'B3 (Green)', 'B4 (Red)', 'B8 (NIR)', 'B11 (SWIR1)', 'B12 (SWIR2)']

print("Band Statistics:")
print("-" * 50)
for i, name in enumerate(band_names):
    band = image[i]
    print(f"{name}: min={band.min():.4f}, max={band.max():.4f}, mean={band.mean():.4f}")

In [None]:
# Compare spectral signatures: debris vs water
debris_mask = mask == 1
water_mask = mask == 0

debris_spectrum = [image[i][debris_mask].mean() if debris_mask.sum() > 0 else 0 for i in range(6)]
water_spectrum = [image[i][water_mask].mean() for i in range(6)]

plt.figure(figsize=(10, 5))
x = range(len(band_names))
plt.plot(x, water_spectrum, 'b-o', label='Water')
plt.plot(x, debris_spectrum, 'r-o', label='Debris')
plt.xticks(x, [name.split()[0] for name in band_names])
plt.xlabel('Band')
plt.ylabel('Reflectance')
plt.title('Spectral Signatures')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 5. Model Architecture Preview

In [None]:
import torch
from src.models.segformer import create_model
from src.utils.device import get_device

# Create model
device = get_device('auto')
model_config = {
    'backbone': 'mit_b2',
    'num_classes': 2,
    'in_channels': 6,
    'pretrained': False,  # Don't download weights for exploration
}

model = create_model(model_config, device=device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model created on: {device}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Test forward pass
model.eval()
x = torch.randn(1, 6, 256, 256).to(device)

with torch.no_grad():
    output = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Output range: [{output.min().item():.2f}, {output.max().item():.2f}]")

## 6. Next Steps

1. **Download MARIDA dataset**: `python scripts/download_marida.py`
2. **Train model**: `python scripts/train.py`
3. **Run inference**: `python scripts/predict.py --input <image.tif>`

See README.md for detailed instructions.