In [4]:
!uv pip install tqdm

[2mUsing Python 3.12.3 environment at: C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite\.venv[0m
[2mResolved [1m2 packages[0m [2min 149ms[0m[0m
[2mInstalled [1m1 package[0m [2min 148ms[0m[0m
 [32m+[39m [1mtqdm[0m[2m==4.67.1[0m


In [6]:
from collections.abc import Iterator
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import rasterio
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [12]:
# Open 4 JP2 files using rasterio and convert to numpy array
def open_jp2_as_numpy(file_path: Path) -> np.ndarray:
    with rasterio.open(file_path) as src:
        data = src.read()
    return data


@dataclass
class JP2Dataset(Dataset):
    jp2_files: list[Path]

    def __len__(self) -> int:
        return len(self.jp2_files)

    def __getitem__(self, idx: int) -> torch.Tensor:
        jp2_file = self.jp2_files[idx]
        data = open_jp2_as_numpy(jp2_file)
        data = np.moveaxis(data, 0, -1)  # (H, W, C)

        # Handle single-band images
        if data.shape[-1] == 1:
            data = data.squeeze(-1)

        # Resize using numpy and torch (no PIL)
        data = data.astype(np.float32)
        data = torch.from_numpy(data)
        if data.ndim == 2:
            data = data.unsqueeze(0)  # (1, H, W)
        else:
            data = data.permute(2, 0, 1)  # (C, H, W)
        data = torch.nn.functional.interpolate(
            data.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False
        ).squeeze(0)
        data = data / 255.0
        return torch.tensor(data)


def create_dataloader(jp2_files: list[Path], batch_size: int = 4) -> DataLoader:
    dataset = JP2Dataset(jp2_files)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    return dataloader


def process_jp2_files(jp2_files: list[Path], batch_size: int = 4) -> None:
    dataloader = create_dataloader(jp2_files, batch_size)

    for batch in tqdm(dataloader, desc="Processing JP2 files"):
        # Here you can process the batch of images
        # For demonstration, we will just print the shape of each batch
        print(f"Batch shape: {batch.shape}")


if __name__ == "__main__":
    # Example JP2 files (replace with your actual file paths)
    jp2_files = [
        Path(
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\nir\B08.jp2"
        ),
        Path(
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\red\B04.jp2"
        ),
        Path(
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\green\B03.jp2"
        ),
        Path(
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\blue\B02.jp2"
        ),
    ]

    # Process the JP2 files
    process_jp2_files(jp2_files, batch_size=4)

  return torch.tensor(data)
Processing JP2 files: 100%|██████████| 1/1 [00:10<00:00, 10.01s/it]

Batch shape: torch.Size([4, 1, 256, 256])





---

In [32]:
import torch.nn as nn


class SimpleUNetV2(nn.Module):
    def __init__(self, dropout_rate: float = 0.3) -> None:
        super().__init__()
        self.dropout_rate = dropout_rate

        # Encoder
        self.enc1 = self.conv_block(4, 32)
        self.enc2 = self.conv_block(32, 64)

        # Bottleneck
        self.bottleneck = self.conv_block(64, 128)
        self.dropout_bottleneck = nn.Dropout2d(p=self.dropout_rate)

        # Decoder
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        self.dropout_dec1 = nn.Dropout2d(p=self.dropout_rate)

        self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(64, 32)

        self.final = nn.Conv2d(32, 1, kernel_size=1)
        self.pool = nn.MaxPool2d(2)

    def conv_block(self, in_ch: int, out_ch: int) -> nn.Sequential:
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: nn.Sequential) -> nn.Conv2d:
        # Encode
        x1 = self.enc1(x)  # (B, 32, H, W)
        x2 = self.enc2(self.pool(x1))  # (B, 64, H/2, W/2)

        # Bottleneck + dropout
        x3 = self.bottleneck(self.pool(x2))
        x3 = self.dropout_bottleneck(x3)  # (B, 128, H/4, W/4)

        # Decode
        x4 = self.up1(x3)
        x4 = self.dec1(torch.cat([x4, x2], dim=1))
        x4 = self.dropout_dec1(x4)

        x5 = self.up2(x4)
        x5 = self.dec2(torch.cat([x5, x1], dim=1))

        return self.final(x5)  # (B, 1, H, W)


In [22]:
import numpy as np
import rasterio


def read_sentinel_channels(red_path, green_path, blue_path, nir_path):
    with (
        rasterio.open(red_path) as red_src,
        rasterio.open(green_path) as green_src,
        rasterio.open(blue_path) as blue_src,
        rasterio.open(nir_path) as nir_src,
    ):
        red = red_src.read(1)
        green = green_src.read(1)
        blue = blue_src.read(1)
        nir = nir_src.read(1)
    stacked = np.stack([red, green, blue, nir], axis=-1)
    return stacked.astype(np.float32) / 10000  # Normalisation


In [36]:
class TileManager:
    def __init__(self, tile_size=384, image_size=10180):
        self.tile_size = tile_size
        self.image_size = image_size
        self.indices = self._compute_indices()

    def _compute_indices(self):
        indices = []
        for y in range(0, self.image_size, self.tile_size):
            for x in range(0, self.image_size, self.tile_size):
                indices.append((y, x))
        return indices

    def extract_tiles(self, image):
        tiles = []
        valid_indices = []
        for y, x in self.indices:
            tile = image[y : y + self.tile_size, x : x + self.tile_size]

            # Si bord droit/bas : pad avec 0
            pad_y = self.tile_size - tile.shape[0]
            pad_x = self.tile_size - tile.shape[1]
            if pad_y > 0 or pad_x > 0:
                tile = np.pad(tile, ((0, pad_y), (0, pad_x), (0, 0)), mode="constant")

            tiles.append(tile)
            valid_indices.append((y, x))
        return np.stack(tiles), valid_indices

In [None]:
import torch


def load_unet_model(path):
    model = SimpleUNetV2()
    model.load_state_dict(torch.load(path, map_location="cpu"))  # ou "cuda" si dispo
    model.eval()
    return model


def infer_tiles(model, tiles, batch_size=32):
    model.eval()
    preds = []
    with torch.no_grad():
        for i in range(0, len(tiles), batch_size):
            batch = tiles[i : i + batch_size]
            batch_tensor = torch.from_numpy(batch.transpose(0, 3, 1, 2))  # BCHW
            output = model(batch_tensor).squeeze(1).numpy()
            preds.append(output)
    return np.concatenate(preds, axis=0)


In [25]:
def filter_useful_tiles(predictions, tiles, indices, threshold=0.5, max_cloud_coverage=0.5):
    useful_tiles = []
    useful_indices = []
    for pred, tile, idx in zip(predictions, tiles, indices):
        low_pred_ratio = (pred < threshold).mean()
        if low_pred_ratio < max_cloud_coverage:
            useful_tiles.append(tile[:, :, :3])  # RGB only
            useful_indices.append(idx)
    return useful_tiles, useful_indices


In [None]:
def update_prediction_state(pred_mask, indices, tile_size=350):
    for y, x in indices:
        pred_mask[y : y + tile_size, x : x + tile_size] = True
    return pred_mask


In [None]:
def assemble_rgb_image(tiles, indices, image_shape=(10180, 10180, 3), tile_size=350):
    full_image = np.zeros(image_shape, dtype=np.float32)
    for tile, (y, x) in zip(tiles, indices):
        full_image[y : y + tile_size, x : x + tile_size] = tile
    return full_image


In [38]:
def run_inference_pipeline(image_paths_list, model_path):
    model = load_unet_model(model_path)
    tile_manager = TileManager()

    prediction_mask = np.zeros((10180, 10180), dtype=bool)
    all_rgb_tiles = []
    all_indices = []

    for red, green, blue, nir in image_paths_list:
        print(f"Processing: {red}, {green}, {blue}, {nir}")
        image = read_sentinel_channels(red, green, blue, nir)
        tiles, indices = tile_manager.extract_tiles(image)

        # Ne prédire que les tuiles non déjà validées
        print(f"Extracted {len(tiles)} tiles, {len(indices)} indices.")
        to_predict = [(t, i) for t, i in zip(tiles, indices) if not prediction_mask[i[0], i[1]]]
        if not to_predict:
            continue

        print(f"Predicting {len(to_predict)} new tiles.")
        t_tiles, t_indices = zip(*to_predict)
        t_tiles = np.stack(t_tiles)

        print("Running inference on tiles...")

        preds = infer_tiles(model, t_tiles, batch_size=16)

        print("Filtering useful tiles...")
        useful_tiles, useful_indices = filter_useful_tiles(preds, t_tiles, t_indices)

        print(f"Found {len(useful_tiles)} useful tiles.")
        if not useful_tiles:
            continue

        all_rgb_tiles.extend(useful_tiles)
        all_indices.extend(useful_indices)

        print(f"Updating prediction mask with {len(useful_indices)} indices.")
        prediction_mask = update_prediction_state(prediction_mask, useful_indices)

        if prediction_mask.all():
            break

    print("Assembling final RGB image from useful tiles...")

    final_image = assemble_rgb_image(all_rgb_tiles, all_indices)
    return final_image


In [39]:
run_inference_pipeline(
    [
        (
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\red\B04.jp2",
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\green\B03.jp2",
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\blue\B02.jp2",
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\nir\B08.jp2",
        ),
        (
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-26\31UDQ\red\B04.jp2",
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-26\31UDQ\green\B03.jp2",
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-26\31UDQ\blue\B02.jp2",
            r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-26\31UDQ\nir\B08.jp2",
        ),
    ],
    r"C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite\satellite\exploration\models\simple_unet_v2_subset4000_epoch20.pth",
)

Processing: C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\red\B04.jp2, C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\green\B03.jp2, C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\blue\B02.jp2, C:\Users\giand\OneDrive\Documents\__packages__\_perso\satellite_data\sentinel2-31UDQ\2025-05-16\31UDQ\nir\B08.jp2


MemoryError: Unable to allocate 1.80 GiB for an array with shape (10980, 10980, 4) and data type float32