In [1]:
import gcsfs
import joblib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from sklearn.decomposition import PCA

In [13]:
EMBEDDINGS_PATH="gs://aurora-encoder-storage/encoder_embedding_20240713_20241821.zarr"
STATIC_PATH="gs://aurora-encoder-storage/static.zarr"

# Dates (end date needs to be at least 6 hours after start date)
START_DATE="2024-07-13T18:00:00"
END_DATE="2024-07-18T18:00:00"

# Test variables
TEST_LON_MIN=120
TEST_LON_MAX=210

# Encoder variables
PATCH_SIZE="4"

In [20]:

def reduce_mask(land_sea_mask: np.ndarray, patch_size: int) -> np.ndarray:
    n_lat_patches = land_sea_mask.shape[0] // patch_size
    n_lon_patches = land_sea_mask.shape[1] // patch_size
    land_sea_mask_patched = np.zeros((n_lat_patches, n_lon_patches), dtype=np.int8)
    for i in range(n_lat_patches):
        for j in range(n_lon_patches):
            lat_slice = slice(i * patch_size, (i+1) * patch_size)
            lon_slice = slice(j * patch_size, (j+1) * patch_size)
            patch_data = land_sea_mask[lat_slice, lon_slice]

            mean_val = np.mean(patch_data)
            if mean_val >= 0.5:
                land_sea_mask_patched[i, j] = 1
            else:
                land_sea_mask_patched[i, j] = 0

    return land_sea_mask_patched

def reduce_lon_lat(
        patch_size: int,
        orig_lat_values: np.ndarray,
        orig_lon_values: np.ndarray,
    ) -> tuple[np.ndarray]:
    n_lat_patches = len(orig_lat_values) // patch_size
    n_lon_patches = len(orig_lon_values) // patch_size
    patch_center_lat = np.zeros((n_lat_patches, n_lon_patches))
    patch_center_lon = np.zeros((n_lat_patches, n_lon_patches))
    for i in range(n_lat_patches):
        for j in range(n_lon_patches):
            orig_i_start = i * patch_size
            orig_i_end = orig_i_start + patch_size
            orig_j_start = j * patch_size
            orig_j_end = orig_j_start + patch_size

            patch_lats = orig_lat_values[orig_i_start:orig_i_end]
            patch_lons = orig_lon_values[orig_j_start:orig_j_end]

            patch_center_lat[i, j] = np.mean(patch_lats)
            patch_center_lon[i, j] = np.mean(patch_lons)

    return patch_center_lat, patch_center_lon

def prepare_x(embeddings: xr.Dataset, mask: np.ndarray | None = None) -> np.ndarray:
    X = []
    for i in range(embeddings.time.shape[0]):
        embedding = embeddings.isel(time=i)
        embedding = embedding.data.reshape(512, -1)
        if mask is not None:
            embedding = embedding[:, mask]
        X.append(embedding)
    X = np.stack(X).transpose(1, 0, 2).reshape(512, -1)
    X = X.compute()
    return X

In [None]:
fs = gcsfs.GCSFileSystem(token="anon")
store = fs.get_mapper(EMBEDDINGS_PATH)
aurora_embeddings = xr.open_zarr(store, consolidated=True)
surf_embeddings = aurora_embeddings["surface_latent"].sel(time=slice(START_DATE, END_DATE))

In [18]:
# Static data
store = fs.get_mapper(STATIC_PATH)
static_data = xr.open_zarr(store, consolidated=True)
print("Read static data")

# Get land-sea mask and reduce to patches
land_sea_mask = static_data["lsm"].squeeze().compute()
land_sea_mask_patched = reduce_mask(land_sea_mask, int(PATCH_SIZE))
y_mask = np.tile(land_sea_mask_patched.ravel(), surf_embeddings.time.shape[0])
print("Got LS mask")

# Get lat/lon centres
lat_patched, lon_patched = reduce_lon_lat(1, surf_embeddings.lat, surf_embeddings.lon)
print("Got lat/lon centres")

Read static data
Got LS mask
Got lat/lon centres


In [21]:
# Prepare X
X_ls = prepare_x(surf_embeddings)
print("Prepared X")

: 