In [None]:
# Copyright 2025 AIT Austrian Institute of Technology GmbH
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Author: Miguel Castells

In [None]:
import os
import json
import numpy as np 
from PIL import Image
import torchvision
import torch
import torchvision.transforms.functional as Fvis
from torchmetrics.image import PeakSignalNoiseRatio
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage.exposure import match_histograms
from tqdm.auto import tqdm
from IPython.display import clear_output

# Manual pick the low resolution images with worst PSNR after generation 
## Model used for generation :
* Pretraining : SEN2NAIPv2-UNET
* Finetuning : Our dataset 

In [None]:
def load_json(json_path):
    with open(json_path, 'r') as file:
        res = json.load(file)
    return res

def extract_metrics(results_dict, metric_name):
    """
    results_dict: dict like your example, keys = image/folder names, values = dict of metrics
    metric_name: str, e.g. 'psnr_rgb', 'clip_rgb', etc.

    Returns:
        sorted_metrics: list of tuples (image_name, metric_value), sorted descending by metric_value
    Also plots histogram of metric values with mean line and label.
    """

    # Extract (name, metric_value) pairs, ignoring entries without the metric
    data = [(name, metrics[metric_name]) for name, metrics in results_dict.items() if metric_name in metrics]

    # Sort by metric descending
    sorted_metrics = sorted(data, key=lambda x: x[1], reverse=True)

    return sorted_metrics

def get_worst_chips_by_threshold(psnr_list, threshold):
    """
    psnr_list: list of tuples (chip_name, psnr)
    threshold: float, PSNR value below which chips are considered "worst"

    Returns:
        list of chip names with PSNR < threshold
    """
    if not psnr_list:
        return []

    worst = [chip for chip, psnr in psnr_list if psnr < threshold]
    return worst

In [None]:
def find_hr_path(hr_root, chip):
    """
    Search recursively in HR folder to find 'rgb.png' under a directory ending with chip.
    """
    for root, dirs, files in os.walk(hr_root):
        if chip in root.split(os.sep) and 'rgb.png' in files:
            return os.path.join(root, 'rgb.png')
    return None

def get_list_unique_id_chip(chips, hr_root):
    """
    Return list of unique_id/chip for a given list of chip names.
    """
    identifiers = []
    for chip in chips:
        hr_path = find_hr_path(hr_root, chip)
        
        if hr_path:
            identifiers.append(os.path.join(hr_path.split('/')[-3], chip))
        else:
            print(f" Missing image for chip: {chip}")
    return identifiers

In [None]:
def load_images(unique_id_chip, rgb_dir, nir_dir):
    '''
    Loads the high-resolution (NAIP) and low-resolution (Sentinel-2) images for both RGB and NIR channels,
    given a unique ID and chip identifier.

    Args:
        unique_id_chip (str): A string formatted as "<unique_id>/<chip_coord>".
        valid_index (list): List of valid LR indices to extract from the LR tensor.
        rgb_dir (str): Root directory containing NAIP and Sentinel-2 RGB images.
        nir_dir (str): Root directory containing NAIP NIR images.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: A tuple containing the high-resolution image tensor (HR)
        and the selected low-resolution image tensors (LRS), both with RGB + NIR channels.
    '''
    unique_id, chip_coord = unique_id_chip.split('/')[0], unique_id_chip.split('/')[1]
    hr_path = os.path.join(rgb_dir, 'naip', unique_id, chip_coord, 'rgb.png')
    lr_path = os.path.join(rgb_dir, 'sentinel2', chip_coord, 'tci.png')

    nir_lr_path = str(lr_path.replace('tci.png', 'b08.png'))
    lr_nir = Fvis.pil_to_tensor(Image.open(nir_lr_path))
    lrs_nir = torch.reshape(lr_nir, (-1, 32,32)).unsqueeze(0).permute(1,0,2,3)

    nir_hr_path = os.path.join(nir_dir, 'naip', unique_id, chip_coord, 'nir.png')
    hr_nir = Fvis.pil_to_tensor(Image.open(nir_hr_path)) 

    hr_rgb = Fvis.pil_to_tensor(Image.open(hr_path))
    lr_tensor = Fvis.pil_to_tensor(Image.open(lr_path))
    lrs_rgb = torch.reshape(lr_tensor, (3,-1, 32,32)).permute(1,0,2,3)

    lrs = torch.cat([lrs_rgb, lrs_nir], dim=1)
    hr = torch.cat([hr_rgb, hr_nir], dim=0)

    return hr, lrs

In [None]:
def step_1_cloud_removal(hr, lrs, device, num_imgs=16):
    '''
    Performs a naive cloud-removal step by selecting the `num_imgs_m1` best-matching low-resolution (LR)
    images using Peak Signal-to-Noise Ratio (PSNR) between bicubic-downsampled HR and LR images.

    Args:
        hr (torch.Tensor): High-resolution image tensor (C, H, W).
        lrs (torch.Tensor): Low-resolution image tensor batch (N, C, H, W).
        num_imgs_m1 (int): Number of LR patches to keep.
        device (str): Device string ("cuda" or "cpu").

    Returns:
        torch.Tensor: Indices of the top `num_imgs_m1` LR images.
    '''
    lrs = lrs.to(device)
    num_lr = lrs.shape[0]
    hr_res = F.interpolate(hr.unsqueeze(0).float(), (32,32), mode='bicubic').to(device)
    hrs_res = hr_res.expand((num_lr, 4, 32, 32))

    if num_imgs < num_lr:
        psnr = PeakSignalNoiseRatio(reduction=None, dim=[1,2,3], data_range=255).to(device)
        scores = psnr(lrs, hrs_res)
        indices = torch.argsort(scores, descending=True)[:num_imgs]
    else:
        indices = torch.arange(num_lr, dtype=torch.uint8)

    return hr.cpu(), lrs.cpu(), indices.cpu()

