# Cannabis Grow Segmentation with TorchGeo

This notebook demonstrates how to train a semantic segmentation model to identify cannabis grows using NAIP imagery and polygon annotations. The model will learn to convert aerial imagery into binary masks that highlight cannabis cultivation areas.

## Overview

1. **Data**: We use NAIP (National Agriculture Imagery Program) aerial imagery and JSON polygon annotations
2. **Model**: U-Net architecture for semantic segmentation
3. **Training**: Binary cross-entropy loss with Adam optimizer

## Setup
First, let's import all necessary libraries

In [69]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchgeo.datasets import NAIP
from torchgeo.samplers import RandomGeoSampler
from torchgeo.transforms import AugmentationSequential
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, List, Tuple
import rasterio
from rasterio.features import rasterize
from shapely.geometry import shape, Polygon
import albumentations as A
from albumentations.pytorch import ToTensorV2

## Custom Dataset

We create a custom PyTorch Dataset class that:
1. Loads NAIP imagery using rasterio
2. Converts JSON polygon annotations into binary masks
3. Applies data augmentation transforms

The dataset expects:
- NAIP imagery in GeoTIFF format
- Annotations in JSON format with polygon coordinates
- Each JSON file should contain:
  - `imagePath`: Path to corresponding NAIP image
  - `imageHeight` and `imageWidth`: Image dimensions
  - `shapes`: List of polygons with `label` and coordinates

In [70]:
class CannabisSegmentationDataset(Dataset):
    def __init__(
        self,
        naip_root: str,
        annotations_dir: str,
        transform=None,
    ):
        self.naip_root = Path(naip_root)
        self.annotations_dir = Path(annotations_dir)
        self.transform = transform
        
        # Get list of annotation files
        self.annotation_files = list(self.annotations_dir.glob("*.json"))
        
    def __len__(self) -> int:
        return len(self.annotation_files)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # Load annotation file
        with open(self.annotation_files[idx], "r") as f:
            annotation = json.load(f)
        
        # Get corresponding NAIP image path
        image_path = self.naip_root / annotation["imagePath"]
        
        # Load NAIP image (4 channels)
        with rasterio.open(image_path) as src:
            image = src.read()  # This will read all 4 bands
            image = np.transpose(image, (1, 2, 0))  # CHW -> HWC
            
        # Create mask from polygons
        height, width = annotation["imageHeight"], annotation["imageWidth"]
        mask = np.zeros((height, width), dtype=np.uint8)
        
        for shape_data in annotation["shapes"]:
            if shape_data["label"] == "cannabis":
                # Convert points to a polygon
                points = shape_data["points"]
                if len(points) >= 3:  # Need at least 3 points for a valid polygon
                    polygon = Polygon(points)
                    if polygon.is_valid:
                        mask = rasterize(
                            [polygon],
                            out_shape=(height, width),
                            fill=0,
                            default_value=1,
                            dtype=np.uint8
                        )
        
        # Apply transforms if specified
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        
        # Make sure mask is float32 for BCE loss
        mask = mask.float() if isinstance(mask, torch.Tensor) else torch.tensor(mask, dtype=torch.float32)
        
        return {
            "image": image,
            "mask": mask.unsqueeze(0)  # Add channel dimension for binary mask
        }

## Data Augmentation

We define data augmentation transforms using the Albumentations library to:
1. Increase the effective size of our training dataset
2. Improve model generalization
3. Make the model robust to variations in:
   - Orientation (random rotations and flips)
   - Lighting conditions (brightness/contrast)

We also normalize the images using ImageNet mean and standard deviation values.

