In [None]:
## Prepare Sentinel-1 (SAR) and Sentinel-2 (optical) satellite images
## Preprocessing script for projection and resampling

In [None]:
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.crs import CRS


# List of file paths for the 12 Sentinel-2 bands
file_paths = [
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B01_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B02_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B03_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B04_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B05_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B06_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B07_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B08_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B8A_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B09_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B11_(Raw).tiff",
    "/path/to/your/data/2020-08-11-00:00_2020-08-11-23:59_Sentinel-2_L2A_B12_(Raw).tiff"
]

'''
# List of file paths for Sentinel-1 VV and VH band paths
file_paths = [
    "/path/to/your/data/2019-03-30-00:00_2019-03-30-23:59_Sentinel-1_IW_VV+VH_VV_(Raw).tiff",
    "/path/to/your/data/2019-03-30-00:00_2019-03-30-23:59_Sentinel-1_IW_VV+VH_VH_(Raw).tiff"
]
'''


def get_resolution(file_path):
    """
    Get the spatial resolution (pixel size) of a TIFF file.

    Parameters:
        file_path (str): Path to the raster file.

    Returns:
        tuple: (resolution_x, resolution_y) in meters.
    """
    with rasterio.open(file_path) as src:
        return src.res


def project_raster(input_file, output_file, target_crs):
    """
    Reproject a raster to a specified coordinate reference system (CRS).

    Parameters:
        input_file (str): Path to the input raster.
        output_file (str): Path to the output reprojected raster.
        target_crs (CRS): Target coordinate reference system.
    """
    with rasterio.open(input_file) as src:
        transform, width, height = calculate_default_transform(
            src.crs, target_crs, src.width, src.height, *src.bounds
        )

        kwargs = src.meta.copy()
        kwargs.update({
            "crs": target_crs,
            "transform": transform,
            "width": width,
            "height": height
        })

        with rasterio.open(output_file, "w", **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=target_crs,
                    resampling=Resampling.nearest
                )


def resample_raster(input_file, output_file, target_res_x, target_res_y):
    """
    Resample a raster to a specified spatial resolution.

    Parameters:
        input_file (str): Path to the input raster.
        output_file (str): Path to the output resampled raster.
        target_res_x (float): Desired pixel width in meters.
        target_res_y (float): Desired pixel height in meters.
    """
    with rasterio.open(input_file) as src:
        transform = rasterio.Affine(
            target_res_x, 0.0, src.bounds.left, 0.0, -target_res_y, src.bounds.top
        )
        width = int((src.bounds.right - src.bounds.left) / target_res_x)
        height = int((src.bounds.top - src.bounds.bottom) / target_res_y)

        kwargs = src.meta.copy()
        kwargs.update({
            "transform": transform,
            "width": width,
            "height": height,
            "dtype": "float32"
        })

        with rasterio.open(output_file, "w", **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=src.crs,
                    resampling=Resampling.bilinear
                )


for path in file_paths:
    res_x, res_y = get_resolution(path)
    print(f"Spatial resolution of {path.split('/')[-1]}: {res_x}m x {res_y}m")

for path in file_paths:
    output_path = f"{path.split('.')[0]}_proj.tiff"
    project_raster(path, output_path, CRS.from_epsg(32723))

for path in file_paths:
    projected_path = f"{path.split('.')[0]}_proj.tiff"
    res_x, res_y = get_resolution(projected_path)
    print(f"Projected resolution of {projected_path.split('/')[-1]}: {res_x}m x {res_y}m")

for path in file_paths:
    projected_path = f"{path.split('.')[0]}_proj.tiff"
    output_resampled_path = f"{projected_path.split('.')[0]}_10m.tiff"
    resample_raster(projected_path, output_resampled_path, 10, 10)

In [None]:
## Raster tiling and feature extraction with CROMA
## Requires code and weights from https://github.com/antofuller/CROMA

In [None]:
import geopandas as gpd
import pandas as pd
import rasterio
from rasterio.windows import Window
import os


