In [1]:
"""Data utilities to grab data from google cloud bucket.

Meant to be used for both training and prediction so the model is
trained on exactly the same data that will be used for predictions.
"""

from __future__ import annotations

import io
import logging
import os
import re
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict

import google.auth
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import requests
from serving.constants import (
    BUCKET,
    HIST_BINS_LIST,
    HIST_DEST_PREFIX,
    IMG_SOURCE_PREFIX,
    SELECTED_BANDS,
    PROJECT,
    PIX_COUNT,
    REFLECTANCE_CONST,
    NUM_BINS,
    MAP_NAN,
    NORMALIZE
)
from google.api_core import exceptions, retry
from google.cloud import storage
from numpy.lib.recfunctions import structured_to_unstructured
from osgeo import gdal
from rasterio.io import MemoryFile
from serving.common import list_blobs_with_prefix

logging.basicConfig(
    filename="hist.log",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)


def hist_init():
    """Authenticate and initialize Earth Engine with the default credentials."""
    # Use the Earth Engine High Volume endpoint.
    #   https://developers.google.com/earth-engine/cloud/highvolume
    credentials, project = google.auth.default()

def process_band(bucket, blob_name, band, bins, skip_nan, normalise):

    storage_client = storage.Client()
    blob = storage_client.bucket(bucket).blob(blob_name)

    with blob.open("rb") as f:
        with rasterio.open(f) as src:

            data = src.read(band).flatten()
            na_mask = np.isnan(data)
            
            if skip_nan == False:
                data[na_mask] = 0.0
                valid_data = data
            else:
                valid_data = data[~na_mask]
                
            if normalise:
                valid_data = valid_data / REFLECTANCE_CONST
                bins = bins / REFLECTANCE_CONST
                
            valid_max = np.max(valid_data)
            valid_min = np.min(valid_data)

            if valid_max > bins[-1]:
                logging.warning(
                    f"image: {blob_name}, band: {band}, {valid_max} value is larger than assumed possible values for this band: {bins[-1]}"
                )
            elif valid_min < bins[0]:
                logging.warning(
                    f"image: {blob_name}, band: {band}, {valid_min} value is smaller than assumed possible values for this band {bins[0]}"
                )
            
            if valid_data.size > 0:
                total_sum = np.sum(valid_data)
                total_count = valid_data.size
                mean = total_sum / total_count
                hist, _ = np.histogram(valid_data, bins=bins, density=False)
            else:
                logging.error(f"image: {blob_name}, band: {band} has 0 valid pixels. Investigate")
                mean = np.nan
                hist = np.zeros_like(
                    bins[:-1]
                )  # histogram will have one less element than bins

    return hist


def process_tiff(bucket, blob_name, bin_list, selected_bands, skip_nan, normalise, max_workers=6):
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_band = {
            executor.submit(
                process_band, bucket, blob_name, band, bins, skip_nan, normalise
            ): band
            for band, bins in zip(selected_bands, bin_list)
        }
        results = []

        for future in as_completed(future_to_band):
            band = future_to_band[future]
            try:
                result = future.result()
                results.append(result)
                logging.info(f"Processed band {band} successfully")
            except Exception as exc:
                logging.exception(f"Band {band} generated an exception: {exc}")

    sorted_results = sorted(results, key=lambda x: x[0])
    return np.array(sorted_results).flatten()  # one long array instead of bands


def recombine_image(bucket, core_image_name, bin_list, selected_bands, skip_nan=False, normalise=False):
    start_time = time.time()

    hist_per_blob = []
    blobs = list_blobs_with_prefix(core_image_name)
    for blob in blobs:
        results = process_tiff(bucket, blob.name, bin_list, selected_bands, skip_nan, normalise)
        hist_per_blob.append(results)

    combined_hist = np.sum(np.array(hist_per_blob), axis=0)

    end_time = time.time()
    execution_time = end_time - start_time
    logging.info(
        f"Image {core_image_name} has been processed in {execution_time/60:.4f} minuntes"
    )

    return combined_hist

