In [1]:
import os
import os
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from loguru import logger
from tqdm import tqdm
import natsort
import glob
import cv2
from PIL import Image
from data.dataset import Sentinel2TCIDataset, Sentinel2Dataset
from data.loader import define_loaders
from model_zoo.models import define_model
from training.metrics import MultiSpectralMetrics, avg_metric_bands
from utils.torch import count_parameters, load_model_weights, seed_everything
from utils.utils import load_config
from utils.wandb_logger import WandbLogger
from training.losses import WeightedMSELoss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def prepare_paths(path_dir):


    df_input = pd.read_csv(f"{path_dir}/input.csv")
    df_output = pd.read_csv(f"{path_dir}/target.csv")

    df_input["path"] = df_input["Name"].apply(lambda x: os.path.join(path_dir, "input", os.path.basename(x).replace(".SAFE","")))
    df_output["path"] = df_output["Name"].apply(lambda x: os.path.join(path_dir, "target", os.path.basename(x).replace(".SAFE","")))

    return df_input, df_output


def prepare_data(config):
    base_dir = config['DATASET']['base_dir']
    version = config['DATASET']['version']
    resize = config['TRAINING']['resize']


    TRAIN_DIR = f"/mnt/disk/dataset/sentinel-ai-processor/{version}/train/"
    VAL_DIR = f"/mnt/disk/dataset/sentinel-ai-processor/{version}/val/"
    TEST_DIR = f"/mnt/disk/dataset/sentinel-ai-processor/{version}/test/"

    df_train_input, df_train_output =  prepare_paths(TRAIN_DIR)
    df_val_input, df_val_output =  prepare_paths(VAL_DIR)
    df_test_input, df_test_output =  prepare_paths(TEST_DIR)

    logger.info(f"Number of training samples: {len(df_train_input)}")
    logger.info(f"Number of validation samples: {len(df_val_input)}")
    logger.info(f"Number of test samples: {len(df_test_input)}")

    train_dataset = Sentinel2Dataset(df_x=df_train_input, df_y=df_train_output, train=True, augmentation=False, img_size=resize)
    val_dataset = Sentinel2Dataset(df_x=df_val_input, df_y=df_val_output, train=True, augmentation=False, img_size=resize)
    test_dataset = Sentinel2Dataset(df_x=df_test_input, df_y=df_test_output, train=True, augmentation=False, img_size=resize)



    train_loader, val_loader = define_loaders(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        train=True,
        batch_size=config['TRAINING']['batch_size'],
        num_workers=config['TRAINING']['num_workers'])

    test_loader = define_loaders(
        train_dataset=test_dataset,
        val_dataset=None,
        train=False,
        batch_size=config['TRAINING']['batch_size'],
        num_workers=config['TRAINING']['num_workers'])

    return train_loader, val_loader, test_loader


In [55]:
config = load_config(config_path="cfg/config.yaml")
version = "V3"
resize = 1024
TEST_DIR = f"/mnt/disk/dataset/sentinel-ai-processor/{version}/test/"
df_input, df_output = prepare_paths(TEST_DIR)

In [56]:
def normalize(data_array):
    """
    Normalize each band in a multi-band image array to the range [0, 1].

    For each band, pixels with values > 0 are considered valid.
    Normalization is done based on the min and max of valid pixels only.
    Invalid pixels (≤ 0) are set to 0 after normalization.

    Parameters
    ----------
    data_array : np.ndarray
        Input image array of shape (H, W, C), where H = height, W = width, C = number of bands.

    Returns
    -------
    normalized_data : np.ndarray
        Normalized data array with values in [0, 1], same shape as input.
    valid_masks : np.ndarray
        Boolean mask array indicating valid pixels, shape (H, W, C).
    norm_params : list of tuple
        List of (min, max) values used for normalization per band.
    """

    normalized_data = []
    valid_masks = []
    norm_params = []
    for i in range(data_array.shape[2]):
        band_data = data_array[:, :, i]
        valid_mask = (band_data > 0)
        valid_pixels = band_data[valid_mask]
        min_val = np.min(valid_pixels)
        max_val = np.max(valid_pixels)
        norm_params.append((min_val, max_val))

        result = band_data.copy().astype(np.float32)
        result[valid_mask] = (valid_pixels - min_val) / (max_val - min_val)
        result[~valid_mask] = 0.0
        normalized_data.append(result)
        valid_masks.append(valid_mask)
    return np.dstack(normalized_data), np.dstack(valid_masks), norm_params

def denormalize(norm_data, valid_masks, norm_params):
    """
    Denormalize a normalized multi-band image back to its original value range and convert to integers.

    Parameters
    ----------
    norm_data : ndarray
        A 3D NumPy array of normalized data (H, W, C), where values are in the range [0, 1].

    valid_masks : ndarray
        Boolean mask array of shape (H, W, C) indicating which pixels were originally valid.

    norm_params : list of tuples
        List of (min, max) values per band used during normalization.

    Returns
    -------
    restored_data : ndarray
        Denormalized image array of shape (H, W, C) with dtype uint16.
    """

    restored_data = []
    for i in range(norm_data.shape[2]):
        band = norm_data[:, :, i]
        valid_mask = valid_masks[:, :, i]
        min_val, max_val = norm_params[i]

        restored_band = band.copy()
        restored_band[valid_mask] = band[valid_mask] * (max_val - min_val) + min_val

        restored_band[~valid_mask] = 0.0
        restored_data.append(np.round(restored_band).astype(np.uint16))
    return np.dstack(restored_data)

def read_images(product_paths):
    """
    Read and stack a list of grayscale image files into a multi-band image.

    Parameters
    ----------
    product_paths : list of str
        List of file paths to the images to read.

    Returns
    -------
    images : ndarray
        A 3D NumPy array of shape (H, W, C), where each image is treated as one band (C channels).
    """

    images = []
    for path in product_paths:
        data = Image.open(path)
        data = np.array(data)
        images.append(data)
    # H x W x C
    images = np.dstack(images)
    return images

In [57]:
random_index = np.random.choice(df_input.index)
random_row = df_input.loc[random_index]

x_paths = natsort.natsorted(glob.glob(os.path.join(df_input["path"][random_index], "*.png")))
y_paths = natsort.natsorted(glob.glob(os.path.join(df_output["path"][random_index], "*.png")))

x_data = read_images(x_paths)
y_data = read_images(y_paths)

x_data_norm, x_valid_masks, x_norm_params = normalize(x_data)
y_data_norm, y_valid_masks, y_norm_params = normalize(y_data)

new_x_data = denormalize(x_data_norm, x_valid_masks, x_norm_params)
new_y_data = denormalize(y_data_norm, y_valid_masks, y_norm_params)