class Cutter:
    """
    Cut a raster image into smaller tiles based on a vector grid.

    Attributes:
        vector_file (str): Path to the shapefile containing the grid cells.
        raster_file (str): Path to the raster image to be split.
        output_folder (str): Folder where the tiles will be saved.
    """

    def __init__(self, vector_file, raster_file, output_folder):
        self.vector_file = vector_file
        self.raster_file = raster_file
        self.output_folder = output_folder

    def cut_images(self):
        """
        Read the input raster and vector grid, then export one raster patch per polygon.
        Each patch corresponds to the bounds of a single grid cell.
        """
        grid = gpd.read_file(self.vector_file)
        gdf = gpd.GeoDataFrame(geometry=grid.geometry)
        gdf["id"] = grid["id"]

        with rasterio.open(self.raster_file) as src:
            for _, row in grid.iterrows():
                cell_id = row["id"]
                geom = row.geometry

                window = src.window(*geom.bounds)
                subset = src.read(window=window)

                profile = src.profile
                profile.update({
                    "height": window.height,
                    "width": window.width,
                    "transform": rasterio.windows.transform(window, src.transform)
                })

                os.makedirs(self.output_folder, exist_ok=True)

                output_path = os.path.join(self.output_folder, f"{cell_id}.tif")
                with rasterio.open(output_path, "w", **profile) as dst:
                    dst.write(subset)

        gdf.to_file(os.path.join(self.output_folder, "check"))


import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import VisionDataset
import imageio.v2 as imageio
import numpy as np
import os


class Loader:
    """
    Prepare paired optical and SAR image datasets for feature extraction.

    Attributes:
        opt_root_dir (str): Directory containing optical image patches.
        sar_root_dir (str): Directory containing SAR image patches.
    """

    def __init__(self, opt_root_dir, sar_root_dir):
        self.opt_root_dir = opt_root_dir
        self.sar_root_dir = sar_root_dir

    def load_data(self, sample_size=224, batch_size=32):
        """
        Create a PyTorch DataLoader from paired optical and SAR images.

        Parameters:
            sample_size (int): Resize dimension for the images.
            batch_size (int): Number of samples per batch.
        """
        class CustomDataset(VisionDataset):
            """Custom paired dataset for optical and SAR images."""

            def __init__(self, opt_root, sar_root, transform=None, target_transform=None):
                super(CustomDataset, self).__init__(
                    root=opt_root, transform=transform, target_transform=target_transform
                )
                self.opt_root = opt_root
                self.sar_root = sar_root
                self.transform = transform

                self.samples = []
                opt_images = [
                    img_name for img_name in os.listdir(opt_root)
                    if os.path.isfile(os.path.join(opt_root, img_name))
                ]
                for img_name in opt_images:
                    opt_img_path = os.path.join(opt_root, img_name)
                    sar_img_path = os.path.join(sar_root, img_name)
                    self.samples.append((opt_img_path, sar_img_path))

            def __len__(self):
                return len(self.samples)

            def __getitem__(self, idx):
                opt_img_path, sar_img_path = self.samples[idx]
                opt_img = imageio.imread(opt_img_path)
                sar_img = imageio.imread(sar_img_path)
                opt_img = np.transpose(opt_img, (2, 0, 1)).astype(np.float32)
                sar_img = np.transpose(sar_img, (2, 0, 1)).astype(np.float32)
                opt_img = torch.tensor(opt_img)
                sar_img = torch.tensor(sar_img)
                if self.transform:
                    opt_img = self.transform(opt_img)
                    sar_img = self.transform(sar_img)
                return opt_img, sar_img, opt_img_path, sar_img_path

        transform = transforms.Resize((sample_size, sample_size), antialias=True)
        dataset = CustomDataset(
            opt_root=self.opt_root_dir,
            sar_root=self.sar_root_dir,
            transform=transform
        )
        self.dataset = dataset
        self.loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


