# 02 — Segmentation Dataset Exploration: RiceSEG for Rice Field Weed Detection

**Purpose:** Explore the RiceSEG dataset before training a segmentation model (DeepLabV3+).  
**Runtime:** CPU only — no GPU needed. Save your GPU hours for training notebooks.  
**Platform:** Works on both Kaggle and Google Colab.

## What This Notebook Covers

1. **Load RiceSEG** — discover dataset structure, build image-mask pairs
2. **Mask analysis** — class pixel distribution, class weights, per-country breakdown
3. **Weed pixel analysis** — how sparse are weeds? Which images are most useful?
4. **Sample visualization** — images, masks, and color-coded overlays
5. **Image properties** — dimensions, mask values, file sizes
6. **Training recommendations** — loss function, augmentation, expected baseline

### About RiceSEG

| Property | Value |
|----------|-------|
| **Images** | ~3,078 |
| **Resolution** | 512x512 pixels |
| **Classes** | 6: Background, Green vegetation, Senescent vegetation, Panicle, Weeds, Duckweed |
| **Countries** | China, Japan, India, Philippines, Tanzania |
| **Format** | Image + pixel-level mask pairs |
| **Relevance** | Rice field weeds — the Philippines subset is closest to Indonesian conditions |

> **Key challenge:** Weed pixels are sparse (~1.6%) due to herbicide use at collection sites. This means class-weighted or focal loss is critical for training.

---
## 1. Platform Detection & Setup

Same pattern as notebook 01 — detect Kaggle vs Colab vs local.

In [None]:
import os
import sys

# --- Platform Detection ---
IS_KAGGLE = os.path.exists('/kaggle/input')

try:
    import google.colab
    IS_COLAB = True
except ImportError:
    IS_COLAB = False

IS_LOCAL = not IS_KAGGLE and not IS_COLAB

PLATFORM = 'kaggle' if IS_KAGGLE else ('colab' if IS_COLAB else 'local')
print(f'Platform detected: {PLATFORM}')
print(f'Python version: {sys.version}')

### Install Dependencies

Lightweight exploration — only needs PIL for image loading, matplotlib for visualization.

In [None]:
import subprocess
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', 'Pillow'])

print('Dependencies ready.')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path
from PIL import Image
from collections import Counter
import json
import warnings
warnings.filterwarnings('ignore')

# Consistent plot style
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 11

print('Imports ready.')

---
## 2. Load RiceSEG Dataset

### Dataset Structure

RiceSEG organizes data by country/region, with image-mask pairs. The expected structure varies, but typically:

```
riceseg/
  images/
    *.png or *.jpg
  masks/ (or labels/ or annotations/)
    *.png
```

Or organized by country:

```
riceseg/
  China/
    images/ + masks/
  Philippines/
    images/ + masks/
```

### Class Definitions

| Class ID | Name | Color (for overlay) |
|----------|------|---------------------|
| 0 | Background | Black |
| 1 | Green vegetation | Green |
| 2 | Senescent vegetation | Yellow |
| 3 | Panicle | Orange |
| 4 | **Weeds** | Red |
| 5 | **Duckweed** | Blue |

Classes 4 and 5 are our primary targets — everything else is context.

In [None]:
# --- RiceSEG class definitions ---
RICESEG_CLASSES = {
    0: 'Background',
    1: 'Green vegetation',
    2: 'Senescent vegetation',
    3: 'Panicle',
    4: 'Weeds',
    5: 'Duckweed',
}
NUM_CLASSES = len(RICESEG_CLASSES)

# Colors for overlay visualization (RGBA)
CLASS_COLORS = {
    0: (0, 0, 0, 0),          # Background — transparent
    1: (0, 200, 0, 128),      # Green vegetation
    2: (200, 200, 0, 128),    # Senescent vegetation
    3: (255, 165, 0, 128),    # Panicle — orange
    4: (255, 0, 0, 180),      # Weeds — red (highlighted)
    5: (0, 100, 255, 180),    # Duckweed — blue (highlighted)
}

# --- Set data path ---
if IS_KAGGLE:
    DATA_ROOT = Path('/kaggle/input/riceseg')
elif IS_COLAB:
    DATA_ROOT = Path('/content/riceseg')
else:
    DATA_ROOT = Path('./data/riceseg')

print(f'Data root: {DATA_ROOT}')
print(f'Exists: {DATA_ROOT.exists()}')