In [71]:
def get_transforms(image_size=(512, 512)):
    train_transform = A.Compose([
        A.Resize(height=image_size[0], width=image_size[1]),
        A.RandomRotate90(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        # Updated normalization for 4 channels (R,G,B,NIR)
        A.Normalize(
            mean=[0.485, 0.456, 0.406, 0.5],  # Approximate NIR mean
            std=[0.229, 0.224, 0.225, 0.225],  # Approximate NIR std
        ),
        ToTensorV2(),
    ])
    
    val_transform = A.Compose([
        A.Resize(height=image_size[0], width=image_size[1]),
        # Updated normalization for 4 channels (R,G,B,NIR)
        A.Normalize(
            mean=[0.485, 0.456, 0.406, 0.5],  # Approximate NIR mean
            std=[0.229, 0.224, 0.225, 0.225],  # Approximate NIR std
        ),
        ToTensorV2(),
    ])
    
    return train_transform, val_transform

# Model Setup

The setup_model function initializes our U-Net model with a ResNet34 backbone:

## 1. Encoder Path (Contracting):
- Uses pretrained ResNet34 as the encoder backbone
- ResNet34 provides strong feature extraction capabilities
- Pretrained weights from ImageNet help with initialization
- 4 levels of downsampling through ResNet34's layers
- Each level doubles the number of channels
- Each block: Conv2d -> BatchNorm -> ReLU -> Conv2d -> BatchNorm -> ReLU

## 2. Decoder Path (Expanding):
- 3 levels of upsampling
- Skip connections from encoder to decoder
- Each block: Transposed Conv2d -> Concatenate -> Double Conv

## 3. Output:
- 1x1 convolution to produce final segmentation map
- Sigmoid activation for binary segmentation

The model is designed for binary segmentation of cannabis grows in NAIP imagery, with the following configuration:
- Input: 3 channels (RGB)
- Output: 1 channel (binary mask)
- Uses pretrained ResNet34 encoder for feature extraction
- Uses skip connections to preserve spatial information
- Implements batch normalization for stable training

In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class UNetWithResNet(nn.Module):
    def __init__(self, n_channels=4, n_classes=1, pretrained=True):
        super().__init__()
        
        # Load pre-trained ResNet50
        resnet = models.resnet50(weights='IMAGENET1K_V2' if pretrained else None)
        
        # Modify first conv layer to accept 4 channels
        if n_channels != 3:
            new_conv1 = nn.Conv2d(
                n_channels, 
                64,
                kernel_size=7, 
                stride=2, 
                padding=3,
                bias=False
            )
            
            # Initialize new conv1 with weights from pre-trained model
            with torch.no_grad():
                new_conv1.weight.zero_()
                new_conv1.weight[:,:3,:,:] = resnet.conv1.weight
                new_conv1.weight[:,3:,:,:] = resnet.conv1.weight.mean(dim=1, keepdim=True)
            
            resnet.conv1 = new_conv1
        
        # Encoder layers
        self.firstconv = resnet.conv1      # 64 channels
        self.firstbn = resnet.bn1
        self.firstrelu = resnet.relu
        self.firstmaxpool = resnet.maxpool
        self.encoder1 = resnet.layer1      # 256 channels
        self.encoder2 = resnet.layer2      # 512 channels
        self.encoder3 = resnet.layer3      # 1024 channels
        self.encoder4 = resnet.layer4      # 2048 channels
        
        # Center convolution
        self.center = nn.Sequential(
            nn.Conv2d(2048, 2048, kernel_size=3, padding=1),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True)
        )
        
        # Decoder layers with correct channel numbers
        self.decoder4 = DecoderBlock(2048, 1024, 1024)  # input: 2048+1024, output: 1024
        self.decoder3 = DecoderBlock(1024, 512, 512)    # input: 1024+512, output: 512
        self.decoder2 = DecoderBlock(512, 256, 256)     # input: 512+256, output: 256
        self.decoder1 = DecoderBlock(256, 64, 64)       # input: 256+64, output: 64
        
        # Final convolution
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, n_classes, kernel_size=1),
            nn.Sigmoid()
        )
        
        # Final upsampling to match input size
        self.final_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    
    def forward(self, x):
        # Save input size for final upsampling
        input_size = x.size()[2:]
        
        # Initial convolution
        x = self.firstconv(x)      # 64 channels
        x = self.firstbn(x)
        x = self.firstrelu(x)
        e0 = x                     # Save for skip connection
        x = self.firstmaxpool(x)
        
        # Encoder path
        e1 = self.encoder1(x)      # 256 channels
        e2 = self.encoder2(e1)     # 512 channels
        e3 = self.encoder3(e2)     # 1024 channels
        e4 = self.encoder4(e3)     # 2048 channels
        
        # Center
        e4 = self.center(e4)
        
        # Decoder path with skip connections
        d4 = self.decoder4(e4, e3)  # 1024 channels
        d3 = self.decoder3(d4, e2)  # 512 channels
        d2 = self.decoder2(d3, e1)  # 256 channels
        d1 = self.decoder1(d2, e0)  # 64 channels
        
        # Final convolution
        out = self.final_conv(d1)
        
        # Final upsampling to match input size
        out = self.final_up(out)
        
        # Ensure output size matches input size exactly
        if out.size()[2:] != input_size:
            out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=True)
        
        return out

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
        )
        
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels + skip_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x, skip):
        x = self.up(x)
        
        # Handle different spatial dimensions
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=True)
            
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x

## Training Loop