import torch
import torch.nn as nn
from tqdm import tqdm
import h5py
import pickle
from use_croma import PretrainedCROMA


class FeatureExtractor:
    """
    Extract deep multimodal features using a pretrained CROMA network.

    Attributes:
        dataloader (torch.utils.data.DataLoader): DataLoader providing paired images.
        use_8_bit (bool): Whether to normalize to 8-bit before feeding the model.
    """

    def __init__(self, dataloader, use_8_bit=True):
        self.dataloader = dataloader
        self.use_8_bit = use_8_bit
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.FE = PretrainedCROMA(
            pretrained_path="CROMA_base.pt", size="base", modality="both", image_resolution=120
        )
        self.FE.to(self.device)
        self.FE.eval()

    def normalize(self, x):
        """
        Normalize each image channel to either [0, 255] or [0, 1] based on use_8_bit.

        Parameters:
            x (torch.Tensor): Input tensor with shape (B, C, H, W).

        Returns:
            torch.Tensor: Normalized tensor.
        """
        x = x.float()
        imgs = []
        for channel in range(x.shape[1]):
            mean = x[:, channel, :, :].mean()
            std = x[:, channel, :, :].std()
            min_val = mean - 2 * std
            max_val = mean + 2 * std
            if self.use_8_bit:
                img = (x[:, channel, :, :] - min_val) / (max_val - min_val) * 255.0
                img = torch.clip(img, 0, 255).unsqueeze(dim=1).to(torch.uint8)
            else:
                img = (x[:, channel, :, :] - min_val) / (max_val - min_val)
                img = torch.clip(img, 0, 1).unsqueeze(dim=1)
            imgs.append(img)
        return torch.cat(imgs, dim=1)

    def extract_features(self, save_name=None):
        """
        Extract features for each image pair and optionally save them to disk.

        Parameters:
            save_name (str): Optional base name for the output files (.h5 and .pkl).

        Returns:
            tuple: (features, ids)
        """
        features_batches = []
        id_batches = []
        with torch.no_grad():
            for optical_images, sar_images, optical_img_paths, _ in tqdm(
                self.dataloader, desc="Extracting Features"
            ):
                optical_images = self.normalize(optical_images.to(self.device))
                sar_images = self.normalize(sar_images.to(self.device))
                if self.use_8_bit:
                    optical_images = optical_images.float() / 255
                    sar_images = sar_images.float() / 255
                outputs = self.FE(
                    SAR_images=sar_images, optical_images=optical_images
                )["joint_GAP"]
                features_batches.append(outputs.cpu())
                ids_i = torch.tensor([
                    int((path.split(".")[0]).split("/")[-1])
                    for path in optical_img_paths
                ])
                id_batches.append(ids_i)

        features = torch.cat(features_batches).numpy()
        ids = torch.cat(id_batches).numpy()

        if save_name:
            with h5py.File(f"{save_name}.h5", "w") as hf:
                hf.create_dataset("features", data=features)
                hf.create_dataset("ids", data=ids)

            with open(f"{save_name}.pkl", "wb") as f:
                pickle.dump([features, ids], f)

        return features, ids


# --- Execution pipeline ---

cutter = Cutter(
    "/path/to/your/data/grid.shp",
    "/path/to/your/data/sar_image.tif",
    "/path/to/your/data/patchs_sar_image"
)
cutter.cut_images()

cutter = Cutter(
    "/path/to/your/data/grid.shp",
    "/path/to/your/data/opt_image.tif",
    "/path/to/your/data/patchs_opt_image"
)
cutter.cut_images()

processor = Loader(
    opt_root_dir="/path/to/your/data/patchs_opt_image",
    sar_root_dir="/path/to/your/data/patchs_sar_image"
)
processor.load_data(sample_size=120, batch_size=16)
loader = processor.loader

feature_extractor = FeatureExtractor(dataloader=loader, use_8_bit=True)
features, ids = feature_extractor.extract_features("dl_features")