In [1]:
import os
from glob import glob
from typing import Dict, List, Optional, Tuple, Union

import dask
import dask.array as da
import dask.bag
import matplotlib.pyplot as plt
import numpy as np
from dask.diagnostics import ProgressBar
from dask.distributed import Client
from PIL import Image
from tqdm import tqdm

pbar = ProgressBar()
pbar.register()


%matplotlib inline

data_dir = "/media/xultaeculcis/2TB/datasets/sr/fine-tuning/wc/weather"
output_dir = os.path.join(data_dir, "scaled")

In [2]:
variables = [
    "pre",
    "tmin",
    "tmax",
]
resolutions = ["2.5m"]


def get_files(
    data_dir: str, variables: List[str], resolutions: List[str]
) -> Dict[str, Dict[str, List[str]]]:
    results = {}
    for var in variables:
        results[var] = {}
        for res in resolutions:
            pattern = os.path.join(data_dir, var, "**", f"*_{res}*.tif")
            print(pattern)
            results[var][res] = sorted(glob(pattern, recursive=True))

    return results


fpaths = get_files(data_dir, variables, resolutions)

/media/xultaeculcis/2TB/datasets/sr/fine-tuning/wc/weather/pre/**/*_2.5m*.tif
/media/xultaeculcis/2TB/datasets/sr/fine-tuning/wc/weather/tmin/**/*_2.5m*.tif
/media/xultaeculcis/2TB/datasets/sr/fine-tuning/wc/weather/tmax/**/*_2.5m*.tif


In [3]:
def plot_clim_from_file_path(img_path: str) -> None:
    img = Image.open(img_path)
    arr = np.array(img, dtype=np.float)
    arr[arr == -32768] = np.nan

    fig = plt.figure(figsize=(20, 8))
    fig.suptitle(os.path.basename(img_path), fontsize=20)
    ax = fig.add_subplot(1, 1, 1)
    ax.set_aspect("equal")
    plt.imshow(arr, interpolation="nearest", cmap=plt.cm.jet)
    plt.colorbar()
    plt.show()


def plot_clim_from_array(arr: np.ndarray, title="") -> None:
    fig = plt.figure(figsize=(20, 8))
    fig.suptitle(os.path.basename(title), fontsize=20)
    ax = fig.add_subplot(1, 1, 1)
    ax.set_aspect("equal")
    plt.imshow(arr, interpolation="nearest", cmap=plt.cm.jet)
    plt.colorbar()
    plt.show()

In [4]:
class WorldClimScaler:
    """
    The custom Min-Max scaler for the WorldClim dataset.
    """

    def __init__(
        self,
    ):
        self.feature_range = (0.0, 1.0)
        self.lower_bound, self.upper_bound = self.feature_range
        self.nan_replacement = 0.0

    def fit(self, fpaths=List[str]) -> None:
        self.fpaths = fpaths

        def find_min_max(file_path: str) -> Tuple[float, float]:
            img = Image.open(file_path)
            arr = np.array(img, dtype=np.float)
            arr[arr == -32768] = np.nan
            xmin = np.nanmin(arr)
            xmax = np.nanmax(arr)

            return xmin, xmax

        c = Client(n_workers=8, threads_per_worker=1)

        results = (
            dask.bag.from_sequence(self.fpaths, npartitions=len(self.fpaths))
            .map(find_min_max)
            .compute()
        )

        c.close()

        current_min = 9999
        current_max = -9999

        for xmin, xmax in results:
            if current_min > xmin:
                current_min = xmin
            if current_max < xmax:
                current_max = xmax

        self.org_data_min_ = current_min
        self.org_data_max_ = current_max
        self.data_min_ = current_min - 1
        self.data_max_ = current_max + 1
        self.data_range_ = self.data_max_ - self.data_min_
        self.scale_ = (self.upper_bound - self.lower_bound) / self.data_range_
        self.min_ = self.lower_bound - self.data_min_ * self.scale_

    def transform_single(
        self,
        X: np.ndarray,
    ) -> None:
        X_std = (X - self.data_min_) / (self.data_max_ - self.data_min_)
        X_scaled = X_std * (self.upper_bound - self.lower_bound) + self.lower_bound
        X_scaled[np.isnan(X_scaled)] = self.nan_replacement
        return X_scaled

    def transform(self, fpaths: List[str], out_dir: str):
        os.makedirs(out_dir, exist_ok=True)

        def transform_(
            fpath: str,
            out_dir: str,
            min_val: float,
            max_val: float,
            lower: Optional[float] = 0.0,
            upper: Optional[float] = 1.0,
            nan_replacement: Optional[float] = 0.0,
        ) -> None:
            im_name = os.path.basename(os.path.splitext(fpath)[0]) + ".tiff"
            if os.path.exists(im_name):
                return
            X = np.array(Image.open(fpath), dtype=np.float)
            X_std = (X - self.data_min_) / (self.data_max_ - self.data_min_)
            X_scaled = X_std * (self.upper_bound - self.lower_bound) + self.lower_bound
            X_scaled[np.isnan(X_scaled)] = self.nan_replacement

            im = Image.fromarray(X_scaled)
            im.save(os.path.join(out_dir, im_name))

        c = Client(n_workers=8, threads_per_worker=1)

        results = (
            dask.bag.from_sequence(fpaths, npartitions=len(fpaths))
            .map(
                transform_,
                out_dir=out_dir,
                min_val=self.data_min_,
                max_val=self.data_max_,
                lower=self.lower_bound,
                upper=self.upper_bound,
                nan_replacement=self.nan_replacement,
            )
            .compute()
        )

        c.close()

    def fit_transform(self, fpaths: List[str], out_dir: str) -> None:
        self.fit(fpaths)
        self.transform(fpaths, out_dir)

    def inverse_transform(self, X: np.ndarray) -> np.ndarray:
        """Undo the scaling of X according to feature_range."""
        X = X.copy()
        X[X == 0.0] = np.nan
        return X * (self.data_max_ - self.data_min_) + self.data_min_

In [5]:
for var in variables:
    scaler = WorldClimScaler()
    scaler.fit(fpaths=fpaths[var]["2.5m"])
    scaler.transform(fpaths=fpaths[var]["2.5m"], out_dir=os.path.join(output_dir, var))

In [6]:
elev_files = glob(
    "/media/xultaeculcis/2TB/datasets/sr/fine-tuning/wc/elevation/*_2.5m*.tif"
)
scaler = WorldClimScaler()
scaler.fit(elev_files)
scaler.transform(fpaths=elev_files, out_dir=os.path.join(output_dir, "elev"))