In [None]:
import rasterio as rio
from pathlib import Path
import numpy as np
from tqdm.auto import tqdm
from matplotlib import pyplot as plt

In [None]:
t = Path("/media/nick/SNEAKERNET/SA working v6 2022-2023/scenes/18FWF.tif")
m = Path("/media/nick/SNEAKERNET/SA working v6 2022-2023/scenes/18FWF_scl.tif")

In [None]:
array = rio.open(t).read()
scl = rio.open(m).read(out_shape=(5490, 5490))  # .astype(np.float32)

In [None]:
array.shape, scl.shape

In [None]:
array = array.reshape(3, 12, 10980, 10980)
# array = array.reshape(3, 12, 5490, 5490)

scl = scl.reshape(3, 6, 5490, 5490)

In [None]:
scl.min(), scl.max()

In [None]:
def combine_orbits(
    all_orbits_bands: np.ndarray, target_band_count: int, pbar: tqdm
) -> np.ndarray:
    """
    Combines multiple orbits of bands into a single array.
    """
    # all_orbits_bands = np.moveaxis(all_orbits_bands, 0, 1)

    out_shape = (target_band_count, *all_orbits_bands.shape[2:])
    out_array = np.zeros(out_shape, dtype=np.float32)
    tracking_array = np.zeros(out_shape, dtype=np.uint16)

    bands_per_scene = 2

    pbar.reset()
    pbar.set_description(f"Combining")
    pbar.total = all_orbits_bands.shape[0] * (
        all_orbits_bands.shape[1] // bands_per_scene
    )

    for band_index in range(0, target_band_count, bands_per_scene):
        for orbit in range(all_orbits_bands.shape[0]):
            both_bands = all_orbits_bands[
                orbit,
                band_index : band_index + bands_per_scene,
            ]
            print(both_bands.min(), both_bands.max())
            # print(both_bands.shape)
            # if 0s in either bands, set to 0
            data_mask = np.all(both_bands != 0, axis=0)

            # expand first dimension to match target array
            data_mask = np.expand_dims(data_mask, axis=0)
            # suplicate first dimension to match target array
            data_mask = np.repeat(data_mask, 2, axis=0)

            out_array[band_index : band_index + bands_per_scene][
                data_mask
            ] += both_bands[data_mask]

            tracking_array[band_index : band_index + bands_per_scene] += data_mask
            pbar.update(1)
    # if tracking_array has 0s, set to 1 to avoid divide by zero
    tracking_array[tracking_array == 0] += 1
    print(tracking_array.min(), tracking_array.max())
    out_array = (out_array / tracking_array.astype(np.float32)).astype(np.uint16)
    pbar.close()
    return out_array

In [None]:
pbar = tqdm()
r = combine_orbits(array, 12, pbar)

In [None]:
r.min(), r.max()

In [None]:
from helpers.inference import run_inference

model_path = Path("models/regnety_002_v1.31_model.pkl")

In [None]:
run_inference(
    model_path,
    "pppp.tif",
    r,
    rio.open(t).profile,
    pbar=pbar,
)

In [None]:
index = 10
plt.imshow(r[index], vmin=0, vmax=2000)

In [None]:
profile = rio.open(t).profile
profile.update(dtype=np.uint16, count=12)

with rio.open("test.tif", "w", **profile) as dst:
    dst.write(r.astype(np.uint16))

In [None]:
def combine_orbits_old(
    all_orbits_bands: np.ndarray,
    all_orbits_scls: np.ndarray,
    target_band_count: int,
    pbar: tqdm,
) -> np.ndarray:
    """
    Combines multiple orbits of bands into a single array.
    input shape is orbit, band, x, y such as 3, 12, 10980, 10980
    """
    # reorder to band, orbit, x, y
    all_orbits_bands = np.moveaxis(all_orbits_bands, 0, 1)

    # make output array
    out_shape = (target_band_count, *all_orbits_bands.shape[2:])
    out_array = np.zeros(out_shape, dtype=np.uint16)
    # setup progress bar
    pbar.reset()
    pbar.set_description(f"Combining")
    pbar.total = len(all_orbits_bands)
    # loop through bands
    for index, multi_orbit_bands in enumerate(all_orbits_bands):
        # make empty array for one band
        target_array = np.zeros(multi_orbit_bands.shape[1:], dtype=np.float32)
        # make array to track how many values are in each pixel
        tracking_array = np.zeros(multi_orbit_bands.shape[1:], dtype=np.float32)
        for band in multi_orbit_bands:
            mask = band != 0
            target_array[mask] += band[mask].astype(np.float32)
            tracking_array[mask] += 1
        # add 1 to tracking array where its currently 0 to avoid divide by zero
        tracking_array[tracking_array == 0] += 1
        out_array[index] = (target_array / tracking_array).astype(np.uint16)
        pbar.update(1)

    return out_array

In [None]:
pbar = tqdm()
p = combine_orbits_old(array, scl, 12, pbar)

In [None]:
index = 10
plt.imshow(r[index], vmin=0, vmax=2000)

In [None]:
plt.imshow(p[index], vmin=0, vmax=2000)

In [None]:
plt.imshow(p[index] - r[index])

In [None]:
r.shape