# Set Up

In [1]:
import xarray as xr
import numpy as np
import torch
import random
from sklearn.feature_extraction import image
from sklearn.utils import check_random_state
import gcsfs
from google.oauth2.credentials import Credentials
from scipy import linalg
import xarray as xr
import gcsfs
import matplotlib.pyplot as plt

In [2]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [3]:
PERSISTENT_BUCKET = 'gs://leap-persistent/dhruvbalwada'
unfiltered_link = PERSISTENT_BUCKET + '/ssh_reconstruction_project/unfiltered_data.zarr'
filtered_link = PERSISTENT_BUCKET+'/ssh_reconstruction_project/filtered_data.zarr'
da_unfiltered = xr.open_zarr(unfiltered_link)
da_filtered = xr.open_zarr(filtered_link)
ssh_it = da_unfiltered['ssh_unfiltered'] - da_filtered['ssh_filtered']
u_it = da_unfiltered['u_unfiltered'] - da_filtered['u_filtered']
v_it = da_unfiltered['v_unfiltered'] - da_filtered['v_filtered']
da_it = xr.Dataset({'ssh_it': ssh_it,'u_it': u_it,'v_it': v_it})

# Data Augmentation

In [4]:
def augment_and_extract_patches(data, start_time, end_time, patch_size=108, is_ssh=False):
    original_data = data.isel(time=slice(start_time, end_time))
    
    def rotate_data(data, k):
        return xr.apply_ufunc(
            lambda x: np.rot90(x, k),
            data,
            input_core_dims=[['i', 'j']],
            output_core_dims=[['i', 'j']],
            vectorize=True,
            dask='allowed'
        )
    
    augmented_data = []
    for t in range(original_data.sizes['time']):
        time_slice = original_data.isel(time=t)
        augmented_data.extend([
            time_slice,
            rotate_data(time_slice, k=1),
            rotate_data(time_slice, k=2),
            rotate_data(time_slice, k=3)
        ])
    
    augmented_data = xr.concat(augmented_data, dim='time')
    
    all_patches = []
    for time_step in range(augmented_data.sizes['time']):
        arr_2d = augmented_data.isel(time=time_step)
        
        # Fill NaN with 0 only for SSH data
        if is_ssh:
            arr_2d = arr_2d.fillna(0)
        
        arr_2d = arr_2d.values
        patches = []
        for i in range(1080, 3240, patch_size):
            for j in range(0, 2160, patch_size):
                patch = arr_2d[i-1080:i-1080+patch_size, j:j+patch_size]
                patches.append(patch)
        all_patches.extend(patches)
    
    all_patches = np.array(all_patches)
    
    return xr.DataArray(all_patches, dims=['sample', 'i', 'j'])

In [5]:
SSH_train_patches = augment_and_extract_patches(da_unfiltered.ssh_unfiltered, 0, 60, is_ssh=True)
BM_train_patches = augment_and_extract_patches(da_filtered.ssh_filtered, 0, 60, is_ssh=False)
UBM_train_patches = augment_and_extract_patches(da_it.ssh_it, 0, 60, is_ssh=False)

In [8]:
with open("/home/jovyan/SSH/token.txt") as f:
    access_token = f.read().strip()

credentials = Credentials(access_token)
fs = gcsfs.GCSFileSystem(token=credentials)

mapper_ssh = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ssh_train_aug.zarr")
mapper_ubm = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ubm_train_aug.zarr")

SSH_train_patches.to_zarr(mapper_ssh, mode='w')
UBM_train_patches.to_zarr(mapper_ubm, mode='w')

<xarray.backends.zarr.ZarrStore at 0x7d44bc179d40>

In [33]:
# ssh_augmented_with_bm = xr.concat([SSH_train_patches, BM_train_patches.isel(sample=slice(0,5000))], dim='sample')
# zeros_for_bm = xr.zeros_like(BM_train_patches.isel(sample=slice(0,5000)))
# ubm_augmented_with_zeros = xr.concat([UBM_train_patches, zeros_for_bm], dim='sample')

# ssh_final_augmented = xr.concat([ssh_augmented_with_bm, UBM_train_patches.isel(sample=slice(0,5000))], dim='sample')
# ubm_final_augmented = xr.concat([ubm_augmented_with_zeros, UBM_train_patches.isel(sample=slice(0,5000))], dim='sample')

# with open("/home/jovyan/SSH/token.txt") as f:
#     access_token = f.read().strip()

# credentials = Credentials(access_token)
# fs = gcsfs.GCSFileSystem(token=credentials)

# mapper_ssh = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ssh_train_aug.zarr")
# mapper_ubm = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ubm_train_aug.zarr")