def get_indice_from_pip(chip, tracker_path):
    """
    Get the indice chosen from our pipeline.
    """
    tracker = load_json(tracker_path)
    return tracker[chip]

# Hyperparameters of plotting
font = 50
num_col = 3

def plot_one_row(hr, lrs, indices, chip, tracker_path):
    """
    Plot the High resolution NAIP (left), the same High resolution after histigram matched with the low resolution 
    as reference (middle) and the low resolution (right) for each low-resolution
    """
    m = len(indices)  
    fig, axes = plt.subplots(m//num_col+1, num_col, figsize=(4.2 * (m + 1), 15*(m//num_col)))
    
    ind_pip = get_indice_from_pip(chip, tracker_path)

    for ax in axes.flat:
        ax.axis('off')

    axes = axes.flatten()

    fig.suptitle(f'HR | HR (histogram matched) | LR\nindice of pipeline : {ind_pip}', fontsize=2*font)

    for i, ind in enumerate(indices):
        lr_res = (F.interpolate(lrs[ind, :3, :, :].unsqueeze(0).float(), hr.shape[1:], mode='bicubic').squeeze(0))
        hr_numpy = hr[:3].permute(1,2,0).numpy()
        lr_numpy = lr_res.permute(1,2,0).int().numpy()
        hr_matched = match_histograms(hr_numpy, lr_numpy, channel_axis=-1)
        padding_tensor = np.ones((hr_numpy.shape[0], 2, 3))*255
        to_plot = np.concatenate([
            hr_numpy, 
            padding_tensor, 
            hr_matched,
            padding_tensor,
            lr_numpy
            ], axis=1)
        axes[i].imshow(np.clip(to_plot/255, a_min=0., a_max=1.))
        axes[i].set_title(f'indice: {ind}', fontsize=font)
        plt.tight_layout()

In [None]:
def manual_picking(identifiers, rgb_dir, nir_dir, save_dir, tracker_path, device='cuda'):
    """
    Function that wait for the user to input the right index for picking the low-resolution: 
        list_paths: List of paths, representing the pairs we want to manually pick the low-resolution
        save_dir: Where to store the json output of the function that stores all the index picked by the user.
        tracker_path: path of the pipeline tracker.json to get the pipeline index 

    Return:
        json file consisting of a dictionnary: {"unique_id/chip": index_chosen_by_user} 
    """
    d = {}
    for unique_id in identifiers:
        hr, lrs = load_images(unique_id, rgb_dir, nir_dir)
        hr, lrs, indices_cloud_removal = step_1_cloud_removal(hr, lrs, device)
        plot_one_row(hr, lrs, indices_cloud_removal, unique_id, tracker_path)

        plt.show()  
        
        best_indice = input('Enter the indice as in the plot (or type "None" if no acceptable image): ')

        clear_output(wait=True)

        if best_indice.strip().lower() == 'none':
            best_indice_val = None
        else:
            try:
                best_indice_val = int(best_indice)
            except ValueError:
                # If not an int and not "None", just keep the raw string or handle error as you prefer
                best_indice_val = best_indice
        
        d[os.path.join(unique_id, chip)] = best_indice_val

    with open(save_dir, 'w') as file:
        json.dump(d, file)

# Main Pipeline for manual picking

## Parameters
 * result_path: Path of the results 
 * metric_name: metrics available ["psnr_rgb", "psnr_all", "clip_rgb", "ssim_rgb", "ssim_all", "lpips_rgb"], CLIP is Git-SCLIP
 * pipeline_directory: The path of the director where the pipeline images are stored
 * S2NAIP_directory:  The path of the director where the S2NAIP images are stored
 * nir_dir: directory of the nir band of NAIP images
 * json_save_path: path where to store the result

In [None]:
results_path = "/workspace/framework_choose_lr/dev/compute_scores_test/results_scores/TEST_finetune_satlas_unet_hm_nir.json" 
metric_name = "psnr_rgb" 

pipeline_directory = "/workspace/pipeline-data/dataset_from_pipeline/ssim_0-7_hr-lr_0-3_hr-hrharm_NIR/test/"
pipeline_hr_directory = os.path.join(pipeline_directory, "naip") 
pipeline_lr_directory = os.path.join(pipeline_directory, "sentinel2") 

S2NAIP_directory = '/workspace/readonly-satlas/val_set'
nir_dir = '/workspace/readonly-satlasNIR/nirres/val_set' 

tracker_path = os.path.join(pipeline_directory, "tracker.json") 

In [None]:
# Retrieve scores :
with open(results_path, 'r') as file:
    score_results = json.load(file)

# Extract psnr_rgb sorted : 
sorted_metric = extract_metrics(score_results, metric_name) 

# Retrieve the worst metric with respect to a threshold :
metric_thresh = 17
worst = get_worst_chips_by_threshold(sorted_metric, metric_thresh)
# print(len(worst_17))  # 166

identifiers = get_list_unique_id_chip(worst, os.path.join(S2NAIP_directory, 'naip'))


In [None]:
# We divise the work on sub-works
slices = [
    slice(0, 20),
    slice(20, 40),
    slice(40, 60),
    slice(60, 80),
    slice(80, 100),
    slice(100, 120),
    slice(120, 140),
    slice(140, 160),
    slice(160, 166) 
]
i=1 # 0,1,2,3,4,5,6,7
json_save_path = f'/workspace/framework_choose_lr/dev/Manual_test/manual_picking_json/manual_pick_{i*20}_{((i+1)*20)-1}.json'
manual_picking(identifiers[slices[i]], S2NAIP_directory, nir_dir, json_save_path, tracker_path)