if not DATA_ROOT.exists():
    print()
    print('=' * 60)
    print('RICESEG NOT FOUND — Setup Instructions')
    print('=' * 60)
    print('RiceSEG is hosted on HuggingFace, NOT Kaggle.')
    print()
    print('Option 1: Upload to Kaggle as private dataset')
    print('  1. Download from HuggingFace (search: "RiceSEG")')
    print('  2. Go to kaggle.com > Datasets > New Dataset')
    print('  3. Upload the extracted folder, name it "riceseg"')
    print('  4. Attach to this notebook via "Add Data"')
    print()
    print('Option 2: For Colab/local')
    print(f'  Download and extract to: {DATA_ROOT}')
    print('=' * 60)

In [None]:
# --- Discover dataset structure ---
if DATA_ROOT.exists():
    contents = sorted(DATA_ROOT.iterdir())
    print(f'Contents of {DATA_ROOT}:')
    for item in contents[:30]:
        kind = 'DIR' if item.is_dir() else f'FILE ({item.suffix})'
        if item.is_dir():
            sub_count = sum(1 for _ in item.iterdir())
            print(f'  {kind}: {item.name}/ ({sub_count} items)')
        else:
            size = item.stat().st_size / 1024
            print(f'  {kind}: {item.name} — {size:.1f} KB')
    
    if len(contents) > 30:
        print(f'  ... and {len(contents) - 30} more items')
    
    # Count all image and mask files
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
    all_files = list(DATA_ROOT.rglob('*'))
    all_images = [f for f in all_files if f.suffix.lower() in image_extensions]
    
    print(f'\nTotal files: {len(all_files)}')
    print(f'Image/mask files: {len(all_images)}')
    
    # Show directory tree (2 levels)
    print(f'\nDirectory tree:')
    for d in sorted(DATA_ROOT.rglob('*')):
        if d.is_dir():
            depth = len(d.relative_to(DATA_ROOT).parts)
            if depth <= 2:
                indent = '  ' * depth
                sub_files = sum(1 for f in d.iterdir() if f.is_file())
                sub_dirs = sum(1 for f in d.iterdir() if f.is_dir())
                print(f'{indent}{d.name}/ ({sub_files} files, {sub_dirs} subdirs)')
else:
    print('Data root does not exist. Follow setup instructions above.')

### Build Image-Mask Pairs

We need to match each image to its corresponding segmentation mask. The pairing logic depends on the directory structure — images and masks usually share the same filename but live in different folders (e.g., `images/001.png` ↔ `masks/001.png`).

In [None]:
def find_image_mask_pairs(data_root):
    """Find image-mask pairs in the dataset.
    
    Tries multiple common structures:
    1. images/ + masks/ (or labels/, annotations/)
    2. country/images/ + country/masks/
    3. Paired by filename across directories
    
    Returns: list of dicts with image_path, mask_path, country (if available)
    """
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}
    mask_dir_names = {'masks', 'mask', 'labels', 'label', 'annotations', 'annotation', 'gt', 'groundtruth'}
    image_dir_names = {'images', 'image', 'img', 'rgb', 'input'}
    
    pairs = []
    
    # Strategy 1: Top-level images/ + masks/
    for img_dir_name in image_dir_names:
        img_dir = data_root / img_dir_name
        if not img_dir.exists():
            continue
        for mask_dir_name in mask_dir_names:
            mask_dir = data_root / mask_dir_name
            if not mask_dir.exists():
                continue
            
            for img_file in sorted(img_dir.rglob('*')):
                if img_file.suffix.lower() not in image_extensions:
                    continue
                # Try matching mask with same stem
                for ext in image_extensions:
                    mask_candidate = mask_dir / (img_file.stem + ext)
                    if mask_candidate.exists():
                        pairs.append({
                            'image_path': str(img_file),
                            'mask_path': str(mask_candidate),
                            'country': 'unknown',
                        })
                        break
    
    if pairs:
        return pairs
    
    # Strategy 2: country/images/ + country/masks/
    for country_dir in sorted(data_root.iterdir()):
        if not country_dir.is_dir():
            continue
        
        for img_dir_name in image_dir_names:
            img_dir = country_dir / img_dir_name
            if not img_dir.exists():
                continue
            for mask_dir_name in mask_dir_names:
                mask_dir = country_dir / mask_dir_name
                if not mask_dir.exists():
                    continue
                
                for img_file in sorted(img_dir.rglob('*')):
                    if img_file.suffix.lower() not in image_extensions:
                        continue
                    for ext in image_extensions:
                        mask_candidate = mask_dir / (img_file.stem + ext)
                        if mask_candidate.exists():
                            pairs.append({
                                'image_path': str(img_file),
                                'mask_path': str(mask_candidate),
                                'country': country_dir.name,
                            })
                            break
    
    if pairs:
        return pairs
    
    # Strategy 3: Find all images and try to match with nearby masks
    all_images = sorted(f for f in data_root.rglob('*') 
                       if f.suffix.lower() in image_extensions)
    
    # Group by parent directory
    by_dir = {}
    for img in all_images:
        by_dir.setdefault(str(img.parent), []).append(img)
    
    # For each directory, try to find a sibling mask directory
    for dir_path, imgs in by_dir.items():
        dir_p = Path(dir_path)
        parent = dir_p.parent
        dir_lower = dir_p.name.lower()
        
        # Skip if this IS a mask directory
        if dir_lower in mask_dir_names:
            continue
        
        # Look for sibling mask directory
        for mask_dir_name in mask_dir_names:
            mask_dir = parent / mask_dir_name
            if mask_dir.exists():
                for img_file in imgs:
                    for ext in image_extensions:
                        mask_candidate = mask_dir / (img_file.stem + ext)
                        if mask_candidate.exists():
                            # Try to determine country from path
                            rel = img_file.relative_to(data_root)
                            country = rel.parts[0] if len(rel.parts) > 2 else 'unknown'
                            pairs.append({
                                'image_path': str(img_file),
                                'mask_path': str(mask_candidate),
                                'country': country,
                            })
                            break
    
    return pairs