# ssh_final_augmented.to_zarr(mapper_ssh, mode='w')
# ubm_final_augmented.to_zarr(mapper_ubm, mode='w')

<xarray.backends.zarr.ZarrStore at 0x7d0fb7e5f1c0>

# Data with NO Augmentation

In [4]:
def extract_patches(data, start_time, end_time, patch_size=108, is_ssh=False):

    original_data = data.isel(time=slice(start_time, end_time))
    
    all_patches = []
    for time_step in range(original_data.sizes['time']):
        arr_2d = original_data.isel(time=time_step)
        
        # Handle NaN values based on whether it's SSH data or not
        if is_ssh:
            arr_2d = arr_2d.fillna(0)
        
        arr_2d = arr_2d.values
        
        patches = []
        for i in range(1080, 3240, patch_size):
            for j in range(0, 2160, patch_size):
                patch = arr_2d[i-1080:i-1080+patch_size, j:j+patch_size]
                patches.append(patch)
        all_patches.extend(patches)
    
    all_patches = np.array(all_patches)
    
    return xr.DataArray(all_patches, dims=['sample', 'i', 'j'])

In [11]:
SSH_Train_patches = extract_patches(da_unfiltered.ssh_unfiltered, 0, 60, is_ssh=True)
SSH_val_patches = extract_patches(da_unfiltered.ssh_unfiltered, 60, 65, is_ssh=True)
SSH_test_patches = extract_patches(da_unfiltered.ssh_unfiltered, 65, 70, is_ssh=True)

BM_Train_patches = extract_patches(da_filtered.ssh_filtered, 0, 60, is_ssh=False)
BM_val_patches = extract_patches(da_filtered.ssh_filtered, 60, 65, is_ssh=False)
BM_test_patches = extract_patches(da_filtered.ssh_filtered, 65, 70, is_ssh=False)

UBM_Train_patches = extract_patches(da_it.ssh_it, 0, 60, is_ssh=False)
UBM_val_patches = extract_patches(da_it.ssh_it, 60, 65, is_ssh=False)
UBM_test_patches = extract_patches(da_it.ssh_it, 65, 70, is_ssh=False)

with open("/home/jovyan/SSH/token.txt") as f:
    access_token = f.read().strip()

credentials = Credentials(access_token)
fs = gcsfs.GCSFileSystem(token=credentials)

mapper_ssh_train = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ssh_train.zarr")
mapper_ssh_val = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ssh_val.zarr")
mapper_ssh_test = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ssh_test.zarr")

mapper_bm_train = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/bm_train.zarr")
mapper_bm_val = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/bm_val.zarr")
mapper_bm_test = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/bm_test.zarr")

mapper_ubm_train = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ubm_train.zarr")
mapper_ubm_val = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ubm_val.zarr")
mapper_ubm_test = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/ubm_test.zarr")


SSH_Train_patches.to_zarr(mapper_ssh_train, mode='w')
SSH_val_patches.to_zarr(mapper_ssh_val, mode='w')
SSH_test_patches.to_zarr(mapper_ssh_test, mode='w')

BM_Train_patches.to_zarr(mapper_bm_train, mode='w')
BM_val_patches.to_zarr(mapper_bm_val, mode='w')
BM_test_patches.to_zarr(mapper_bm_test, mode='w')

UBM_Train_patches.to_zarr(mapper_ubm_train, mode='w')
UBM_val_patches.to_zarr(mapper_ubm_val, mode='w')
UBM_test_patches.to_zarr(mapper_ubm_test, mode='w')

<xarray.backends.zarr.ZarrStore at 0x791337489ac0>

# Load Data from Cloud

In [9]:
base_path = "gs://leap-persistent/YueWang/SSH/data/Spencer"

def open_zarr(path):
    return xr.open_zarr(path, consolidated=True)

ssh_aug = open_zarr(f"{base_path}/ssh_train_aug.zarr").__xarray_dataarray_variable__
ubm_aug = open_zarr(f"{base_path}/ubm_train_aug.zarr").__xarray_dataarray_variable__

SSH_Train_patches = open_zarr(f"{base_path}/ssh_train.zarr").__xarray_dataarray_variable__
SSH_val_patches = open_zarr(f"{base_path}/ssh_val.zarr").__xarray_dataarray_variable__
SSH_test_patches = open_zarr(f"{base_path}/ssh_test.zarr").__xarray_dataarray_variable__

BM_Train_patches = open_zarr(f"{base_path}/bm_train.zarr").__xarray_dataarray_variable__
BM_val_patches = open_zarr(f"{base_path}/bm_val.zarr").__xarray_dataarray_variable__
BM_test_patches = open_zarr(f"{base_path}/bm_test.zarr").__xarray_dataarray_variable__