def write_histogram_to_gcs(histogram, bucket_name, blob_name):
    """
    Write a NumPy array (histogram) to Google Cloud Storage.

    Args:
    histogram (np.array): The histogram to save.
    bucket_name (str): The name of the GCS bucket.
    blob_name (str): The name to give the file in GCS (including any 'path').

    Returns:
    str: The public URL of the uploaded file.
    """
    # Ensure the blob_name ends with .npy
    if not blob_name.endswith('.npy'):
        blob_name += '.npy'

    # Create a GCS client
    client = storage.Client()

    # Get the bucket
    bucket = client.bucket(bucket_name)

    # Create a blob
    blob = bucket.blob(blob_name)

    # Convert the numpy array to bytes
    array_bytes = io.BytesIO()
    np.save(array_bytes, histogram)
    array_bytes.seek(0)

    # Upload the bytes to GCS
    blob.upload_from_file(array_bytes, content_type='application/octet-stream')

    logging.info(f"Histogram uploaded to gs://{bucket_name}/{blob_name}")



In [17]:
image_name="images/60/Sutter_06/2021/5-6"
file_name="histograms/nan_map_True/norm_True/32_buckets/60/Sutter_06/2021/5-6.npy"
     

In [18]:
hist = recombine_image(BUCKET, image_name, HIST_BINS_LIST, SELECTED_BANDS, MAP_NAN, NORMALIZE)

In [19]:
96/3

32.0

In [20]:
hist

array([   0,    0,    0,    0,  432, 1695,  593,  289,  175,  101,   86,
         80,   82,   78,   46,   52,   33,   19,    9,    7,    8,    4,
          1,    1,    1,    3,    1,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,  424, 1971,  592,  310,  132,
        141,   96,   81,   19,   17,    9,    3,    1,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    1,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0, 2560,  623,  194,  215,
         78,   73,   39,   11,    1,    2,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    1,    0])

In [10]:
client = storage.Client()
bucket = client.get_bucket(BUCKET)
hist_blob = bucket.blob(file_name)
content = hist_blob.download_as_bytes()
binary_data = io.BytesIO(content)
array = np.load(binary_data)

In [16]:
array

array([   0,    0,    0,    0,    0,    0,    0,  424, 1971,  592,  310,
        132,  141,   96,   81,   19,   17,    9,    3,    1,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    1,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0, 2560,  623,  194,
        215,   78,   73,   39,   11,    1,    2,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    1,    0,    0,    0,
          0,    0,  432, 1695,  593,  289,  175,  101,   86,   80,   82,
         78,   46,   52,   33,   19,    9,    7,    8,    4,    1,    1,
          1,    3,    1,    0,    0,    0,    0,    0])

In [21]:
lstm_hist = hist.reshape(-1,3,32)

In [22]:
lstm_hist

array([[[   0,    0,    0,    0,  432, 1695,  593,  289,  175,  101,
           86,   80,   82,   78,   46,   52,   33,   19,    9,    7,
            8,    4,    1,    1,    1,    3,    1,    0,    0,    0,
            0,    0],
        [   0,    0,    0,    0,    0,    0,    0,  424, 1971,  592,
          310,  132,  141,   96,   81,   19,   17,    9,    3,    1,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            1,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0,    0, 2560,
          623,  194,  215,   78,   73,   39,   11,    1,    2,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            1,    0]]])

