In [1]:
import os
import json
import numpy as np
from numpy.linalg import svd
import xarray as xr
import time
from memory_profiler import memory_usage


In [2]:
def run_economy_svd(nc_path: str,
                    var_name: str = 'msl',
                    k: int = None,
                    output_file: str = None,
                    **mem_kwargs) -> dict:
    """
    Unified economy (reduced) SVD driver with single-pass profiling.

    Parameters
    ----------
    nc_path     : Path to NetCDF ('slp.nc' or 't2m.nc').
    var_name    : Variable name ('msl' or 't2m').
    k           : Number of singular modes to retain. Defaults to full rank if None.
    output_file : Optional JSONL path to append results.
    **mem_kwargs: Extra args for memory_usage (e.g. multiprocess=True).

    Returns
    -------
    dict with keys:
      'method', 'dataset', 'shape', 'k', 'elapsed_time_s', 'peak_memory_MiB',
      'reconstruction_error', 'energy_captured', 'cond_full', 'cond_trunc'
    """
    # Load & reshape data
    ds = xr.open_dataset(nc_path)
    data = ds[var_name].values       # (time, lat, lon)
    nt, ny, nx = data.shape
    A = data.reshape(nt, ny*nx).T    # (m, n)
    m, n = A.shape

    # Default k
    if k is None:
        k = min(m, n)

    # Define compute task
    def compute_svd_task():
        U, S, VT = svd(A, full_matrices=False)
        return U, S, VT

    # Profile and capture U,S,VT in one pass
    t0 = time.time()
    peak_mem, (U, S, VT) = memory_usage(
        (compute_svd_task, (), {}),
        retval=True,
        max_usage=True,
        **mem_kwargs
    )
    elapsed = time.time() - t0

    # Truncate to K modes
    U_k   = U[:, :k]
    S_k   = S[:k]
    VT_k  = VT[:k, :]

    # Reconstruction error
    A_rec = (U_k * S_k) @ VT_k
    recon_err = np.linalg.norm(A - A_rec, ord='fro')

    # Energy captured
    total_energy = np.sum(S**2)
    energy = float(np.sum(S_k**2) / total_energy)

    # Condition numbers
    cond_full  = float(S[0] / S[-1])       # κ(A)
    cond_trunc = float(S[0] / S_k[-1])     # κ(A_K)

    # Package results
    results = {
        'method': 'Economy SVD',
        'dataset': os.path.basename(nc_path),
        'shape': (m, n),
        'k': k,
        'elapsed_time_s': float(elapsed),
        'peak_memory_MiB': float(peak_mem),
        'reconstruction_error': float(recon_err),
        'energy_captured': float(energy),
        'cond_full': cond_full,
        'cond_trunc': cond_trunc
    }

    # Print unified report
    print(f"=== {results['method']} on {results['dataset']} (m={m}, n={n}, k={k}) ===")
    for key, val in results.items():
        if key not in ('method', 'dataset', 'shape'):
            print(f"{key.replace('_',' ').capitalize():<20}: {val}")

    # Append to JSONL if requested
    if output_file:
        os.makedirs(os.path.dirname(output_file) or '.', exist_ok=True)
        safe = {kk: (vv.item() if hasattr(vv,'item') else vv)
                for kk, vv in results.items()}
        with open(output_file, 'a') as f:
            f.write(json.dumps(safe) + '\n')

    return results

In [3]:
res_econ = run_economy_svd('slp.nc', var_name='msl', k=10, output_file='svd_results/economy_svd_k10.jsonl', multiprocess=True)

=== Economy SVD on slp.nc (m=16261, n=16071, k=10) ===
K                   : 10
Elapsed time s      : 1120.7144243717194
Peak memory mib     : 17775.1171875
Reconstruction error: 2068757.375
Energy captured     : 0.9999984502792358
Cond full           : 3680575488.0
Cond trunc          : 2293.06982421875
