In [None]:
# To reload local package without restarting kernel every time
%load_ext autoreload
%autoreload 1
%aimport climTools.helperFunctions

# Basic libraries
import numpy as np
import pandas as pd
import xarray as xr
import xskillscore as xs
import dask.array as da
from tqdm.notebook import tqdm
import time
from memory_profiler import memory_usage

# My local package
from climTools.helperFunctions import *

# Images and file libraries
import glob

In [None]:
tas_path_ipsl = "/homedata/pchevali/clean_data_ipsl/preprocessed/tas_day_IPSL-CM6A-LR_historical_allruns_19500101-20141231_30W_40E_30N_75N.nc"
tas_anom_path_ipsl = "/homedata/pchevali/clean_data_ipsl/preprocessed/tas_anom_day_IPSL-CM6A-LR_historical_allruns_19500101-20141231_30W_40E_30N_75N.nc"
tas_ipsl = xr.open_dataset(tas_path_ipsl)
tas_anom_ipsl = xr.open_dataset(tas_anom_path_ipsl)

In [None]:
def compute_pairwise_rmse_fast(data1, data2):
    """Computes the pairwise rms between every times of two datasets

    The function computes the pairwise rms between every times of two datasets,
    can take a in memory array or dask backed array lazily.

    Args:
        data1: A np.array[time,lat,lon]
        data2: A np.array[time,lat,lon]

    Returns:
        A xarray.DataSet containing everything.
    """
    return (
        np.sqrt(((data1[:, None, :, :] - data2[None, :, :, :]) ** 2).mean(axis=(2, 3)))
        if data1.ndim == 3
        else np.sqrt(((data1[:, None, :] - data2[None, :, :]) ** 2).mean(axis=2))
        if data1.ndim == 2
        else np.sqrt((data1[:, None] - data2[None, :]) ** 2)
    )
    # use of numpy array brodcasting explicitely to vectorise the computation

In [None]:
def compute_pairwise_rmse_fast_matmul(data1, data2):
    """Computes the pairwise rms between every times of two datasets

    The function computes the pairwise rms between every times of two datasets,
    can take a in memory array or dask backed array lazily.

    Args:
        data1: A np.array[time_run,space]
        data2: A np.array[time_run,space]

    Returns:
        A xarray.DataSet containing everything.
    """
    return np.sqrt(np.maximum(
        (data1**2).mean(axis=1)[:, None]
        + (data2**2).mean(axis=1)[None, :]
        - 2 * (data1 @ data2.T) / data1.shape[1],
        0
    ))  # max(0,values) because matrix multiplication is fast but can be unprecise leading to <0 results

### Perf test

In [None]:
test = tas_anom_ipsl.tas.data
test_flat = tas_anom_ipsl.tas.stack(space=["lat", "lon"]).data

slice_sizes = list(range(10, 2200, 100))

times_fast = []
times_matmul = []

mem_fast = []
mem_matmul = []

for size in slice_sizes:
    slicee = slice(0, size)
    data_flat = test_flat[slicee]

    # Time and memory for compute_pairwise_rmse_fast
    start = time.time()
    mem_usage = memory_usage((compute_pairwise_rmse_fast, (data_flat, data_flat)))
    times_fast.append(time.time() - start)
    mem_fast.append(np.max(mem_usage))  # peak memory during call

    # Time and memory for compute_pairwise_rmse_fast_matmul
    start = time.time()
    mem_usage = memory_usage((compute_pairwise_rmse_fast_matmul, (data_flat, data_flat)))
    times_matmul.append(time.time() - start)
    mem_matmul.append(np.max(mem_usage))

# Plot execution time
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(slice_sizes, times_fast, label='compute_pairwise_rmse_fast', marker='o')
plt.plot(slice_sizes, times_matmul, label='compute_pairwise_rmse_fast_matmul', marker='x')
plt.xlabel('Slice size (time dimension)')
plt.ylabel('Execution time (s)')
plt.legend()
plt.grid(True)
plt.title("Execution Time")

# Plot memory usage
plt.subplot(1,2,2)
plt.plot(slice_sizes, np.array(mem_fast)/1024, label='compute_pairwise_rmse_fast', marker='o')
plt.plot(slice_sizes, np.array(mem_matmul)/1024, label='compute_pairwise_rmse_fast_matmul', marker='x')
plt.xlabel('Slice size (time dimension)')
plt.ylabel('Peak memory (Gb)')
plt.legend()
plt.grid(True)
plt.title("Memory Usage")
plt.tight_layout()
plt.show()

In [None]:
A=compute_pairwise_rmse_fast_matmul(test_flat[1:100],test_flat[1:100])
B=compute_pairwise_rmse_fast(test_flat[1:100],test_flat[1:100])

In [None]:
np.allclose(A, B, rtol=1e-4, atol=1e-4)

In [None]:
test = tas_anom_ipsl.tas.data
test_flat = tas_anom_ipsl.tas.stack(space=["lat", "lon"]).data

slice_sizes = list(range(10, 50000, 1000))

times_matmul = []

mem_matmul = []

for size in slice_sizes:
    slicee = slice(0, size)
    data_flat = test_flat[slicee]

    # Time and memory for compute_pairwise_rmse_fast_matmul
    start = time.time()
    mem_usage = memory_usage((compute_pairwise_rmse_fast_matmul, (data_flat, data_flat)))
    times_matmul.append(time.time() - start)
    mem_matmul.append(np.max(mem_usage))

# Plot execution time
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(slice_sizes, times_matmul, label='compute_pairwise_rmse_fast_matmul', marker='x')
plt.xlabel('Slice size (time dimension)')
plt.ylabel('Execution time (s)')
plt.legend()
plt.grid(True)
plt.title("Execution Time")

# Plot memory usage
plt.subplot(1,2,2)
plt.plot(slice_sizes, np.array(mem_matmul)/1024, label='compute_pairwise_rmse_fast_matmul', marker='x')
plt.xlabel('Slice size (time dimension)')
plt.ylabel('Peak memory (Gb)')
plt.legend()
plt.grid(True)
plt.title("Memory Usage")
plt.tight_layout()
plt.show()

conclusion : we keep the matmul version