The training function handles:
1. Training and validation phases for each epoch
2. Loss calculation and backpropagation
3. Model evaluation
4. Progress tracking

We use:
- Binary Cross-Entropy Loss: Appropriate for binary segmentation
- Adam optimizer: Adaptive learning rate optimization
- Learning rate of 1e-4: Typically good for segmentation tasks

In [73]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    scheduler: optim.lr_scheduler._LRScheduler,
    num_epochs: int,
    device: torch.device
) -> Tuple[List[float], List[float]]:
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        epoch_loss = 0
        batch_count = 0
        
        for batch_idx, batch in enumerate(train_loader):
            images = batch["image"].to(device)
            masks = batch["mask"].to(device)
            
            # Print shapes for debugging
            if batch_idx == 0 and epoch == 0:
                print(f"\nInput shapes:")
                print(f"Images: {images.shape}")
                print(f"Masks: {masks.shape}")
            
            optimizer.zero_grad()
            outputs = model(images)
            
            # Print output shape for debugging
            if batch_idx == 0 and epoch == 0:
                print(f"Outputs: {outputs.shape}\n")
                
                # Verify shapes match
                assert outputs.shape == masks.shape, \
                    f"Shape mismatch: outputs {outputs.shape} vs masks {masks.shape}"
            
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            batch_count += 1
        
        train_loss = epoch_loss / batch_count
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_count = 0
        
        with torch.no_grad():
            for batch in val_loader:
                images = batch["image"].to(device)
                masks = batch["mask"].to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                val_count += 1
        
        val_loss = val_loss / val_count
        val_losses.append(val_loss)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        print("-" * 40)
    
    return train_losses, val_losses

## Main Training Script

This section puts everything together:
1. Sets up the device (GPU if available)
2. Configures data paths and transforms
3. Creates datasets and data loaders
4. Initializes the model, loss function, and optimizer
5. Trains the model
6. Visualizes training progress
7. Saves the trained model

Note: Adjust the paths (`naip_root` and `annotations_dir`) to match your data locations.

In [82]:
def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Set paths
    naip_root = "../cannabis-parcels/cannabis-parcels-masked"
    annotations_dir = "../cannabis-parcels/cannabis-parcels-masked"
    
    # Set image size - smaller size to handle memory constraints
    image_size = (512, 512)
    
    # Get transforms
    train_transform, val_transform = get_transforms(image_size=image_size)
    
    # Create datasets
    train_dataset = CannabisSegmentationDataset(
        naip_root=naip_root,
        annotations_dir=annotations_dir,
        transform=train_transform
    )
    
    val_dataset = CannabisSegmentationDataset(
        naip_root=naip_root,
        annotations_dir=annotations_dir,
        transform=val_transform
    )
    
    # Create data loaders with smaller batch size
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)  # Reduced from 4
    val_loader = DataLoader(val_dataset, batch_size=2)  # Reduced from 4
    
    # Initialize model with pre-trained ResNet backbone
    model = UNetWithResNet(
        n_channels=4,
        n_classes=1,
        pretrained=True
    ).to(device)
    
    # Set up loss function and optimizer
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-5)  # Reduced learning rate
    
    # Add learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=5, 
        verbose=True
    )
    
    # Train model
    train_losses, val_losses = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,  # Add scheduler to training
        num_epochs=50,
        device=device
    )
    
    # Plot training curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    
    # Save model
    torch.save(model.state_dict(), "cannabis_segmentation_model.pth")

In [83]:
if __name__ == "__main__":
    main() 

Using device: cpu





Input shapes:
Images: torch.Size([2, 4, 512, 512])
Masks: torch.Size([2, 1, 512, 512])
Outputs: torch.Size([2, 1, 512, 512])

Epoch 1/50
Train Loss: 0.7362
Val Loss: 0.7259
Learning Rate: 0.000010
----------------------------------------
Epoch 2/50
Train Loss: 0.7208
Val Loss: 0.7230
Learning Rate: 0.000010
----------------------------------------
Epoch 3/50
Train Loss: 0.7084
Val Loss: 0.7001
Learning Rate: 0.000010
----------------------------------------
Epoch 4/50
Train Loss: 0.6934
Val Loss: 0.6870
Learning Rate: 0.000010
----------------------------------------
Epoch 5/50
Train Loss: 0.6692
Val Loss: 0.6726
Learning Rate: 0.000010
----------------------------------------
Epoch 6/50
Train Loss: 0.6422
Val Loss: 0.6552
Learning Rate: 0.000010
----------------------------------------


KeyboardInterrupt: 