## Zarr DataLoader

In [None]:
import os
import glob
import natsort
import xarray as xr

import sys
import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append('../')
from data.healpix import *
from utils.plot import plot_all_chunks

The main goal of this Notebook is to create a Datasets and DataLoader that manipulate Zarr and its associated chunks. PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

## Import data

In [None]:
def prepare_paths(path_dir):

    df_input = pd.read_csv(f"{path_dir}/input.csv")
    df_output = pd.read_csv(f"{path_dir}/target.csv")
    df_input["path"] = df_input["Name"].apply(lambda x: os.path.join(path_dir, "input", os.path.basename(x).replace(".SAFE","")))
    df_output["path"] = df_output["Name"].apply(lambda x: os.path.join(path_dir, "target", os.path.basename(x).replace(".SAFE","")))

    return df_input, df_output

base_dir = "/mnt/disk/dataset/sentinel-ai-processor"
version = "V4"

TRAIN_DIR = f"{base_dir}/{version}/train/"
VAL_DIR = f"{base_dir}/{version}/val/"
TEST_DIR = f"{base_dir}/{version}/test/"
df_train_input, df_train_output =  prepare_paths(TRAIN_DIR)
df_val_input, df_val_output =  prepare_paths(VAL_DIR)
df_test_input, df_test_output =  prepare_paths(TEST_DIR)
df_test_output = df_test_output[:2]

In [None]:
zarr_index = 1

In [None]:
# open .zarr datatree
x_path = df_test_output["path"].iloc[zarr_index] + ".zarr"
dt = xr.open_datatree(x_path, engine="zarr", mask_and_scale=False, chunks={})
x_path

## Plot all chunks for a given resolution

In [None]:
res = "60m"
band = "b01"
chunk_size_y, chunk_size_x, nb_chunks_y, nb_chunks_x = get_chunk_info(data_tree=dt, band=band, res=res)

In [None]:
plot_all_chunks(dt, band, res, chunk_size_y, chunk_size_x, nb_chunks_y, nb_chunks_x, cmap="viridis", verbose= False, figsize_scale=3)

## Extract Chunk from Xarray

In [None]:
chunk_y_idx = 4
chunk_x_idx = 4
chunk = get_chunk(data_tree=dt, res=res,
                  chunk_y_idx=chunk_y_idx, chunk_x_idx=chunk_x_idx,
                  chunk_size_y=chunk_size_y, chunk_size_x=chunk_size_x)

chunk_array = chunk.to_dataset().to_dataarray()
chunk_array

## Datasets & DataLoaders

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

# --------- Sentinel2 Dataset ----------
class Sentinel2ZarrDataset(Dataset):
    def __init__(self, df_x, res, bands):
        self.df_x = df_x
        self.res_key = f"r{res}"
        self.x_res = f"x_{res}"
        self.y_res = f"y_{res}"
        self.bands = bands
        self.res = res

    def __getitem__(self, index):
        x_path = self.df_x["path"].iloc[index] + ".zarr"
        dt = xr.open_datatree(x_path, engine="zarr", mask_and_scale=False, chunks={})
        data_tree = dt.measurements.reflectance[self.res_key]

        chunk_size_y = data_tree["b01"].chunksizes[self.y_res][0]
        chunk_size_x = data_tree["b01"].chunksizes[self.x_res][0]
        nb_chunks_y = len(data_tree["b01"].chunksizes[self.y_res])
        nb_chunks_x = len(data_tree["b01"].chunksizes[self.x_res])

        all_chunks = []
        for row in range(nb_chunks_y):  # matrix row = Y
            for col in range(nb_chunks_x):  # matrix col = X
                y_start = row * chunk_size_y
                x_start = col * chunk_size_x
                chunk = data_tree.isel(
                    {self.y_res: slice(y_start, y_start + chunk_size_y),
                     self.x_res: slice(x_start, x_start + chunk_size_x)}
                )
                chunk = np.array(chunk.load().to_dataset().to_dataarray())
                chunk = torch.from_numpy(chunk)
                all_chunks.append(chunk)

        all_chunks = torch.stack(all_chunks)  # [nb_chunks, bands, h, w]
        all_chunks = all_chunks.view(nb_chunks_y, nb_chunks_x, *all_chunks.shape[1:])
        return all_chunks, (nb_chunks_y, nb_chunks_x, chunk_size_y, chunk_size_x)

    def __len__(self):
        return len(self.df_x)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