In [None]:
[[0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  4.240e+02 1.971e+03 5.920e+02 3.100e+02 1.320e+02 1.410e+02 9.600e+01
  8.100e+01 1.900e+01 1.700e+01 9.000e+00 3.000e+00 1.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 1.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 2.560e+03
  6.230e+02 1.940e+02 2.150e+02 7.800e+01 7.300e+01 3.900e+01 1.100e+01
  1.000e+00 2.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 4.320e+02 1.695e+03
  5.930e+02 2.890e+02 1.750e+02 1.010e+02 8.600e+01 8.000e+01 8.200e+01
  7.800e+01 4.600e+01 5.200e+01 3.300e+01 1.900e+01 9.000e+00 7.000e+00
  8.000e+00 4.000e+00 1.000e+00 1.000e+00 1.000e+00 3.000e+00 1.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.620e+02
  4.270e+02 6.760e+02 7.660e+02 6.520e+02 3.570e+02 2.740e+02 1.860e+02
  1.190e+02 8.800e+01 3.600e+01 2.200e+01 1.400e+01 7.000e+00 2.000e+00
  2.000e+00 0.000e+00 2.000e+00 5.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.500e+01 4.750e+02
  1.490e+03 1.082e+03 3.700e+02 2.090e+02 9.600e+01 3.800e+01 1.100e+01
  6.000e+00 3.000e+00 0.000e+00 0.000e+00 1.000e+00 0.000e+00 1.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 3.860e+02 1.641e+03 1.196e+03
  3.960e+02 1.200e+02 4.000e+01 1.500e+01 1.000e+00 0.000e+00 0.000e+00
  0.000e+00 1.000e+00 0.000e+00 1.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00]
 [0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00 7.000e+00 4.600e+01
  7.600e+01 6.300e+01 3.400e+01 3.400e+01 1.040e+02 1.340e+02 1.390e+02
  2.230e+02 2.930e+02 3.070e+02 3.360e+02 2.480e+02 1.100e+02 7.400e+01
  9.900e+01 1.540e+02 1.670e+02 1.760e+02 1.650e+02 1.850e+02 1.340e+02
  1.230e+02 1.720e+02 6.600e+01 5.000e+01 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 3.400e+01
  1.560e+02 2.130e+02 4.160e+02 4.730e+02 3.800e+02 5.910e+02 4.660e+02
  5.330e+02 4.050e+02 1.130e+02 9.000e+00 7.000e+00 0.000e+00 1.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
  0.000e+00 6.000e+00 9.200e+01 9.200e+01 2.160e+02 3.770e+02 4.050e+02
  2.580e+02 3.220e+02 4.160e+02 2.010e+02 3.070e+02 3.290e+02 3.530e+02
  2.510e+02 1.490e+02 7.000e+00 7.000e+00 9.000e+00 0.000e+00 0.000e+00
  0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00]]

In [27]:
lstm_hist.shape

(1, 3, 32)

In [23]:
skip_nan = True
normalise = True
bins = HIST_BINS_LIST[0]
blob_name = "images/60/Sutter_06/2021/5-6.tif"
storage_client = storage.Client()
blob = storage_client.bucket(BUCKET).blob(blob_name)
band = 2
with blob.open("rb") as f:
    with rasterio.open(f) as src:

        data = src.read(band).flatten()
        na_mask = np.isnan(data)

        if skip_nan == False:
            data[na_mask] = 0.0
            valid_data = data
        else:
            valid_data = data[~na_mask]

        if normalise:
            valid_data = valid_data / REFLECTANCE_CONST
            bins = bins / REFLECTANCE_CONST

        valid_max = np.max(valid_data)
        valid_min = np.min(valid_data)

        if valid_max > bins[-1]:
            logging.warning(
                f"image: {blob_name}, band: {band}, {valid_max} value is larger than assumed possible values for this band: {bins[-1]}"
            )
        elif valid_min < bins[0]:
            logging.warning(
                f"image: {blob_name}, band: {band}, {valid_min} value is smaller than assumed possible values for this band {bins[0]}"
            )

        if valid_data.size > 0:
            total_sum = np.sum(valid_data)
            total_count = valid_data.size
            mean = total_sum / total_count
            hist, _ = np.histogram(valid_data, bins=bins, density=False)
        else:
            logging.error(f"image: {blob_name}, band: {band} has 0 valid pixels. Investigate")
            mean = np.nan
            hist = np.zeros_like(
                bins[:-1]
            ) 

In [24]:
data

array([nan, nan, nan, ..., nan, nan, nan], dtype=float32)

In [25]:
len(data)

1050460

In [7]:
valid_data

array([0.09785, 0.10395, 0.0999 , ..., 0.1089 , 0.1187 , 0.1149 ],
      dtype=float32)

In [68]:
sorted_valid_data = sorted(valid_data)

In [69]:
sorted_valid_data[-5:]

[np.float32(0.1533),
 np.float32(0.1605),
 np.float32(0.1694),
 np.float32(0.17225),
 np.float32(0.28635)]

In [51]:
len(sorted_valid_data)

3797

In [53]:
np.sum(array)/3

np.float64(3797.0)