# Build pairs
if DATA_ROOT.exists():
    pairs = find_image_mask_pairs(DATA_ROOT)
    df = pd.DataFrame(pairs)
    
    print(f'Found {len(df)} image-mask pairs')
    if len(df) > 0:
        print(f'\nCountry distribution:')
        print(df['country'].value_counts().to_string())
        print(f'\nSample pairs:')
        display(df.head(10))
    else:
        print('No pairs found. Check the directory structure above.')
        print('You may need to adjust the pairing logic for this dataset version.')
else:
    df = None
    print('Data root not found. Cannot load pairs.')

---
## 3. Mask Analysis

### Understanding Segmentation Masks

Each mask is an image where pixel values represent class IDs (0-5). Let's verify the mask values match our expected class definitions and measure the class distribution across the entire dataset.

In [None]:
if df is not None and len(df) > 0:
    # Analyze class distribution across ALL masks
    # (This may take a few minutes for 3K+ masks)
    
    class_pixel_counts = np.zeros(NUM_CLASSES, dtype=np.int64)
    images_with_class = np.zeros(NUM_CLASSES, dtype=np.int64)  # How many images contain each class
    total_pixels = 0
    mask_values_seen = set()
    errors = []
    
    print(f'Analyzing {len(df)} masks (this may take a minute)...')
    
    for i, (_, row) in enumerate(df.iterrows()):
        try:
            mask = np.array(Image.open(row['mask_path']))
            
            # Track unique values
            unique_vals = np.unique(mask)
            mask_values_seen.update(unique_vals.tolist())
            
            # Count pixels per class
            for cls_id in range(NUM_CLASSES):
                count = np.sum(mask == cls_id)
                class_pixel_counts[cls_id] += count
                if count > 0:
                    images_with_class[cls_id] += 1
            
            total_pixels += mask.size
            
        except Exception as e:
            errors.append((row['mask_path'], str(e)))
        
        if (i + 1) % 500 == 0:
            print(f'  Processed {i + 1}/{len(df)} masks...')
    
    print(f'\nDone! Analyzed {len(df) - len(errors)} masks ({len(errors)} errors)')
    print(f'\nUnique mask values seen: {sorted(mask_values_seen)}')
    print(f'Expected class IDs: {list(range(NUM_CLASSES))}')
    
    # Unexpected values?
    expected = set(range(NUM_CLASSES))
    unexpected = mask_values_seen - expected
    if unexpected:
        print(f'Unexpected mask values: {sorted(unexpected)}')
        print('These may represent additional classes or encoding artifacts.')
    
    # Class pixel distribution
    print(f'\n=== Class Pixel Distribution ===')
    print(f'Total pixels analyzed: {total_pixels:,}')
    for cls_id in range(NUM_CLASSES):
        pct = class_pixel_counts[cls_id] / total_pixels * 100
        img_pct = images_with_class[cls_id] / len(df) * 100
        print(f'  {RICESEG_CLASSES[cls_id]:25s}: {class_pixel_counts[cls_id]:>12,} pixels ({pct:5.2f}%) — in {images_with_class[cls_id]:,} images ({img_pct:.1f}%)')
