# **Load the libraries**

In [2]:
# Import essential libraries for deep learning and data handling
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
from google.colab import drive

# Libraries for image handling and visualization
!pip install rasterio
import matplotlib.pyplot as plt
import rasterio
from rasterio.windows import Window

# Additional utilities
import logging
import tqdm
from datetime import datetime
import itertools
import pickle

Collecting rasterio
  Downloading rasterio-1.3.10-cp310-cp310-manylinux2014_x86_64.whl (21.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.5/21.5 MB[0m [31m56.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Collecting snuggs>=1.4.1 (from rasterio)
  Downloading snuggs-1.4.7-py3-none-any.whl (5.4 kB)
Installing collected packages: snuggs, affine, rasterio
Successfully installed affine-2.4.0 rasterio-1.3.10 snuggs-1.4.7


# **Setup and Data Path Configuration for Colab**

In [6]:
# Setup for mounting Google Drive in Colab
from google.colab import drive
drive.mount('/content/gdrive')

# Directory where the data is stored
data_dir = "/content/gdrive/My Drive/adleo_my/final_project/data/"

# Paths to images, labels, and catalog
image_paths = list(Path(data_dir).glob("images/*.tif"))
label_paths = list(Path(data_dir).glob("labels/*.tif"))
catalog_paths = "/content/gdrive/MyDrive/adleo_my/final_project/data/label-catalog-filtered.csv"

# Ensure that paths are sorted and matched
image_paths.sort()
label_paths.sort()

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


## **Step 1: Definition of the ResUNet Architecture**


In [7]:
# Residual Block Class
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=None, upsample=None):
        super(ResBlock, self).__init__()
        self.upsample = upsample
        stride = 2 if downsample else 1

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Adjust the shortcut to match dimensions
        if in_channels != out_channels or downsample:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.upsample and out.size() != identity.size():
            identity = F.interpolate(identity, size=out.shape[2:], mode='bilinear', align_corners=False)

        out += identity
        out = self.relu(out)
        return out

class ResUNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResUNet, self).__init__()
        self.encoder1 = ResBlock(in_channels, 64)
        self.encoder2 = ResBlock(64, 128, downsample=True)
        self.encoder3 = ResBlock(128, 256, downsample=True)
        self.encoder4 = ResBlock(256, 512, downsample=True)

        self.decoder1 = ResBlock(512, 256, upsample=True)
        self.decoder2 = ResBlock(256, 128, upsample=True)
        self.decoder3 = ResBlock(128, 64, upsample=True)
        self.decoder4 = ResBlock(64, out_channels, upsample=True)

        self.final_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        d1 = self.decoder1(e4)
        d2 = self.decoder2(F.interpolate(d1, scale_factor=2) + e3)  # Ensure dimension match
        d3 = self.decoder3(F.interpolate(d2, scale_factor=2) + e2)  # Ensure dimension match
        d4 = self.decoder4(F.interpolate(d3, scale_factor=2) + e1)  # Ensure dimension match

        out = self.final_conv(d4)
        return out

## **Step 2: Utility Functions for Data Handling in Satellite Image Processing**

In [8]:
# Utility function to load image data
def load_data(filepath, is_label=False):
    """Loads image data from a file using rasterio, with option to load as label data."""
    with rasterio.open(filepath) as src:
        if is_label:
            # For labels, assume single channel and integer type
            data = src.read(1)  # read the first channel
        else:
            # For images, read all bands
            data = src.read()

        # Convert data to float32 and scale to [0, 1]
        if not is_label:
            data = data.astype(np.float32)
            data /= 255.0

    return torch.from_numpy(data)

# Function to normalize image data
def normalize_data(image, mean, std):
    """Normalize image data using provided mean and standard deviation."""
    if image.ndim == 3:  # if image has multiple channels
        for i in range(image.shape[0]):  # normalize each channel
            image[i, :, :] = (image[i, :, :] - mean[i]) / std[i]
    else:  # if single channel (e.g., label or grayscale image)
        image = (image - mean) / std
    return image

# Function to apply transformations to dataset (augmentation)
def transform(image, label):
    """Apply transformations to the image and label for data augmentation."""
    # Example transformation: vertical flip
    if torch.rand(1) > 0.5:
        image = torch.flip(image, [1])  # flip along vertical axis
        label = torch.flip(label, [1])  # flip along vertical axis
    return image, label

# Function to get windows of data from large images
def get_data_window(data, row_start, row_end, col_start, col_end):
    """Extract a window of data from a larger array."""
    return data[:, row_start:row_end, col_start:col_end]

# Example use within a dataset class
class SatelliteImageDataset(torch.utils.data.Dataset):
    """Dataset class for satellite images, utilizing the utility functions above."""
    def __init__(self, image_paths, label_paths, transform=None):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label_path = self.label_paths[idx]

        image = load_data(image_path)
        label = load_data(label_path, is_label=True)

        if self.transform:
            image, label = self.transform(image, label)

        return image, label

## **Step 3: Data Preparation and Transformation Setup for Image Processing**

In [9]:
# Adjust normalization to handle images that may have an extra channel
class Normalize(transforms.Normalize):
    """Normalize only the first three channels of the image."""
    def __call__(self, tensor):
        if tensor.size(0) == 4:
            tensor[:3] = super().__call__(tensor[:3])
            return tensor
        return super().__call__(tensor)

# Define transformations for images (assuming these are RGB images)
image_transform = transforms.Compose([
    transforms.ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define transformations for labels (if these are masks, just convert to tensor)
label_transform = transforms.Compose([
    transforms.ToTensor()
])

class CustomDataset(Dataset):
    def __init__(self, image_paths, label_paths, image_transform=None, label_transform=None):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.image_transform = image_transform
        self.label_transform = label_transform

    def __len__(self):
        # Ensuring the length function does not cause an index out of range
        return min(len(self.image_paths), len(self.label_paths))

    def __getitem__(self, index):
        if index >= len(self.image_paths) or index >= len(self.label_paths):
            raise IndexError("Index {} is out of range".format(index))

        image_path = self.image_paths[index]
        label_path = self.label_paths[index]

        try:
            with rasterio.open(image_path) as src:
                image = src.read().astype('float32')
            with rasterio.open(label_path) as src:
                label = src.read(1).astype('float32')
        except Exception as e:
            raise RuntimeError("Error opening image files at index {}: {}".format(index, e))

        # Ensure the image is three channels (RGB)
        if image.shape[0] == 1:
            image = np.tile(image, (3, 1, 1))  # Duplicate the single channel to make it three channels
        elif image.shape[0] > 3:
            image = image[:3, :, :]  # Use only the first three channels if there are more

        image = np.transpose(image, (1, 2, 0))  # CHW to HWC for torchvision transforms

        if self.image_transform:
            image = self.image_transform(image)
        if self.label_transform:
            label = self.label_transform(label)

        return image, label

# Assume image_transform and label_transform are defined elsewhere:
image_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

label_transform = transforms.ToTensor()

# Example usage:
train_dataset = CustomDataset(image_paths, label_paths, image_transform, label_transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataset = CustomDataset(image_paths, label_paths, image_transform, label_transform)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

## **Step 4: Model Training Setup and Execution**

In [10]:
# Define hyperparameters
learning_rate = 0.001
batch_size = 16
epochs = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model
model = ResUNet(in_channels=3, out_channels=1).to(device)

# Define the loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss for binary classification tasks
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop definition
def train_loop(dataloader, model, loss_fn, optimizer, device):
    """Function to execute the training loop."""
    model.train()  # Set the model to training mode
    total_loss = 0

    for batch, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass: Compute predictions and loss
        preds = model(images)
        loss = loss_fn(preds, labels)

        # Backward pass: Compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Print loss every 100 batches
        if batch % 100 == 0:
            current_loss = total_loss / (batch + 1)
            print(f"Batch {batch}, Loss: {current_loss:.4f}")

# Validation loop definition
def validate_loop(dataloader, model, loss_fn, device):
    """Function to execute the validation loop."""
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            preds = model(images)
            loss = loss_fn(preds, labels)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Validation Loss: {avg_loss:.4f}")

## **Step 5: Model Training and Validation Execution**

In [None]:
# Execute training and validation
def run_training(train_loader, val_loader, model, loss_fn, optimizer, epochs, device):
    """Executes the training and validation process over a given number of epochs."""
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs} -------------------------------")

        # Train the model for one epoch
        train_loop(train_loader, model, loss_fn, optimizer, device)

        # Validate the model at the end of the epoch
        validate_loop(val_loader, model, loss_fn, device)

    print("Training complete.")

# Save the trained model
def save_model(model, path):
    """Saves the model to the specified path."""
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

# Setup data loaders (assuming train_loader and val_loader are already defined)

# Paths for saving the model
model_save_path = "/content/gdrive/My Drive/trained_resunet.pth"

# Running the training and validation
run_training(train_loader, val_loader, model, criterion, optimizer, epochs, device)

# Saving the trained model
save_model(model, model_save_path)

Epoch 1/1 -------------------------------
Batch 0, Loss: 0.7018
Batch 100, Loss: 0.7021
Batch 200, Loss: 0.7020