UBM_Train_patches = open_zarr(f"{base_path}/ubm_train.zarr").__xarray_dataarray_variable__
UBM_val_patches = open_zarr(f"{base_path}/ubm_val.zarr").__xarray_dataarray_variable__
UBM_test_patches = open_zarr(f"{base_path}/ubm_test.zarr").__xarray_dataarray_variable__

# ZCA

In [12]:
def calculate_zca_params(data, epsilon=1e-5):

    # Reshape the data to 2D
    data_flat = data.reshape((data.shape[0], -1))
    
    # Handle NaN values
    mask = ~np.isnan(data_flat)
    mean = np.nanmean(data_flat, axis=0)
    
    # Center the data, ignoring NaN values
    data_centered = data_flat - mean
    data_centered[~mask] = 0  # Set NaN values to 0 after centering
    
    # Compute covariance matrix
    cov_matrix = np.dot(data_centered.T, data_centered) / np.sum(mask, axis=0).clip(min=1)
    
    # Compute SVD
    U, S, _ = linalg.svd(cov_matrix)
    
    # Compute whitening matrix
    zca_matrix = np.dot(U, np.dot(np.diag(1.0 / np.sqrt(S + epsilon)), U.T))
    
    return zca_matrix, mean

def apply_zca_whitening(data, zca_matrix, mean):

    original_shape = data.shape
    data_flat = data.reshape((data.shape[0], -1))
    
    # Create a mask for non-NaN values
    mask = ~np.isnan(data_flat)
    
    # Center the data using the provided mean
    data_centered = data_flat - mean
    data_centered[~mask] = 0  # Set NaN values to 0 after centering
    
    # Apply the whitening transformation
    data_whitened = np.dot(data_centered, zca_matrix)
    
    # Restore NaN values where they were originally
    data_whitened[~mask] = np.nan
    
    # Reshape back to original shape
    return data_whitened.reshape(original_shape)

In [13]:
zca_matrix_ubm_aug, zca_mean_ubm_aug = calculate_zca_params(ubm_aug.values)
zca_matrix_ubm_aug_da = xr.DataArray(zca_matrix_ubm_aug, dims=('i', 'j'), coords={'i': range(11664), 'j': range(11664)})
zca_mean_ubm_aug_da = xr.DataArray(zca_mean_ubm_aug, dims=('i',), coords={'i': range(11664)})

UBM_train_zca_aug = apply_zca_whitening(ubm_aug.values, zca_matrix_ubm_aug, zca_mean_ubm_aug)
UBM_val_zca_aug = apply_zca_whitening(UBM_val_patches.values, zca_matrix_ubm_aug, zca_mean_ubm_aug)
UBM_test_zca_aug = apply_zca_whitening(UBM_test_patches.values, zca_matrix_ubm_aug, zca_mean_ubm_aug)

UBM_train_zca_aug = xr.DataArray(UBM_train_zca_aug, dims=ubm_aug.dims, coords=ubm_aug.coords)
UBM_val_zca_aug = xr.DataArray(UBM_val_zca_aug, dims=UBM_val_patches.dims, coords=UBM_val_patches.coords)
UBM_test_zca_aug = xr.DataArray(UBM_test_zca_aug, dims=UBM_test_patches.dims, coords=UBM_test_patches.coords)

In [17]:
import gcsfs
from google.oauth2.credentials import Credentials

with open("/home/jovyan/SSH/token.txt") as f:
    access_token = f.read().strip()

credentials = Credentials(access_token)
fs = gcsfs.GCSFileSystem(token=credentials)

mapper_zca_aug_matrix_ubm = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/zca_aug_matrix_ubm_eps5.zarr")
mapper_zca_aug_mean_ubm = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/zca_aug_mean_ubm_eps5.zarr")

zca_matrix_ubm_aug_da.to_zarr(mapper_zca_aug_matrix_ubm, mode='w')
zca_mean_ubm_aug_da.to_zarr(mapper_zca_aug_mean_ubm, mode='w')

mapper_UBM_Train_zca_aug = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/UBM_Train_zca_aug_eps5.zarr")
mapper_UBM_Val_zca_aug = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/UBM_Val_zca_aug_eps5.zarr")
mapper_UBM_Test_zca_aug = fs.get_mapper("gs://leap-persistent/YueWang/SSH/data/Spencer/UBM_Test_zca_aug_eps5.zarr")

UBM_train_zca_aug.to_zarr(mapper_UBM_Train_zca_aug, mode='w')
UBM_val_zca_aug.to_zarr(mapper_UBM_Val_zca_aug, mode='w')
UBM_test_zca_aug.to_zarr(mapper_UBM_Test_zca_aug, mode='w')

<xarray.backends.zarr.ZarrStore at 0x7d445601a1c0>