In [1]:
!pip install rasterio
!pip install geopandas
!pip install tqdm
!pip install shapely

Collecting rasterio
  Downloading rasterio-1.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading rasterio-1.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.3/22.3 MB[0m [31m116.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl (11 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1.2 cligj-0.7.2 rasterio-1.4.3


In [2]:
import os
import cv2
import rasterio
import numpy as np
import seaborn as sns

from tqdm import tqdm
from pathlib import Path
from rasterio.windows import Window
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader

from scipy.ndimage import gaussian_filter

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.InstanceNorm2d(out_channels)
        )

        self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        return F.leaky_relu(self.conv(x) + self.skip(x), 0.1)

class Autoencoder(nn.Module):
    """
    Residual U-Net style Autoencoder
    """
    def __init__(self, input_channels=3, latent_dim=256):
        super(Autoencoder, self).__init__()

        #----------ENCODER----------#
        self.enc1 = ResidualBlock(input_channels, 32)
        self.enc2 = ResidualBlock(32, 64)
        self.enc3 = ResidualBlock(64, 128)
        self.enc4 = ResidualBlock(128, 256)
        self.enc5 = ResidualBlock(256, latent_dim)

        self.pool = nn.MaxPool2d(2, 2)

        #----------DECODER----------#
        self.up5 = nn.ConvTranspose2d(latent_dim, 256, 2, stride=2)
        self.dec5 = ResidualBlock(256 + 256, 256)

        self.up4 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec4 = ResidualBlock(128 + 128, 128)

        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec3 = ResidualBlock(64 + 64, 64)

        self.up2 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec2 = ResidualBlock(32 + 32, 32)

        self.final_conv = nn.Conv2d(32, input_channels, 3, padding=1)
        self.output_activation = nn.Tanh()

def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        e5 = self.enc5(self.pool(e4))

        # Decoder with skip connections
        d5 = self.up5(e5)
        d5 = torch.cat([d5, e4], dim=1)
        d5 = self.dec5(d5)

        d4 = self.up4(d5)
        d4 = torch.cat([d4, e3], dim=1)
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat([d3, e2], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2(d2)

        out = self.output_activation(self.final_conv(d2))
        return out, e5

In [4]:
class Dataset(Dataset):
    def __init__(self, rgb_stack, patch_size=256, stride=16,
                 normalize=True, valid_threshold=0.1):
        self.rgb_stack = rgb_stack
        self.patch_size = patch_size
        self.normalize = normalize

        # Extract patches
        self.patches, self.positions = self._extract_patches(stride, valid_threshold)

        # Normalize if requested
        if normalize and len(self.patches) > 0:
            # Normalize per channel across all patches
            n_samples = self.patches.shape[0]
            n_channels = self.patches.shape[1]

            # Reshape to (n_samples * height * width, n_channels)
            original_shape = self.patches.shape
            reshaped = self.patches.reshape(n_samples, n_channels, -1)  # (N, C, H*W)
            reshaped = reshaped.transpose(0, 2, 1)  # (N, H*W, C)
            reshaped = reshaped.reshape(-1, n_channels)  # (N*H*W, C)

            # Fit and transform
            self.scaler = StandardScaler()
            normalized = self.scaler.fit_transform(reshaped)

            # Reshape back
            normalized = normalized.reshape(n_samples, patch_size, patch_size, n_channels)
            normalized = normalized.transpose(0, 3, 1, 2)  # (N, C, H, W)

            self.patches = normalized.astype(np.float32)

    def _extract_patches(self, stride, valid_threshold):
        """Extract patches from RGB stack"""
        h, w, c = self.rgb_stack.shape
        patches = []
        positions = []

        for y in range(0, h - self.patch_size + 1, stride):
            for x in range(0, w - self.patch_size + 1, stride):
                patch = self.rgb_stack[y:y+self.patch_size,
                                      x:x+self.patch_size, :]

                # Check if patch has enough valid pixels (not black/zero)
                # For RGB, we check if pixels are above a certain brightness threshold
                valid_ratio = np.mean(patch > valid_threshold * 255) if patch.dtype == np.uint8 else np.mean(patch > valid_threshold)

                if valid_ratio >= valid_threshold:
                    # Transpose to (C, H, W) for PyTorch
                    patch = np.transpose(patch, (2, 0, 1))
                    patches.append(patch)
                    positions.append((y, x))

        patches = np.array(patches, dtype=np.float32)
        print(f"Extracted {len(patches)} valid patches")

        return patches, positions

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        return torch.FloatTensor(self.patches[idx])