In [1]:
from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
from data_manager import create_modified_crop_labels
from tqdm import tqdm

In [2]:
class CropDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data.astype(float)
        self.transform = transform
        
        # Fixed mapping for known labels
        self.label_map = {
            -1: 0,  # background
            1: 1,   # corn
            5: 2,   # soybean
            23: 3,  # spring wheat
            176: 4  # grassland/pasture
        }
        self.num_classes = 5  # 5 classes including background
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get the image and label
        image = self.data[idx, :, :, :-1]  # All bands except last one (label)
        label = self.data[idx, :, :, -1]   # Last band is the label
        
        # Scale first 18 bands by 0.0001 and clip to [0,1]
        image[:, :, :18] = np.clip(image[:, :, :18] * 0.0001, 0, 1)
        
        # Convert to torch tensors
        image = torch.from_numpy(image).float()
        
        # Map labels to 0 to 4 range
        label = np.vectorize(self.label_map.get)(label)
        label = torch.from_numpy(label).long()
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [3]:
train_data = np.load('./training_data/train_patches.npy')
valid_data = np.load('./training_data/val_patches.npy')
test_data = np.load('./training_data/test_patches.npy')

unchanged_crops = [1, 5, 23, 176]
train_data = create_modified_crop_labels(train_data, unchanged_crops=unchanged_crops)
valid_data = create_modified_crop_labels(valid_data, unchanged_crops=unchanged_crops)
test_data = create_modified_crop_labels(test_data, unchanged_crops=unchanged_crops)


# Create datasets
train_dataset = CropDataset(train_data)
val_dataset = CropDataset(valid_data)
test_dataset = CropDataset(test_data)

# Print number of classes
print(f"Number of classes: {train_dataset.num_classes}")
print(f"Label mapping: {train_dataset.label_map}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

# Print dataset sizes and sample shapes
print(f"\nTraining samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Print shape of a single sample
sample_image, sample_label = next(iter(train_loader))
print(f"\nImage shape: {sample_image.shape}")
print(f"Label shape: {sample_label.shape}")
print(f"Unique labels in sample: {torch.unique(sample_label)}")

Number of classes: 5
Label mapping: {-1: 0, 1: 1, 5: 2, 23: 3, 176: 4}

Training samples: 1074
Validation samples: 231
Test samples: 231

Image shape: torch.Size([16, 224, 224, 18])
Label shape: torch.Size([16, 224, 224])
Unique labels in sample: tensor([0, 1, 2, 3, 4])


In [None]:
# Import Segformer
from transformers import SegformerForSemanticSegmentation
from transformers import SegformerImageProcessor

# Initialize model and processor
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                       num_labels=5,
                                                       ignore_mismatched_sizes=True)
processor = SegformerImageProcessor.from_pretrained("nvidia/mit-b0")

# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# Training function
def train_epoch(model, train_loader, processor, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total_pixels = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for images, labels in pbar:
        # Process images for Segformer
        inputs = processor(images=images.permute(0, 3, 1, 2), return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate accuracy
        logits = outputs.logits
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total_pixels += labels.numel()
        total_loss += loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{(predicted == labels).float().mean().item():.4f}'
        })
    
    return total_loss / len(train_loader), correct / total_pixels

# Validation function
def validate(model, val_loader, processor, device):
    model.eval()
    total_loss = 0
    correct = 0
    total_pixels = 0
    
    pbar = tqdm(val_loader, desc='Validation')
    with torch.no_grad():
        for images, labels in pbar:
            # Process images for Segformer
            inputs = processor(images=images.permute(0, 3, 1, 2), return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            
            # Calculate accuracy
            logits = outputs.logits
            _, predicted = torch.max(logits, 1)
            correct += (predicted == labels).sum().item()
            total_pixels += labels.numel()
            total_loss += loss.item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{(predicted == labels).float().mean().item():.4f}'
            })
    
    return total_loss / len(val_loader), correct / total_pixels

# Training loop
num_epochs = 100
best_val_acc = 0.0

# Add tqdm for epochs
epoch_pbar = tqdm(range(num_epochs), desc='Epochs')
for epoch in epoch_pbar:
    # Training
    train_loss, train_acc = train_epoch(model, train_loader, processor, optimizer, device)
    
    # Validation
    val_loss, val_acc = validate(model, val_loader, processor, device)
    
    # Update learning rate
    scheduler.step()
    
    # Update epoch progress bar
    epoch_pbar.set_postfix({
        'train_loss': f'{train_loss:.4f}',
        'train_acc': f'{train_acc:.4f}',
        'val_loss': f'{val_loss:.4f}',
        'val_acc': f'{val_acc:.4f}'
    })
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_segformer_model.pth')
        print(f'\nNew best model saved with validation accuracy: {val_acc:.4f}')

# Load best model for testing
model.load_state_dict(torch.load('best_segformer_model.pth'))

# Test the model
test_loss, test_acc = validate(model, test_loader, processor, device)
print(f'\nTest Results:')
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')

In [5]:
# Import Segformer
from transformers import SegformerForSemanticSegmentation
from transformers import SegformerImageProcessor

# Initialize model and processor
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                       num_labels=5,
                                                       ignore_mismatched_sizes=True)


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
print(model)

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

In [9]:
from transformers import SegformerForSemanticSegmentation
import torch.nn as nn

# Load the model with ignore_mismatched_sizes so we can safely override the input
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b0",
    num_labels=5,
    ignore_mismatched_sizes=True
)

# Modify the first conv layer to accept 18 input channels instead of 3
old_proj = model.segformer.encoder.patch_embeddings[0].proj
new_proj = nn.Conv2d(
    in_channels=18,
    out_channels=old_proj.out_channels,
    kernel_size=old_proj.kernel_size,
    stride=old_proj.stride,
    padding=old_proj.padding,
    bias=old_proj.bias is not None
)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
new_proj.weight

Parameter containing:
tensor([[[[-2.9252e-02,  9.5352e-03, -1.6305e-02,  ...,  1.8966e-02,
           -1.9608e-02, -1.4077e-02],
          [-2.6893e-02, -3.3393e-02,  1.9009e-02,  ..., -2.4639e-02,
           -3.2623e-02, -4.6976e-04],
          [-2.9086e-02, -1.0357e-02, -8.1168e-04,  ..., -2.6713e-03,
           -6.8931e-03,  2.3142e-02],
          ...,
          [-5.4773e-04,  2.1291e-02, -2.4563e-03,  ...,  3.0521e-02,
            9.8139e-03,  2.0446e-02],
          [-2.1274e-02,  1.5798e-02, -2.9060e-02,  ..., -2.3121e-04,
           -9.5050e-03, -3.1537e-02],
          [ 7.4233e-03, -2.3940e-02,  3.0620e-02,  ..., -1.5762e-02,
            6.2191e-03,  6.4132e-03]],

         [[-9.2835e-04,  8.0656e-03, -2.7583e-02,  ...,  2.1031e-02,
           -1.3600e-02,  2.2036e-02],
          [ 6.3124e-03, -1.3604e-02, -2.3970e-02,  ..., -9.5724e-03,
           -1.2391e-02, -2.1033e-02],
          [-9.6284e-03, -2.3003e-04, -9.3021e-03,  ..., -8.9109e-03,
           -6.5988e-03,  2.7939e-02]