Total batch number is equal to Number of zarr * nb_chunks_x * nb_chunks_y

example: 
 - 60m resolution
 - nb_chunks_x = 6 
 - nb_chunks_y = 6 
 - batch_size in train_loader = 2

final batch = batch_size* nb_chunks_x * nb_chunks_y * = 72

Final ouput ->>>>  [72, 11, 305, 305]

## Check data 

Let's take a random index in the entire dataset. batch = len of dataset

In [None]:
from tqdm import tqdm
import torch

# --------- Parameters ----------
res = "60m"
bands = ['b01', 'b02', 'b03', 'b04', 'b05', 'b06', 'b07', 'b09', 'b11', 'b12', 'b8a']
batch_size = 2

train_dataset = Sentinel2ZarrDataset(df_x=df_test_output, res=res, bands=bands)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
device = "cuda" if torch.cuda.is_available() else "cpu"
with tqdm(total=len(train_loader.dataset), ncols=100, colour='#3eedc4') as t:

    t.set_description("Training")
    for chunks_grid, _ in train_loader:
        # chunks_grid: [B, nb_chunks_y, nb_chunks_x, C, H, W]
        # Flatten chunk grid → [B * nb_chunks_y * nb_chunks_x, C, H, W]
        B, ny, nx, C, H, W = chunks_grid.shape
        chunks_tensor = chunks_grid.view(B * ny * nx, C, H, W).to(device)
        print(chunks_tensor.shape)
        t.update(B)

In [None]:
print(f"Final output shape {chunks_tensor.shape} - [B, C, H, W]")

# Batch Verification

In [None]:
import matplotlib.pyplot as plt

In [None]:
res = "60m"
bands = ['b01', 'b02', 'b03', 'b04', 'b05', 'b06', 'b07', 'b09', 'b11', 'b12', 'b8a']
df_test_input, df_test_output =  prepare_paths(TEST_DIR)
batch_size = len(df_test_output)
train_dataset = Sentinel2ZarrDataset(df_x=df_test_output, res=res, bands=bands)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
device = "cuda" if torch.cuda.is_available() else "cpu"

with tqdm(total=len(train_loader.dataset), ncols=100, colour='#3eedc4') as t:
    t.set_description("Training")

    for batch_idx, (chunks_grid, _) in enumerate(train_loader):
        # chunks_grid: [B, nb_chunks_y, nb_chunks_x, C, H, W]
        B, ny, nx, C, H, W = chunks_grid.shape
        chunks_tensor = chunks_grid.view(B * ny * nx, C, H, W).to(device)

        print(chunks_tensor.shape)

        # Loop over each scene in the batch
        for batch_scene in range(B):
            scene_index = batch_idx * batch_size + batch_scene
            scene_chunks = chunks_grid[batch_scene]  # [ny, nx, C, H, W]

            # Pick a position in the chunk grid
            row_idx = 0
            col_idx = 4

            # Get rebuilt chunk from dataset tensor
            rebuilt_chunk = scene_chunks[row_idx, col_idx]  # [C, H, W]

            # Load same chunk directly from Zarr
            x_path = df_test_output["path"].iloc[scene_index] + ".zarr"
            dt = xr.open_datatree(x_path, engine="zarr", mask_and_scale=False, chunks={})
            data_tree = dt.measurements.reflectance[f"r{res}"]

            # Compute pixel indices in full image
            chunk_size_y = H
            chunk_size_x = W
            y_start = row_idx * chunk_size_y
            x_start = col_idx * chunk_size_x

            original_chunk = data_tree.isel(
                {f"y_{res}": slice(y_start, y_start + chunk_size_y),
                 f"x_{res}": slice(x_start, x_start + chunk_size_x)}
            ).to_dataset().to_dataarray()

            # --- Plot rebuilt ---
            plt.figure(figsize=(8, 6))
            plt.imshow(rebuilt_chunk[0].cpu().numpy(), cmap="viridis")
            plt.title(f"Rebuilt - Scene {scene_index} - Chunk ({row_idx}, {col_idx}) - Band 0")
            plt.colorbar()
            plt.show()

            # --- Plot original ---
            plt.figure(figsize=(8, 6))
            plt.imshow(original_chunk[0].values, cmap="viridis")
            plt.title(f"Original - Scene {scene_index} - Chunk ({row_idx}, {col_idx}) - Band 0")
            plt.colorbar()
            plt.show()

        t.update(B)