else:
    print('No data to analyze.')

In [None]:
# Plot class pixel distribution
if df is not None and len(df) > 0 and total_pixels > 0:
    class_names_list = [RICESEG_CLASSES[i] for i in range(NUM_CLASSES)]
    class_pcts = class_pixel_counts / total_pixels * 100
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Bar chart
    bar_colors = ['#333333', '#2ca02c', '#bcbd22', '#ff7f0e', '#d62728', '#1f77b4']
    bars = axes[0].barh(class_names_list, class_pcts, color=bar_colors)
    axes[0].set_xlabel('Pixel Percentage (%)')
    axes[0].set_title('Class Distribution (by pixel count)')
    for bar, pct in zip(bars, class_pcts):
        axes[0].text(bar.get_width() + 0.2, bar.get_y() + bar.get_height()/2,
                     f'{pct:.2f}%', va='center', fontsize=10)
    
    # Pie chart (exclude background for clarity)
    fg_counts = class_pixel_counts[1:]  # Skip background
    fg_names = class_names_list[1:]
    fg_colors = bar_colors[1:]
    axes[1].pie(fg_counts, labels=fg_names, colors=fg_colors,
                autopct='%1.1f%%', startangle=90, textprops={'fontsize': 10})
    axes[1].set_title('Foreground Class Proportions\n(excluding background)')
    
    plt.suptitle('RiceSEG — Class Pixel Distribution', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Compute class weights (inverse frequency, for training)
    # Exclude zero-count classes to avoid division by zero
    nonzero = class_pixel_counts > 0
    class_weights = np.zeros(NUM_CLASSES)
    if nonzero.any():
        freq = class_pixel_counts[nonzero] / total_pixels
        weights = 1.0 / freq
        weights = weights / weights.mean()  # Normalize so mean weight = 1
        class_weights[nonzero] = weights
    
    print('\n=== Recommended Class Weights (for weighted loss) ===')
    for cls_id in range(NUM_CLASSES):
        print(f'  {RICESEG_CLASSES[cls_id]:25s}: weight = {class_weights[cls_id]:.2f}')
    
    weed_weight = class_weights[4] if 4 < len(class_weights) else 0
    print(f'\nWeed class weight: {weed_weight:.1f}x — this is how much the loss should penalize weed misclassifications')

In [None]:
# Per-country distribution
if df is not None and len(df) > 0 and df['country'].nunique() > 1:
    countries = sorted(df['country'].unique())
    country_stats = {}
    
    for country in countries:
        country_df = df[df['country'] == country]
        country_pixels = np.zeros(NUM_CLASSES, dtype=np.int64)
        total = 0
        
        for _, row in country_df.iterrows():
            try:
                mask = np.array(Image.open(row['mask_path']))
                for cls_id in range(NUM_CLASSES):
                    country_pixels[cls_id] += np.sum(mask == cls_id)
                total += mask.size
            except Exception:
                pass
        
        if total > 0:
            country_stats[country] = {
                'count': len(country_df),
                'class_pcts': country_pixels / total * 100,
            }
    
    if country_stats:
        # Plot per-country class distribution
        fig, ax = plt.subplots(figsize=(14, 6))
        
        x = np.arange(len(country_stats))
        width = 0.12
        bar_colors = ['#333333', '#2ca02c', '#bcbd22', '#ff7f0e', '#d62728', '#1f77b4']
        
        for cls_id in range(NUM_CLASSES):
            values = [country_stats[c]['class_pcts'][cls_id] for c in country_stats]
            ax.bar(x + cls_id * width, values, width, 
                   label=RICESEG_CLASSES[cls_id], color=bar_colors[cls_id])
        
        ax.set_xlabel('Country')
        ax.set_ylabel('Pixel Percentage (%)')
        ax.set_title('Class Distribution by Country')
        ax.set_xticks(x + width * (NUM_CLASSES - 1) / 2)
        ax.set_xticklabels([f'{c}\n(n={country_stats[c]["count"]})' for c in country_stats])
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        plt.tight_layout()
        plt.show()
        
        # Highlight Philippines (most relevant)
        if 'Philippines' in country_stats:
            ph = country_stats['Philippines']
            print(f'\n=== Philippines Subset (most relevant for Indonesia) ===')
            print(f'Images: {ph["count"]}')
            for cls_id in range(NUM_CLASSES):
                print(f'  {RICESEG_CLASSES[cls_id]:25s}: {ph["class_pcts"][cls_id]:.2f}%')
else:
    print('Single country or no country metadata — skipping per-country analysis.')

### Interpretation — Class Imbalance

The weed pixel distribution reveals a key training challenge:

**Weed pixels are very sparse** (~1-2% of total pixels). This means:
- A naive model can achieve >98% pixel accuracy by predicting "not weed" everywhere
- Standard cross-entropy loss will ignore weeds in favor of majority classes
- **Focal loss** or **heavily weighted cross-entropy** is critical

**Training implications:**
- Use focal loss (gamma=2) or weighted CE with weed class weight 10-20x
- Consider oversampling images that actually contain weed pixels
- Monitor **weed-class IoU** separately — overall IoU will be misleading
- The Philippines subset may have different weed proportions — evaluate separately

In [None]:
# Analyze which images contain weed pixels (class 4 and 5)
if df is not None and len(df) > 0:
    weed_info = []
    
    for _, row in df.iterrows():
        try:
            mask = np.array(Image.open(row['mask_path']))
            weed_pixels = np.sum(mask == 4)
            duckweed_pixels = np.sum(mask == 5)
            total = mask.size
            
            weed_info.append({
                'image_path': row['image_path'],
                'mask_path': row['mask_path'],
                'country': row['country'],
                'weed_pixels': weed_pixels,
                'duckweed_pixels': duckweed_pixels,
                'weed_pct': weed_pixels / total * 100,
                'duckweed_pct': duckweed_pixels / total * 100,
                'any_weed': weed_pixels > 0 or duckweed_pixels > 0,
            })
        except Exception:
            pass
    
    weed_df = pd.DataFrame(weed_info)
    
    with_weed = weed_df['any_weed'].sum()
    total_imgs = len(weed_df)
    
    print(f'=== Weed Pixel Analysis ===')
    print(f'Total images: {total_imgs}')
    print(f'Images WITH any weed/duckweed pixels: {with_weed} ({with_weed/total_imgs*100:.1f}%)')
    print(f'Images WITHOUT weed pixels: {total_imgs - with_weed} ({(total_imgs - with_weed)/total_imgs*100:.1f}%)')
    
    # Distribution of weed pixel area
    weed_present = weed_df[weed_df['any_weed']]
    if len(weed_present) > 0:
        print(f'\nAmong images WITH weeds:')
        print(f'  Mean weed area: {weed_present["weed_pct"].mean():.2f}%')
        print(f'  Max weed area:  {weed_present["weed_pct"].max():.2f}%')
        print(f'  Mean duckweed area: {weed_present["duckweed_pct"].mean():.2f}%')
        print(f'  Max duckweed area:  {weed_present["duckweed_pct"].max():.2f}%')
        
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        axes[0].hist(weed_present['weed_pct'], bins=30, color='#d62728', edgecolor='white', alpha=0.8)
        axes[0].set_xlabel('Weed Pixel Percentage (%)')
        axes[0].set_ylabel('Number of Images')
        axes[0].set_title(f'Weed Pixel Area Distribution\n(n={len(weed_present)} images with weeds)')
        
        axes[1].hist(weed_present['duckweed_pct'], bins=30, color='#1f77b4', edgecolor='white', alpha=0.8)
        axes[1].set_xlabel('Duckweed Pixel Percentage (%)')
        axes[1].set_ylabel('Number of Images')
        axes[1].set_title(f'Duckweed Pixel Area Distribution\n(n={len(weed_present)} images with duckweed)')
        
        plt.suptitle('Weed Pixel Area Distribution (images with weeds only)', fontsize=13, fontweight='bold')
        plt.tight_layout()
        plt.show()

---
## 4. Sample Visualization

Let's inspect image-mask pairs visually. We'll show:
- Original image
- Segmentation mask (class IDs as grayscale)
- Color-coded overlay (mask on top of image)

We prioritize samples that **contain weed/duckweed pixels** — since most images have none, random sampling would show mostly empty masks.

In [None]:
def create_overlay(image, mask, class_colors, alpha=0.5):
    """Create a color-coded overlay of the mask on the image."""
    img_arr = np.array(image.convert('RGB')).astype(np.float32)
    overlay = np.zeros((*mask.shape, 4), dtype=np.float32)
    
    for cls_id, color in class_colors.items():
        cls_mask = mask == cls_id
        if cls_mask.any():
            overlay[cls_mask] = [c / 255.0 for c in color]
    
    # Blend: where overlay has alpha, mix image and overlay color
    result = img_arr.copy()
    for c in range(3):
        mask_alpha = overlay[:, :, 3]
        result[:, :, c] = (1 - mask_alpha * alpha) * img_arr[:, :, c] + mask_alpha * alpha * overlay[:, :, c] * 255
    
    return np.clip(result, 0, 255).astype(np.uint8)


if df is not None and len(df) > 0:
    # Select samples WITH weed pixels (more interesting)
    if 'weed_df' in dir() and len(weed_df[weed_df['any_weed']]) > 0:
        weed_samples = weed_df[weed_df['any_weed']].nlargest(12, 'weed_pct')
    else:
        weed_samples = df.sample(n=min(12, len(df)), random_state=42)
    
    n_samples = min(4, len(weed_samples))
    fig, axes = plt.subplots(n_samples, 3, figsize=(15, 4 * n_samples))
    if n_samples == 1:
        axes = axes.reshape(1, -1)
    
    for row_idx, (_, sample) in enumerate(weed_samples.head(n_samples).iterrows()):
        img_path = sample.get('image_path', sample.get('image_path'))
        mask_path = sample.get('mask_path', sample.get('mask_path'))
        
        img = Image.open(img_path).convert('RGB')
        mask = np.array(Image.open(mask_path))
        
        # Original image
        axes[row_idx, 0].imshow(img)
        axes[row_idx, 0].set_title('Image', fontsize=10)
        axes[row_idx, 0].axis('off')
        
        # Mask (class IDs)
        axes[row_idx, 1].imshow(mask, cmap='tab10', vmin=0, vmax=NUM_CLASSES - 1)
        axes[row_idx, 1].set_title('Mask (class IDs)', fontsize=10)
        axes[row_idx, 1].axis('off')
        
        # Color overlay
        overlay = create_overlay(img, mask, CLASS_COLORS)
        axes[row_idx, 2].imshow(overlay)
        country = sample.get('country', '?')
        weed_pct = sample.get('weed_pct', 0)
        axes[row_idx, 2].set_title(f'Overlay — {country} (weed: {weed_pct:.1f}%)', fontsize=10)
        axes[row_idx, 2].axis('off')
    
    plt.suptitle('RiceSEG — Images with Highest Weed Content', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    # Legend
    print('Color legend:')
    for cls_id, name in RICESEG_CLASSES.items():
        r, g, b, _ = CLASS_COLORS[cls_id]
        print(f'  Class {cls_id}: {name} — RGB({r}, {g}, {b})')

In [None]:
# Show Philippines-only subset (most relevant to Indonesia)
if df is not None and len(df) > 0 and 'country' in df.columns:
    ph_df = df[df['country'].str.lower().str.contains('phil', na=False)]
    
    if len(ph_df) > 0:
        print(f'Philippines subset: {len(ph_df)} images')
        
        # Find Philippines images with weeds
        if 'weed_df' in dir():
            ph_weed = weed_df[weed_df['country'].str.lower().str.contains('phil', na=False) & weed_df['any_weed']]
        else:
            ph_weed = ph_df.sample(n=min(4, len(ph_df)), random_state=42)
        
        n_ph = min(4, len(ph_weed))
        if n_ph > 0:
            fig, axes = plt.subplots(n_ph, 3, figsize=(15, 4 * n_ph))
            if n_ph == 1:
                axes = axes.reshape(1, -1)
            
            for row_idx, (_, sample) in enumerate(ph_weed.head(n_ph).iterrows()):
                img = Image.open(sample['image_path']).convert('RGB')
                mask = np.array(Image.open(sample['mask_path']))
                overlay = create_overlay(img, mask, CLASS_COLORS)
                
                axes[row_idx, 0].imshow(img)
                axes[row_idx, 0].set_title('Image', fontsize=10)
                axes[row_idx, 0].axis('off')
                
                axes[row_idx, 1].imshow(mask, cmap='tab10', vmin=0, vmax=NUM_CLASSES - 1)
                axes[row_idx, 1].set_title('Mask', fontsize=10)
                axes[row_idx, 1].axis('off')
                
                axes[row_idx, 2].imshow(overlay)
                axes[row_idx, 2].set_title('Overlay', fontsize=10)
                axes[row_idx, 2].axis('off')
            
            plt.suptitle('Philippines Subset — Most Relevant to Indonesian Conditions', 
                         fontsize=14, fontweight='bold')
            plt.tight_layout()
            plt.show()
        else:
            print('No Philippines images with weed pixels found.')
    else:
        print('No Philippines subset found in the country metadata.')
        print(f'Available countries: {df["country"].unique().tolist()}')

### Visual Observations Checklist

After looking at the samples above, note:

- [ ] **Mask quality:** Are boundaries between classes smooth or jagged?
- [ ] **Weed appearance:** What do weeds look like in these images? Color, shape, texture?
- [ ] **Mask alignment:** Do masks align well with visible features in the images?
- [ ] **Philippines subset:** How do Philippines images compare to what you'd expect in Indonesian rice fields?
- [ ] **Duckweed vs weeds:** Can you visually distinguish duckweed (floating, green) from terrestrial weeds?

In [None]:
# Image dimension analysis
if df is not None and len(df) > 0:
    sample_check = df.sample(n=min(200, len(df)), random_state=42)
    
    widths, heights, mask_shapes = [], [], []
    file_sizes = []
    
    for _, row in sample_check.iterrows():
        try:
            img = Image.open(row['image_path'])
            w, h = img.size
            widths.append(w)
            heights.append(h)
            file_sizes.append(Path(row['image_path']).stat().st_size / 1024)
            
            mask = Image.open(row['mask_path'])
            mask_shapes.append(mask.size)
        except Exception:
            pass
    
    print(f'=== Image Properties (sample of {len(widths)}) ===')
    print(f'Width  — min: {min(widths)}, max: {max(widths)}, unique: {len(set(widths))}')
    print(f'Height — min: {min(heights)}, max: {max(heights)}, unique: {len(set(heights))}')
    print(f'File size — min: {min(file_sizes):.1f} KB, max: {max(file_sizes):.1f} KB, mean: {np.mean(file_sizes):.1f} KB')
    
    # Verify mask dimensions match image dimensions
    mismatched = sum(1 for i, ms in enumerate(mask_shapes) 
                     if ms != (widths[i], heights[i]))
    print(f'\nMask-image dimension mismatches: {mismatched}')
    if mismatched == 0:
        print('All masks match their image dimensions.')

In [None]:
# Mask value analysis
if df is not None and len(df) > 0:
    check_sample = df.sample(n=min(50, len(df)), random_state=42)
    
    all_values = set()
    value_counts = Counter()
    
    for _, row in check_sample.iterrows():
        try:
            mask = np.array(Image.open(row['mask_path']))
            vals = np.unique(mask)
            all_values.update(vals.tolist())
            for v in vals:
                value_counts[v] += np.sum(mask == v)
        except Exception:
            pass
    
    print(f'=== Mask Value Analysis (sample of {len(check_sample)}) ===')
    print(f'Unique values: {sorted(all_values)}')
    print(f'\nValue → Class mapping:')
    for val in sorted(all_values):
        name = RICESEG_CLASSES.get(val, f'UNKNOWN (value={val})')
        pct = value_counts[val] / sum(value_counts.values()) * 100
        print(f'  {val} → {name} ({pct:.2f}%)')
    
    # Check mask dtype
    sample_mask = np.array(Image.open(df.iloc[0]['mask_path']))
    print(f'\nMask dtype: {sample_mask.dtype}')
    print(f'Mask shape: {sample_mask.shape}')
    print(f'Expected: uint8, (512, 512) or similar single-channel')

In [None]:
# File size distribution
if df is not None and len(df) > 0 and file_sizes:
    fig, axes = plt.subplots(1, 3, figsize=(16, 4))
    
    axes[0].hist(widths, bins=20, color='steelblue', edgecolor='white')
    axes[0].set_title('Image Width Distribution')
    axes[0].set_xlabel('Width (px)')
    
    axes[1].hist(heights, bins=20, color='coral', edgecolor='white')
    axes[1].set_title('Image Height Distribution')
    axes[1].set_xlabel('Height (px)')
    
    axes[2].hist(file_sizes, bins=30, color='mediumseagreen', edgecolor='white')
    axes[2].set_title('File Size Distribution')
    axes[2].set_xlabel('Size (KB)')
    
    plt.suptitle('RiceSEG — Image Properties', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.show()

In [None]:
# Summary configuration for training
if df is not None and len(df) > 0:
    summary = {
        'dataset': 'RiceSEG',
        'task': 'segmentation',
        'total_images': len(df),
        'num_classes': NUM_CLASSES,
        'class_names': RICESEG_CLASSES,
        'image_size': f'{min(widths)}x{min(heights)}' if len(set(widths)) == 1 and len(set(heights)) == 1 else 'varies',
        'platform': PLATFORM,
    }
    
    if 'class_weights' in dir():
        summary['class_weights'] = {RICESEG_CLASSES[i]: round(class_weights[i], 2) for i in range(NUM_CLASSES)}
    
    if 'weed_df' in dir():
        summary['images_with_weeds'] = int(weed_df['any_weed'].sum())
        summary['weed_pixel_pct'] = round(class_pixel_counts[4] / total_pixels * 100, 3) if total_pixels > 0 else 0
    
    print('=== RiceSEG — Training Configuration Summary ===')
    print(json.dumps({k: v for k, v in summary.items() if k != 'class_names'}, indent=2))
    
    print(f'\n=== Recommended Training Setup for DeepLabV3+ ===')
    print(f'Input size: 512x512 (native resolution — no resize needed)')
    print(f'Backbone: ResNet-50 (pretrained ImageNet)')
    print(f'Loss: Focal loss (gamma=2) or weighted cross-entropy')
    print(f'Weed class weight: 10-20x (compensate for sparse weed pixels)')
    print(f'Augmentation: flip, rotate, color jitter, random crop 384x384')
    print(f'Normalize to: ImageNet stats (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])')

---
## 5. Training Recommendations for DeepLabV3+

### Model Configuration

| Parameter | Recommended Value | Reasoning |
|-----------|-------------------|-----------|
| **Architecture** | DeepLabV3+ | SOTA for semantic segmentation, good balance of accuracy and speed |
| **Backbone** | ResNet-50 (ImageNet pretrained) | Enough capacity for 6 classes, not too heavy for Kaggle GPU |
| **Input size** | 512x512 | Native RiceSEG resolution — no information loss |
| **Loss function** | Focal loss (gamma=2) | Handles extreme class imbalance (weeds ~1.6%) |
| **Class weights** | See summary above | Weed/duckweed classes need 10-20x weight |

### Augmentation Strategy

| Transform | Parameters | Why |
|-----------|-----------|-----|
| HorizontalFlip | p=0.5 | Rice fields look same mirrored |
| VerticalFlip | p=0.5 | Aerial/top-down views are rotation-invariant |
| RandomRotate90 | p=0.5 | Further rotation invariance |
| ColorJitter | brightness=0.2, contrast=0.2 | Handle varying lighting conditions |
| RandomCrop | 384x384 from 512x512 | Augment while keeping enough context |

### Expected Baseline Performance

| Metric | Expected Range | Notes |
|--------|---------------|-------|
| Overall mIoU | 40-55% | Averaged across all 6 classes |
| Weed IoU | 10-30% | Sparse class, hardest to predict |
| Background IoU | 85-95% | Dominant class, easy to predict |
| Overall pixel accuracy | >90% | Misleading due to class imbalance |

### Training Strategy

1. **Phase 1:** Freeze backbone, train decoder (5-10 epochs, lr=1e-3)
2. **Phase 2:** Unfreeze all, fine-tune (10-20 epochs, lr=1e-4)
3. **Evaluate on:** Weed IoU specifically, not just overall metrics
4. **Consider:** Philippines-only subset for faster iteration (smaller but most relevant)

---
## 6. What's Next

| Next Step | Notebook | What It Does |
|-----------|----------|-------------|
| Train segmentation model | `03-segmentation-baseline.ipynb` (or `05-*`) | Train DeepLabV3+ on RiceSEG |
| Classification baseline | `04-classification-baseline.ipynb` | Train EfficientNetV2-S on Crop & Weed or Bangladesh data |

**Key inputs from this notebook:**
- Class weights for weighted loss function
- Knowledge that weed pixels are sparse (~1.6%)
- Philippines subset identified as most relevant
- 512x512 native resolution confirmed
- Mask format verified (single-channel, integer class IDs)