In [None]:
import numpy as np
import zarr 
from pathlib import Path

# 1. Data Loading

In [None]:
trainval_root = Path("/Data_large/marine/PythonProjects/OtherProjects/lpl-PyNas/data/Phisat2Simulation/TrainVal")
trainval_imgs = trainval_root / "numpy_images"
trainval_masks = trainval_root / "numpy_masks"
# list all imgs and masks
trainval_imgs = sorted(trainval_imgs.glob("*.npy"))
# find the corresponding mask for each img
trainval_masks = [trainval_masks / img.name.replace('image','mask') for img in trainval_imgs]
###################################
test_root = Path("/Data_large/marine/PythonProjects/OtherProjects/lpl-PyNas/data/Phisat2Simulation/Test")
test_imgs = test_root / "numpy_images"
test_masks = test_root / "numpy_masks"
# list all imgs and masks
test_imgs = sorted(test_imgs.glob("*.npy"))
test_masks = [test_masks / img.name.replace('image','mask') for img in test_imgs]

## Helper Function

In [None]:
def add_sample(root, 
               dataset_set, 
               idx, 
               task, 
               img, 
               label,
               metadata=None,
               overwrite=False):
    """
    Adds a sample to the Zarr dataset.

    Args:
        root (zarr.Group): Root of the Zarr dataset.
        dataset_set (str): One of {"trainval", "test"} (case-insensitive).
        idx (int): Sample index.
        task (str): One of {"classification", "segmentation", "regression", "compression"} (case-insensitive).
        img (np.ndarray): Image array, shape (C, H, W).
        label (np.ndarray): Label array; 1D for classification, 3D for other tasks.
        metadata (dict, optional): Metadata to attach to the sample group.
        overwrite (bool): If True, overwrites existing sample with same index.

    Raises:
        AssertionError: If types or shapes are invalid.
        ValueError: If task or dataset_set is not recognized.
    """
    dataset_set = dataset_set.lower()
    task = task.lower()

    assert isinstance(idx, int), "idx must be an integer"
    assert isinstance(img, np.ndarray) and img.ndim == 3, "img must be a 3D numpy array"
    assert isinstance(label, np.ndarray), "label must be a numpy array"
    assert dataset_set in {"trainval", "test"}, f"Unknown dataset_set: {dataset_set}"
    assert task in {"classification", "segmentation", "regression", "compression"}, f"Unknown task: {task}"

    if task == "classification":
        assert label.ndim == 1, "label must be 1D for classification"
    else:
        assert label.ndim == 3, "label must be 3D for non-classification tasks"
        assert img.shape[1:] == label.shape[1:], "img and label must have same spatial dimensions"
        if task == "compression":
            assert img.shape[0] == label.shape[0], "img and label must have same number of channels for compression"

    # Ensure the dataset_set group exists
    if dataset_set not in root:
        root.create_group(dataset_set)
        
    dataset_group = root[dataset_set]
    sample_id = f"{idx:07d}"

    if sample_id in dataset_group:
        if overwrite:
            del dataset_group[sample_id]
        else:
            raise FileExistsError(f"Sample '{dataset_set}/{sample_id}' already exists. Use overwrite=True to replace.")

    g = dataset_group.create_group(sample_id)
    g.create_dataset("img", data=img.astype(np.float32))
    g.create_dataset("label", data=label.astype(np.float32))

    # ---- Metadata ----
    meta = metadata or {}

    # Task attribute
    g.attrs["task"] = task

    # Sensor-related
    g.attrs.update({
        "sensor": meta.get("sensor", "S2A"),
        "sensor_resolution": meta.get("sensor_resolution", 10),
        "sensor_orbit": meta.get("sensor_orbit", "ASCENDING"),
        "spectral_bands_ordered": meta.get("spectral_bands_ordered", "B2-B3-B4-B4"),
        "sensor_orbit_number": meta.get("sensor_orbit_number", 0),
        "datatake": meta.get("datatake", "00-00-0000 00:00:00")
    })

    # Geolocation
    geo = meta.get("geolocation", {})
    g.attrs["geolocation"] = {
        "UL": geo.get("UL", [np.nan, np.nan]),
        "UR": geo.get("UR", [np.nan, np.nan]),
        "LL": geo.get("LL", [np.nan, np.nan]),
        "LR": geo.get("LR", [np.nan, np.nan])
    }

    # Ancillary
    anc = meta.get("ancillary", {})
    g.attrs.update({
        "cloud_cover": anc.get("cloud_cover", np.nan),
        "sun_azimuth": anc.get("sun_azimuth", np.nan),
        "sun_elevation": anc.get("sun_elevation", np.nan),
        "view_azimuth": anc.get("view_azimuth", np.nan),
        "view_elevation": anc.get("view_elevation", np.nan)
    })

# 2. Dataset Creation

This implementation:
1. First ensures the dataset group ("trainval" or "test") exists
2. Gets a reference to that group
3. Works with sample IDs inside that group
4. Maintains the same functionality while being more explicit about the hierarchy

### a) TrainVal

In [None]:
root = zarr.open("burned_area_dataset.zarr", mode="w")

dataset_set = "trainval"
task = "segmentation"

for idx, (img, mask) in enumerate(zip(trainval_imgs, trainval_masks)):
    img = np.load(img)
    mask = np.load(mask)
    # H, W, C -> C, H, W
    img = np.moveaxis(img, -1, 0)
    mask = np.moveaxis(mask, -1, 0)
    
    print(f"Image shape: {img.shape}, Mask shape: {mask.shape}")
    # Assuming img and mask are already loaded as numpy arrays
    add_sample(root, 
               dataset_set, 
               idx, 
               task, 
               img, 
               label=mask,
               metadata=None,
               overwrite=False)


### b) Test

In [None]:
dataset_set = "test" # or trainval
task = "segmentation"

for idx, (img, mask) in enumerate(zip(test_imgs, test_masks)):
    img = np.load(img)
    mask = np.load(mask)
    # H, W, C -> C, H, W
    img = np.moveaxis(img, -1, 0)
    mask = np.moveaxis(mask, -1, 0)
    
    print(f"Image shape: {img.shape}, Mask shape: {mask.shape}")
    # Assuming img and mask are already loaded as numpy arrays
    add_sample(root, 
               dataset_set, 
               idx, 
               task, 
               img, 
               label=mask,
               metadata=None,
               overwrite=False)


# Part B) Dataloader

In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from data_loader import get_zarr_dataloader, NormalizeChannels
from tqdm import tqdm


# Example usage
zarr_path = "burned_area_dataset.zarr"

dataset_set = "trainval"
# Create DataLoader
dataloader = get_zarr_dataloader(
    zarr_path=zarr_path,
    dataset_set=dataset_set,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    transform=NormalizeChannels(min_max=True),
    task_filter="segmentation",
    metadata_keys=["sensor", "timestamp"],
)

# Iterate through batches
for idx, batch in enumerate(tqdm(dataloader, desc="Processing Batches")):
    # Access data based on tasks in the batch
    for task in batch['tasks']:
        images = batch[f'{task}_img']
        labels = batch[f'{task}_label']
        # Forward pass, compute loss, etc.

Processing Batches: 100%|██████████| 487/487 [00:30<00:00, 16.00it/s]
